torch.fx

前言

最近在学习一些AI编译器,推理结构的常识,恰好看到了torch.fx这个部分。这个其实在1.10就现已出来了,可是一直不知道,所以花了一点时刻学习了这部分的内容。

以下全部的代码依据Mac M1 pytorch 1.13,其他的os/版别没有进行测验

1.什么是torch.fx

首要去查看官网docTORCH.FX

FX is a toolkit for developers to use to transform nn.Module instances. 这句话很好的界说了FX的本质:用来改动module实例的一种东西。包含了三个首要的组件:symbolic tracer intermediate representation python code generation 符号追踪能够捕获模块的语义进行解析;中心表明也便是IR记录了中心的操作,比如输入输出和调用的函数等;代码生成这个比较有意思,由于这是一个python-to-python的转化东西,这就从本质上区别了FX与一些AI编译器,推理库的区别。从流程上看,FX与推理库都是解析模型生成IR,然后交融算子呀优化等等,可是FX仅仅为了优化改动模型的功用,最终落脚点仍是在python上;而其他的库都是通过一系列优化后能够脱离python依赖布置到c++等边缘环境上。

2. torch.fx有什么用

已然运用fx能够改动module,那么详细能够有哪些运用场景呢?我总结了下面几个首要的

  • 追踪模型图,改动模型部分结构,替换某些算子
  • 在python代码的层面临模型进行优化
  • 依据trace得到的成果更好的可视化模型
  • 对模型进行量化

2.1 模型算子替换

首要来看看官网给出的比如

import torch
from torch import nn
from torch import fx
from torch.fx import symbolic_trace
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.param=nn.Parameter(torch.Tensor([1,2,3,4]))
    def forward(self,x):
        return (x+self.param).clamp(min=0.0,max=1.0)
model=MyModel()
symbolic_traced=symbolic_trace(model)
print(symbolic_traced.graph)
print(symbolic_traced.code)
symbolic_traced.graph.print_tabular()

关于torch.fx的使用

从图里咱们能够清楚地看到模型进行的操作以及IR,它也很好的界说了算子的分类(这个对下面部分内容很有用)。然后咱们假如想用sigmoid替换clamp,假如按照官网以及大多数已有文章的比如是有过错的

# 将clamp转为sigmoid
def transform(m):
    gm=fx.Tracer().trace(m)
    for node in gm.nodes:
        if node.op=='call_method':
            if node.target=="clamp":
                print(node.target)
                node.target=torch.sigmoid
    gm.lint()
    return fx.GraphModule(m,gm)
trans_model=transform(model)
print(trans_model.graph)
print(trans_model.code)
trans_model.graph.print_tabular()

关于torch.fx的使用

很明显能够看到node.target有必要是字符串,所以这样替换是不对的。而原示例给出的是torch.mul替换torch.add,假如测验那个代码,node.target==torch.add这个根本不会建立(target是str),所以这儿我才将target条件更正。

那怎样替换clamp呢,并且还要验证替换后模型的成果无误差

# 将clamp转为sigmoid
def transform(m):
    gm=fx.Tracer().trace(m)
    for node in gm.nodes:
        if node.op=='call_method':
            if node.name=="clamp":
                print(node.target)
                node.target="sigmoid"
                node.name="sigmoid"
                node.kwargs={}
    gm.lint()
    return fx.GraphModule(m,gm)
trans_model=transform(model)
print(trans_model.graph)
print(trans_model.code)
trans_model.graph.print_tabular()

关于torch.fx的使用

从模型打印成果来看替换是成功的,可是还要通过输出检验

class MyModel1(nn.Module):
    def __init__(self):
        super().__init__()
        self.param=nn.Parameter(torch.Tensor([1,2,3,4]))
        #self.linear=torch.nn.Linear(4,5)
    def forward(self,x):
        return (x+self.param).sigmoid()
test=MyModel1()
inputs = torch.randn(1,4)
torch.testing.assert_close(test(inputs),trans_model(inputs))

这儿没有任何输出,证明输出与gt共同。当然不止一种完成,下面给出其他两种

# 将clamp转为sigmoid
def transform(m):
    gm=symbolic_trace(m)
    for node in gm.graph.nodes:
        if node.op=='call_method':
            if node.name=="clamp":
                print(node.target)
                node.target="sigmoid"
                node.name="sigmoid"
                node.kwargs={}
    gm.recompile()
    return gm
trans_model=transform(model)
print(trans_model.graph)
print(trans_model.code)
torch.testing.assert_close(test(inputs),trans_model(inputs))
# 将clamp转为sigmoid
from torch.fx import replace_pattern
def pattern(x):
    return x.clamp(min=0.0,max=1.0)
