forked from nod-ai/transformer-benchmarks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fusion_base.py
56 lines (48 loc) · 2.42 KB
/
fusion_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from logging import getLogger
from onnx_model import OnnxModel
from typing import Union, List
from onnx import GraphProto
logger = getLogger(__name__)
class Fusion:
def __init__(self,
model: OnnxModel,
fused_op_type: str,
search_op_types: Union[str, List[str]],
description: str = None):
self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
self.fused_op_type: str = fused_op_type
self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
self.model: OnnxModel = model
self.nodes_to_remove: List = []
self.nodes_to_add: List = []
self.prune_graph: bool = False
self.node_name_to_graph_name: dict = {}
self.this_graph_name: str = None
# It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
self.fused_count: int = 0
def apply(self):
logger.debug(f"start {self.description} fusion...")
input_name_to_nodes = self.model.input_name_to_nodes()
output_name_to_node = self.model.output_name_to_node()
# This assumes that two search ops will not be fused at same time!
for search_op_type in self.search_op_types:
for node in self.model.get_nodes_by_op_type(search_op_type):
graph = self.model.get_graph_by_node(node)
if graph is None:
raise Exception("Can not find node in any graphs")
self.this_graph_name = graph.name
self.fuse(node, input_name_to_nodes, output_name_to_node)
op_list = [node.op_type for node in self.nodes_to_add]
count = max(self.fused_count, op_list.count(self.fused_op_type))
if count > 0:
logger.info(f"Fused {self.description} count: {count}")
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
if self.prune_graph:
self.model.prune_graph()
elif self.nodes_to_remove or self.nodes_to_add:
self.model.update_graph()