From a2769b528026fedb82f2b512309000342bedc263 Mon Sep 17 00:00:00 2001 From: Sai Chaitanya Gajula Date: Tue, 3 Dec 2024 00:21:47 +0530 Subject: [PATCH] Fix custom_tensor_utils for environments without spconv (#3538) Signed-off-by: Sai Chaitanya Gajula --- .../python/aimet_torch/custom/custom_tensor_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/custom/custom_tensor_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/custom/custom_tensor_utils.py index 2552dcde929..6394edd7963 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/custom/custom_tensor_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/custom/custom_tensor_utils.py @@ -39,13 +39,17 @@ try: import spconv.pytorch as spconv except ImportError as e: - to_torch_tensor = None - to_custom_tensor = None + def to_torch_tensor(tensors): + """ placeholder in case spconv doesn't exist """ + return tensors + + def to_custom_tensor(tensors): + """ placeholder in case spconv doesn't exist """ + return tensors else: from typing import List, Union, Tuple import torch - def to_torch_tensor(original: Union[List, Tuple]) -> List[torch.Tensor]: """ Convert custom tensors to torch tensors @@ -77,7 +81,6 @@ def to_custom_tensor(original: Union[List, Tuple], torch_tensors: List[torch.Ten tensor = torch_tensor if isinstance(orig, spconv.SparseConvTensor): tensor = orig.replace_feature(torch_tensor) - outputs.append(tensor) return outputs