From e87c56be6ccdbcd3cf2c38d8e3e4e850d862d6ab Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Fri, 9 Apr 2021 23:17:24 -0700 Subject: [PATCH] Add leaky relu and flatten to map_torch_types_to_onnx Signed-off-by: Kevin Hsieh --- .../torch/src/python/aimet_torch/onnx_utils.py | 2 ++ TrainingExtensions/torch/test/python/test_quantizer.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index 758de53c408..6ce17e2b812 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -157,6 +157,8 @@ nn.Sigmoid: ['Sigmoid'], nn.Upsample: ['Upsample'], nn.PReLU: ['PRelu'], + nn.LeakyReLU: ['LeakyRelu'], + nn.Flatten: ['Flatten'], elementwise_ops.Add: ['Add'], elementwise_ops.Subtract: ['Sub'], elementwise_ops.Multiply: ['Mul'], diff --git a/TrainingExtensions/torch/test/python/test_quantizer.py b/TrainingExtensions/torch/test/python/test_quantizer.py index f67cb89ec59..5505d435d2e 100644 --- a/TrainingExtensions/torch/test/python/test_quantizer.py +++ b/TrainingExtensions/torch/test/python/test_quantizer.py @@ -187,7 +187,8 @@ def __init__(self): self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.maxpool2 = nn.MaxPool2d(2) - self.relu2 = nn.ReLU() + self.relu2 = nn.LeakyReLU() + self.flatten = nn.Flatten() self.fc1 = nn.Linear(320, 50) self.relu3 = nn.ReLU() @@ -201,7 +202,7 @@ def forward(self, x1, x2): x2 = self.relu1_b(self.maxpool1_b(self.conv1_b(x2))) x = x1 + x2 x = self.relu2(self.maxpool2(self.conv2(x))) - x = x.view(-1, 320) + x = self.flatten(x) x = self.relu3(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) @@ -711,7 +712,7 @@ def forward_pass(model, args): activation_encodings = encodings['activation_encodings'] param_encodings = encodings['param_encodings'] - self.assertEqual(15, len(activation_encodings)) + self.assertEqual(16, len(activation_encodings)) self.assertIn('conv1_a.bias', param_encodings) self.assertEqual(param_encodings['conv1_a.bias'][0]['bitwidth'], 32) self.assertEqual(6, len(param_encodings['conv1_a.weight'][0])) @@ -722,7 +723,7 @@ def forward_pass(model, args): activation_encodings = encodings['activation_encodings'] param_encodings = encodings['param_encodings'] - self.assertEqual(15, len(activation_encodings)) + self.assertEqual(16, len(activation_encodings)) self.assertIn('conv1_a.bias', param_encodings) self.assertEqual(param_encodings['conv1_a.bias'][0]['bitwidth'], 32) self.assertEqual(6, len(param_encodings['conv1_a.weight'][0]))