def replacement(x):
    return x.sigmoid()
replace_pattern(symbolic_traced,pattern,replacement)
print(symbolic_traced.graph)
print(symbolic_traced.code)
torch.testing.assert_close(test(inputs),symbolic_traced(inputs))

2.2 算子交融

在做推理布置的时候最常用的便是算子交融,也便是将多个算子的核算在数学上进行等效替换,然后减少了算子数量以及全体的核算量,加速了推理时刻。torch.fx也给了咱们很好的算子交融替换帮助,由于上面说了有了trace咱们能够很轻松地对模型算子进行替换,例如最常见的conv+bn交融丢掉dropout

这部分代码能够参阅官方样例/torch/fx/experimental/optimization.py,我这儿直接白嫖过来演示一下

from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.fx.node import Argument, Target
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import copy
def _parent_name(target : str) -> Tuple[str, str]:
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
    if len(node.args) == 0:
        return False
    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
    for expected_type, current_node in zip(pattern, nodes):
        if not isinstance(current_node, fx.Node):
            return False
        if current_node.op != 'call_module':
            return False
        if not isinstance(current_node.target, str):
            return False
        if current_node.target not in modules:
            return False
        if type(modules[current_node.target]) is not expected_type:
            return False
    return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(nn.Conv1d, nn.BatchNorm1d),
                (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)
    for pattern in patterns:
        for node in new_graph.nodes:
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                bn = modules[node.target]
                fused_conv = fuse_conv_bn_eval(conv, bn)
                replace_node_module(node.args[0], modules, fused_conv)
                node.replace_all_uses_with(node.args[0])
                new_graph.erase_node(node)
    return fx.GraphModule(fx_model, new_graph)
def remove_dropout(model: nn.Module) -> nn.Module:
    """
    Removes all dropout layers from the module.
    """
    fx_model = fx.symbolic_trace(model)
    class DropoutRemover(torch.fx.Transformer):
        def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
            if isinstance(self.submodules[target], nn.Dropout):
                assert len(args) == 1
                return args[0]
            else:
                return super().call_module(target, args, kwargs)
    return DropoutRemover(fx_model).transform()
class TestConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(TestConv2d,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,**kwargs)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(True)
    def forward(self,x):
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        return x
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=TestConv2d(3,32,kernel_size=3)
        self.conv2=TestConv2d(32,64,kernel_size=3)
        self.dropout=nn.Dropout(0.3)
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.dropout(x)
        return x
def show(string,count):
    print(f"{'='*count}{string}{'='*count}")
test_model=TestModel()
# 在eval下进行交融,丢掉
test_model.eval()
### origin
origin_model=symbolic_trace(test_model)
show("origin result",20)
print(origin_model.graph)
print(origin_model.code)
### fuse
fuse_model=fuse(test_model)
fuse_model=remove_dropout(fuse_model)
show("fuse result",20)
print(fuse_model.graph)
print(fuse_model.code)

关于torch.fx的使用

能够看到通过算子交融与丢掉,模型没有了bn dropout十分简练。有人会说为什么不把relu也融进conv,这在量化中能够完成切断可是假如是全精度也便是FP32下假如scale和zeropoint不共同没法量化回来,所以这儿并没有进行交融。

2.3 模型可视化

不知道多少人用过torchviz对模型进行过可视化,不能说欠好只能说根本不直观。这儿我恰好看到了一篇讲运用fx进行模型结构可视化的博客,可惜博主代码没有全部给出来。不过依据他的文章也算是给了我一种很好的思路,已然咱们都有模型的DAG,IR,那咱们应该能够更加直观的完成模型结构的可视化。所以这部分就算是完成博主没有给出来的代码,模型界说就用博主博客中的模型

运用torch.fx提取PyTorch网络结构信息制作网络结构图 – wrong.wang,咱们能够先去看看博主的这篇文章我不过多讲重复内容。另外假如想完成功用,还得去研究一下fx解说器的源码torch.fx.interpreter — PyTorch 1.13 documentation

from torchviz import make_dot
import graphviz
import torch.nn.functional as F
class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.bias = nn.Parameter(torch.randn(1))
        self.main = nn.Sequential(nn.Conv2d(3, 4, 1), nn.ReLU(True))
        self.skip = nn.Conv2d(2, 4, 3, stride=1, padding=1)
    def forward(self, x, y):
        x = self.main(x)
        y = (self.skip(y)+self.bias).clamp(0, 1)
        x_size = x.size()[-2:]
        y = F.interpolate(y, x_size, mode="bilinear", align_corners=False)
        return torch.sigmoid(x) + y
