-
Notifications
You must be signed in to change notification settings - Fork 3
/
paddle_fx_poc.py
159 lines (133 loc) · 4.89 KB
/
paddle_fx_poc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import paddle
import operator
from typing import (Any,Callable)
# This is a proof of concept for a paddle IR. It is not intended to be a complete
# Lots of the codes are from torch.fx.
# Since Paddle doesnot have protocol like __torch_function__,
# I haven to monkey patch paddle.add, paddle.nn.functional.relu
# to create a proxy for them.
def _find_module(root, m):
for n, p in root.named_children():
if m is p:
return n
raise NameError('module is not installed as a submodule')
def print_nodes(nodes):
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
# Store all nodes in a global list so we can print them later.
nodes = []
# Nodes represent a definition of a value in our graph of operators.
class Node:
def __init__(self, name, op, target, args, kwargs):
self.name = name # unique name of value being created
self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|getattr
self.target = target # for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add
self.args = args
self.kwargs = kwargs
def __repr__(self):
return self.name
def create_node(op, target=None, args=None, kwargs=None, name=None):
assert op in ('call_function', 'call_method', 'get_param', 'call_module', 'placeholder')
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
n = Node(name if name is not None else 'noname', op, target, args, kwargs)
nodes.append(n)
return n
def _create_proxy(op, target, args, kwargs, name=None):
n = create_node(op,
target,
args,
kwargs,
name)
return Proxy(n)
class Proxy:
def __init__(self, node):
self.node = node
def __repr__(self):
return f'Proxy({self.node.name})'
reflectable_magic_methods = {
'add': '{} + {}',
'sub': '{} - {}',
'mul': '{} * {}',
'floordiv': '{} // {}',
'truediv': '{} / {}',
'div': '{} / {}',
'mod': '{} % {}',
'pow': '{} ** {}',
'lshift': '{} << {}',
'rshift': '{} >> {}',
'and': '{} & {}',
'or': '{} | {}',
'xor': '{} ^ {}',
'getitem': '{}[{}]'
}
magic_methods = dict({
'eq': '{} == {}',
'ne': '{} != {}',
'lt': '{} < {}',
'gt': '{} > {}',
'le': '{} <= {}',
'ge': '{} >= {}',
'pos': '+{}',
'neg': '-{}',
'invert': '~{}'}, **reflectable_magic_methods)
for method in magic_methods:
def scope(method):
def impl(*args, **kwargs):
target = getattr(operator, method)
return _create_proxy('call_function', target, args, kwargs, method)
impl.__name__ = method
as_magic = f'__{method}__'
setattr(Proxy, as_magic, impl)
scope(method)
def my_trace(root: Callable[..., Any]):
fn = type(root).forward
co = fn.__code__
args = [root]
names_iter = iter(co.co_varnames)
next(names_iter) # skip self
for _ in range(1, co.co_argcount):
name = next(names_iter)
args.append(_create_proxy('placeholder', name, None, None, name))
# monkey patch paddle.add to create a proxy for it
orig_call = paddle.add
def paddle_add_wrapper(*args, **kwargs):
return _create_proxy('call_function', orig_call, args, kwargs, 'add')
# monkey patch paddle.nn.functional.relu to create a proxy for it
orig_relu_call = paddle.nn.functional.relu
def paddle_relu_wrapper(*args, **kwargs):
return _create_proxy('call_function', orig_relu_call, args, kwargs, 'relu')
# monkey patch paddle.nn.Layer to create a proxy for it
orig_module_call = paddle.nn.Layer.__call__
def module_call_wrapper(mod, *args, **kwargs):
target = _find_module(root, mod)
return _create_proxy('call_module', target, args, kwargs, target)
try:
paddle.add = paddle_add_wrapper
paddle.nn.functional.relu = paddle_relu_wrapper
paddle.nn.Layer.__call__ = module_call_wrapper
fn(*args)
finally:
paddle.add = orig_call
paddle.nn.functional.relu = orig_relu_call
paddle.nn.Layer.__call__ = orig_module_call
return
class MyNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc1 = paddle.nn.Linear(in_features=10, out_features=10)
self._fc2 = paddle.nn.Linear(in_features=10, out_features=10)
self._fc3 = paddle.nn.Linear(in_features=10, out_features=10)
def forward(self, x):
x = self._fc1(x)
x = self._fc2(x)
x = self._fc3(x)
y = paddle.add(x=x, y=x)
return paddle.nn.functional.relu(x=y)
net = MyNet()
nodes = []
# tracing a paddle layer
my_trace(net)
print_nodes(nodes)