当前位置: 首页 > news >正文

PatternMatcher-Pytorch

import os
import torch
import torch.nn as nn
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.compile_fx import compile_fxcache_dir = "/home/xytpai/workspace/work/temp"
envs = {"TORCHINDUCTOR_CACHE_DIR": os.path.join(cache_dir, "inductor"),
}
for k,v in envs.items():os.environ[k] = v@torch.library.custom_op("myops::add", mutates_args=["result"])
def myops_add(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:torch.add(x, y, out=result)@torch.library.custom_op("myops::relu", mutates_args=["result"])
def myops_relu(result: torch.Tensor, x: torch.Tensor) -> None:result.copy_(x)torch.relu_(result)@torch.library.custom_op("myops::add_relu", mutates_args=["result"])
def myops_add_relu(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:z = x + yresult.copy_(z)torch.relu_(result)def pattern(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):at1 = auto_functionalized(torch.ops.myops.add.default,result=result_add,x=x, y=y)at2 = auto_functionalized(torch.ops.myops.relu.default,result=result,x=at1[1])return at2[1]def replacement(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):at = auto_functionalized(torch.ops.myops.add_relu.default,result=result,x=x, y=y)return at[1]inputs = [torch.empty(5, 4, dtype=torch.float),  # resulttorch.empty(5, 4, dtype=torch.float),  # result_addtorch.empty(5, 4, dtype=torch.float),  # xtorch.empty(5, 4, dtype=torch.float),  # y
]pm_pass = pm.PatternMatcherPass(pass_name="fusion_pass")
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:print(graph)_count = pm_pass.apply(graph)print(_count)print(graph)graph.eliminate_dead_code()return graphdef custom_backend(graph: torch.fx.GraphModule, example_inputs):from torch._inductor import configcurrent_config = config.get_config_copy()current_config["post_grad_custom_post_pass"] = custom_passreturn compile_fx(graph, example_inputs, config_patches=current_config)# def fw_add(x, y):
#     out = torch.empty_like(x)
#     torch.ops.myops.add(out, x, y)
#     return out# def fw_relu(x):
#     out = torch.empty_like(x)
#     torch.ops.myops.relu(out, x)
#     return out@torch.compile(backend=custom_backend)
class SimpleModel(nn.Module):@torch._inductor.config.patch(enable_auto_functionalized_v2=False)# def forward(self, x, y):#     x = fw_add(x, y)#     x = fw_relu(x)#     return xdef forward(self, x, y):out = torch.empty_like(x)out2 = torch.empty_like(x)torch.ops.myops.add(out, x, y)torch.ops.myops.relu(out2, out)return out2model = SimpleModel()
x = torch.rand(10, 10)
y = torch.rand(10, 10)
z = model(x, y)
http://www.zskr.cn/news/9316.html

相关文章:

  • uboot启动流程
  • 内存泄漏
  • Context Engineering
  • github/网盘/公众号信息收集
  • AtCoder Regular Contest 206 (Div. 2) 部分题解
  • Influxdb 得模糊查询总结
  • 多表关系和多表查询
  • 【反比例函数】【做题笔记】【图形存在性】题目合集
  • 20250920 嘉定江桥---江苏吴江区太湖 往返160KM骑行小记
  • 工作队列(Work Queues)与消息确认(Ack)
  • 6-5 汇聚层
  • 6-4 多输入多输出通道
  • 6-2图像卷积
  • 基于WOA鲸鱼优化的XGBoost序列预测算法matlab仿真
  • Arch下实现人脸识别登录:howdy的配置与使用
  • Winform的Formborder.None情况下,解决不能拖动的问题
  • 实用指南:centos sshd:xxx.xxx.xxx.xxx:allow 如何设置
  • fedora无法看视频?编解码器详细安装教程
  • 高并发高吞吐量
  • 服务降级
  • 镜像制作
  • IAR Embedded Workbench中的MCU启动过程分析
  • CSP-S 2025
  • ENVI系列教程(七)——自定义 RPC 资料图像正射校正
  • Linux 笔记本充电限制【转发】
  • 别样的CSP-S初赛大战(又名:我和油一的那些年)
  • 范德蒙德卷积入门
  • 用 【C# + WinUI3 + 图像动画】 来理解:高数 - 函数 - 初等函数 - 行人-
  • ansible语句
  • 代码随想录算法训练营第四天 |24. 两两交换链表中的节点、19.删除链表的倒数第N个节点、面试题 02.07. 链表相交、142.环形链表II