x=torch.randn(1,3,16,16)
y=torch.randn(1,2,8,8)
test_model=TestModel()
z=test_model(x,y)
g=make_dot(z,params=dict(test_model.named_parameters()))
g.render(directory="test",format='svg',view=False)

首要用torchviz制作一下模型

关于torch.fx的使用

看着这张图,似懂非懂的样子,并不能直观的看到模型的结构。然后开端完成博主的代码

import traceback
class Get_IR(torch.fx.Interpreter):
    def run_node(self,n):
        try:
            result=super().run_node(n)
        except Exception:
            traceback.print_exc()
            raise RuntimeError(f"Error while run node:{n.format_node()}")
        is_find=False
        def extract_meta(t):
            if isinstance(t,torch.Tensor):
                nonlocal is_find
                is_find=True
                return _extra_meta(t)
            else:
                return t
        def _extra_meta(t):
            if n.op=="call_module":
                submod=self.fetch_attr(n.target)
                return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs,'mod':submod}
            elif n.op=="call_method":
                return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs}
            elif n.op=="call_function":
                return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs}
            else:
                return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape}
        n.meta["result"]=torch.fx.node.map_aggregate(result,extract_meta)
        n.meta["find"]=is_find
        return result
traced=symbolic_trace(test_model)
args=(x,y)
kwargs={}
_=Get_IR(traced).run(*args,**kwargs)
print(traced.graph.print_tabular())
for node in traced.graph.nodes:
    print(node.meta)

其实这部分便是运用解说器会遍历图中的每个节点,所以咱们只需求自界说一下run_node(),在里边加入解析网络结构,输入输出的功用就能够了。

关于torch.fx的使用

能够看到meta里边现已有了模型结构所需求的全部,可是这儿尽管打印出来sizegetitem是存在的,可是实际上并没有在条件中解析到,现在还没找到原因。

def create_str(node):
    if node.op=="call_module":
        return f"<<TABLE><TR><TD COLSPAN='2'>{node.meta['result']['mod']}</TD></TR><TR><TD>{node.meta['result']['name']}</TD><TD>{node.meta['result']['shape']}</TD></TR></TABLE>>"
    elif node.meta['find']:
        return f"<<TABLE><TR><TD>{node.meta['result']['name']}</TD></TR><TR><TD>{node.meta['result']['shape']}</TD></TR></TABLE>>"
    else:
        return f"<<TABLE><TR><TD>{node.meta['result']}</TD></TR></TABLE>>"
def single_node(model: torch.nn.Module, graph: graphviz.Digraph, node: torch.fx.Node):
    node_label = create_str(node) # 生成其时节点的label
    node_kwargs = dict(shape='plaintext',align='center',fontname='monospace')
    graph.node(node.name, label=node_label, **node_kwargs) # 在Graphviz图中增加其时节点
    # 遍历其时节点的全部输入节点,增加Graphviz图中的边
    for in_node in node.all_input_nodes:
        edge_kwargs = dict()
        if (
            not node.meta["find"]
            or not in_node.meta["find"]
        ):
            # 假如其时节点的输入和输出中都没有Tensor,就把其时边置为浅灰色虚线,弱化显示
            edge_kwargs.update(dict(style="dashed", color="lightgrey"))
        # 增加其时边
        graph.edge(in_node.name, node.name, **edge_kwargs)
def model_graph(model: torch.nn.Module, *args, **kwargs) -> graphviz.Digraph:
    # 将nn.Module转化为torch.fx.GraphModule,获取核算图
    symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(model)
    # 履行一下网络,以此获取每个节点输入输出的详细信息
    Get_IR(symbolic_traced).run(*args, **kwargs)
    # 界说一个Graphviz网络
    graph = graphviz.Digraph("model", format="svg", node_attr={"shape": "plaintext"})
    for node in symbolic_traced.graph.nodes: # 遍历全部节点
        single_node(model, graph, node)
    return graph
model = TestModel()
graph = model_graph(model, torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8))
graph.render(directory="test", view=False)

关于torch.fx的使用

这样来看模型成果就明晰许多,也和博主给出的成果高度还原。其时便是由于看到了这个结构图所以让我好好看了一遍解说器部分的源码来完成这个效果,假如未来自己做推理结构期望也能很明晰直观地给出模型结构图这和简单易用一样都是最基本的。

2.4 量化

在不大起伏减小模型精度的情况下,对已有练习好的模型以低精度履行核算这便是量化。一般关于pytorch便是从FP32(FP16假如有amp)转到INT8 能够参阅torch的官方文档pytorch.org/docs/master…

运用fx能够轻松的插入量化节点,并进行校准。不过量化需求已知数据分布,所以下面的步骤便是

  1. 用某个数据集训一个模型
  2. 量化
  3. 校准
  4. 对比检验

这儿我就用resnet18在cifar10上练习得到模型为例,练习部分的代码网上许多这儿就不再给出

model=resnet18(pretrained=True)
model.fc=nn.Linear(model.fc.in_features,10)
if not os.path.exists("raw.pth"):
    train_model(model,train_loader,test_loader,10,torch.device("mps:0"))
    torch.save(model.state_dict(),"raw.pth")

这儿说个坑哈,千万别用mac练习太慢了。假如用cuda估量几分钟以内就算完了,可是由于用服务器不能多屏仍是觉得欠好所以忍着在mac上练习(顺便摸摸鱼)

然后开端量化,参阅pytorch.org/tutorials/p…进行后练习动态量化

print(torch.backends.quantized.supported_engines)

关于torch.fx的使用

这个很重要,得知道运用的渠道支撑的engine

import os
import time
import copy
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torch.quantization.quantize_fx import prepare_fx,convert_fx
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.fx.graph_module import ObservedGraphModule
model=resnet18(pretrained=True)
model.fc=nn.Linear(model.fc.in_features,10)
model.load_state_dict(torch.load("raw.pth",map_location='cpu'))
model.to(torch.device("cpu"))
model.eval()
torch.backends.quantized.engine = 'qnnpack'
qconfig_mapping=get_default_qconfig_mapping("qnnpack")
model_to_quantize=copy.deepcopy(model)
prepared_model=prepare_fx(model_to_quantize,qconfig_mapping,example_inputs=torch.randn([1,3,224,224]))
print(f"prepared model {prepared_model.graph.print_tabular()}")
quantized_model=convert_fx(prepared_model)
print(f"{'='*100}")
print(f"quantized model {quantized_model.graph.print_tabular()}")

这儿就载入练习好的模型,然后进行量化。依据官网的比如找到核心内容仿照就好

关于torch.fx的使用

能够看到图中转为了torch.quint8,模型的巨细肯定也缩小了许多

def print_size_of_model(model):
    torch.save(model.state_dict(),"tmp.pt")
    print(f"The model size:{os.path.getsize('tmp.pt')/1e6}MB")
    os.remove("tmp.pt")
print_size_of_model(prepared_model)
print_size_of_model(quantized_model)

关于torch.fx的使用

模型巨细差不多变成了本来的1/4,可是光变小不可还得看精度

# 测验一下精度
train_loader,test_loader=prepare_dataloader()
example_data=torch.randn([1,3,224,224])
out1=model(example_data)
out2=quantized_model(example_data)
print(torch.allclose(out1,out2,1e-3))
out1
out2
evaluate_model(model,test_loader,device='cpu')
evaluate_model(quantized_model,test_loader,device='cpu')

关于torch.fx的使用

直接G了,这什么鬼呀尽管推理时刻差不多少了一半可是这准确率跟瞎猜差不多了,这可不可!!!所以还需求进行量化的重要一步:校准

咱们需求已知数据分布的情况下对模型进行量化才能使量化后的模型依然保持准确率,所以下面就进行量化校准

# 校准康复精度
model_to_quantize=copy.deepcopy(model)
prepared_model=prepare_fx(model_to_quantize,qconfig_mapping,example_inputs=torch.randn([1,3,224,224]))
prepared_model.eval()
with torch.inference_mode():
    for inputs,labels in test_loader:
        prepared_model(inputs)
quantized_recover_model=convert_fx(prepared_model)
out3=quantized_recover_model(example_data)
print(torch.allclose(out1,out3,1e-3))
out3
evaluate_model(quantized_recover_model,test_loader,device='cpu')

关于torch.fx的使用

关于torch.fx的使用

尽管这儿精度并没有对齐,可是准确率仍是康复上来了。关于边缘,移动端的布置来说,这么一点点微小的准确率丢失能够换来存储占用小75%,推理速度进步一倍,这是谁都能接受的。

最后

看了AI编译器,推理结构后再来看fx,总感觉相似可是又不同。就像之前说的本质上二者就不同,fx只存在于python而不考虑硬件布置上,可是假如咱们首要运用fx在python端尽力优化好然后再去推理结构上微调一下结构,那会比反复调整推理结构适应全部可能的算子轻松许多,毕竟python仍是比c++写起来坑少许多的,并且这样的话推理结构就能够很自然的附带出python的推理api,期望以后有时刻我能够依据这个思路早点写出来。