From 15acd5de2a59c81acc7b6d99fc812f3d623982fd Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 13 Nov 2023 09:02:07 +0000 Subject: [PATCH] [FXML-3548] Bump torch mlir Bump torch-mlir to ff7f8b21dcc842a4f70209a6d255d54c4ef6e39b, and llvm to d13da154a7c7eff77df8686b2de1cfdfa7cc7029. For now, point llvm to the upstream commit, will change again after xilinx/llvm-project itself is bumped. Test failure in `Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir` is expected, as it requires changes from the xilinx llvm-fork. --- .github/workflows/RollPyTorch.yml | 19 +- .github/workflows/bazelBuildAndTest.yml | 20 +- .github/workflows/merge-rollpytorch.yml | 2 +- .gitmodules | 9 +- CITATION.cff | 19 + CMakeLists.txt | 24 +- README.md | 12 +- build_tools/autogen_ltc_backend.py | 3 +- build_tools/autogen_ltc_backend.yaml | 56 +- .../python_deploy/build_linux_packages.sh | 53 +- build_tools/update_torch_ods.sh | 3 +- docs/code_owners.md | 6 +- docs/development.md | 7 +- e2e_testing/main.py | 11 +- e2e_testing/xfail_sets.py | 288 +- .../lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 2 +- .../Dialect/TMTensor/Transforms/Bufferize.cpp | 6 +- externals/llvm-project | 2 +- externals/mlir-hlo | 1 - externals/stablehlo | 1 + include/torch-mlir-c/TorchTypes.h | 75 +- .../TorchToStablehlo/StablehloLegalizeUtils.h | 3 +- .../TorchToTosa/TosaLegalizeCommon.h | 6 + .../TorchToTosa/TosaLegalizeUtils.h | 3 +- include/torch-mlir/Conversion/Utils/Utils.h | 2 +- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2338 +++++++++++++++-- .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 1 + .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 2 +- .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 3 +- .../TorchConversion/Transforms/Passes.h | 8 + .../TorchConversion/Transforms/Passes.td | 12 + lib/CAPI/TorchTypes.cpp | 114 +- lib/CMakeLists.txt | 10 +- lib/Conversion/Passes.cpp | 1 - lib/Conversion/TorchToLinalg/DataMovement.cpp | 653 +++-- .../TorchToLinalg/IndirectDataMovement.cpp | 32 +- lib/Conversion/TorchToLinalg/Linear.cpp | 33 +- lib/Conversion/TorchToLinalg/Pooling.cpp | 125 +- lib/Conversion/TorchToLinalg/Reduction.cpp | 70 +- .../TorchToLinalg/TensorConstructors.cpp | 6 +- .../TorchToLinalg/Uncategorized.cpp | 61 +- lib/Conversion/TorchToLinalg/Utils.cpp | 88 +- lib/Conversion/TorchToLinalg/Utils.h | 13 +- lib/Conversion/TorchToSCF/TorchToSCF.cpp | 14 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 297 ++- .../TorchToStablehlo/CMakeLists.txt | 3 +- .../TorchToStablehlo/GatherScatter.cpp | 328 ++- lib/Conversion/TorchToStablehlo/Linear.cpp | 2 +- lib/Conversion/TorchToStablehlo/Pooling.cpp | 336 +-- lib/Conversion/TorchToStablehlo/Reduction.cpp | 81 + .../StablehloLegalizeUtils.cpp | 9 +- .../TorchToStablehlo/TorchToStablehlo.cpp | 4 +- .../TorchToTMTensor/TorchToTMTensor.cpp | 42 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 265 +- .../TorchToTosa/TosaLegalizeCommon.cpp | 275 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 2 - lib/Conversion/Utils/Utils.cpp | 6 +- lib/Dialect/Torch/IR/CMakeLists.txt | 3 + lib/Dialect/Torch/IR/TorchOps.cpp | 292 +- lib/Dialect/Torch/IR/TorchTypes.cpp | 20 +- .../Transforms/AbstractInterpLibrary.cpp | 748 ++++-- .../Transforms/AdjustCallingConventions.cpp | 45 - .../Torch/Transforms/DecomposeComplexOps.cpp | 687 ++++- .../Transforms/LowerToBackendContract.cpp | 14 +- .../Torch/Transforms/RecomposeComplexOps.cpp | 217 +- .../Torch/Transforms/RefinePublicReturn.cpp | 5 +- .../ReifyAbstractInterpCalculationsUtils.cpp | 15 +- .../Transforms/SimplifyDtypeCalculations.cpp | 5 + lib/Dialect/Torch/Utils/Utils.cpp | 17 +- .../TorchConversion/Transforms/CMakeLists.txt | 6 +- .../Transforms/ConvertCustomQuantOp.cpp | 226 ++ .../Transforms/UnpackQuantTensor.cpp | 143 + lib/InitAll.cpp | 14 +- python/torch_mlir/_dynamo_fx_importer.py | 6 +- python/torch_mlir/compiler_utils.py | 8 +- .../csrc/base_lazy_backend/CMakeLists.txt | 5 + .../mlir_lowering_context.cpp | 54 +- .../mlir_native_functions.cpp | 340 ++- .../csrc/base_lazy_backend/mlir_node.cpp | 35 +- .../csrc/base_lazy_backend/mlir_node.h | 13 + .../base_lazy_backend/mlir_node_lowering.cpp | 20 +- .../csrc/base_lazy_backend/ops/index.cpp | 99 + .../csrc/base_lazy_backend/ops/index.h | 58 + .../csrc/base_lazy_backend/ops/ivalue.cpp | 36 + .../csrc/base_lazy_backend/ops/ivalue.h | 37 + .../csrc/base_lazy_backend/ops/split.cpp | 101 + .../csrc/base_lazy_backend/ops/split.h | 65 + .../csrc/base_lazy_backend/ops/unbind_int.cpp | 54 + .../csrc/base_lazy_backend/ops/unbind_int.h | 37 + .../base_lazy_backend/shape_inference.cpp | 372 ++- .../csrc/base_lazy_backend/tensor.cpp | 29 + .../csrc/base_lazy_backend/tensor.h | 24 + .../base_lazy_backend/utils/string_utils.h | 18 + .../csrc/base_lazy_backend/utils/sys_utils.h | 8 + .../reference_lazy_backend/backend_impl.cpp | 25 +- .../reference_lazy_backend_pybind.cpp | 7 + python/torch_mlir/dialects/TorchBinding.td | 1 - .../build_tools/abstract_interp_lib_gen.py | 424 ++- .../jit_ir/build_tools/library_generator.py | 21 +- .../importer/jit_ir/build_tools/registry.py | 18 +- .../jit_ir/build_tools/torch_ods_gen.py | 105 +- .../jit_ir/csrc/torch_to_mlir_utils.cpp | 57 +- .../linalg_on_tensors_backends/refbackend.py | 1 - .../stablehlo_backends/linalg_on_tensors.py | 50 - .../torch_mlir_e2e_test/test_suite/basic.py | 504 +++- .../test_suite/constant_alloc.py | 144 + python/torch_mlir_e2e_test/test_suite/conv.py | 40 +- .../test_suite/elementwise.py | 277 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 156 ++ .../test_suite/reduction.py | 72 + python/torch_mlir_e2e_test/test_suite/rng.py | 60 + .../torch_mlir_e2e_test/test_suite/scatter.py | 29 + .../test_suite/slice_like.py | 102 + .../test_suite/type_conversion.py | 43 + pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- setup.py | 2 +- test/Conversion/TorchToArith/basic.mlir | 21 - test/Conversion/TorchToLinalg/basic.mlir | 37 +- test/Conversion/TorchToStablehlo/scatter.mlir | 35 + test/Conversion/TorchToTosa/basic.mlir | 278 +- ...orch-backend-to-tosa-backend-pipeline.mlir | 42 +- .../Torch/adjust-calling-conventions.mlir | 17 - test/Dialect/Torch/canonicalize.mlir | 112 +- test/Dialect/Torch/decompose-complex-ops.mlir | 24 + test/Dialect/Torch/invalid.mlir | 2 +- test/Dialect/Torch/refine-public-return.mlir | 19 + .../Torch/reify-dtype-calculations.mlir | 15 + .../Torch/simplify-dtype-calculations.mlir | 10 +- .../Torch/verify-backend-contract-error.mlir | 29 + .../convert-custom-quant-op.mlir | 45 + .../TorchConversion/unpack-quant-tensor.mlir | 13 + .../verify-tosa-backend-contract.mlir | 2 +- test/python/custom_op_shape_dtype_fn.py | 42 +- .../importer/jit_ir/node_import/debug-info.py | 11 +- tools/torch-mlir-lsp-server/CMakeLists.txt | 2 + .../torch-mlir-lsp-server.cpp | 2 + tools/torch-mlir-opt/CMakeLists.txt | 8 + tools/torch-mlir-opt/torch-mlir-opt.cpp | 10 +- torchvision-requirements.txt | 2 +- utils/bazel/WORKSPACE.bazel | 19 +- utils/bazel/torch-mlir-overlay/BUILD.bazel | 8 +- 142 files changed, 10362 insertions(+), 2328 deletions(-) create mode 100644 CITATION.cff delete mode 160000 externals/mlir-hlo create mode 160000 externals/stablehlo create mode 100644 lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp create mode 100644 lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/index.h create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/split.h create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h create mode 100644 python/torch_mlir/csrc/base_lazy_backend/tensor.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/tensor.h delete mode 100644 python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py create mode 100644 test/Conversion/TorchToStablehlo/scatter.mlir create mode 100644 test/Dialect/TorchConversion/convert-custom-quant-op.mlir create mode 100644 test/Dialect/TorchConversion/unpack-quant-tensor.mlir diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 51f3f874b065..5c8d74ee0941 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -24,9 +24,21 @@ jobs: - name: Get torch-mlir uses: actions/checkout@v3 with: - submodules: 'true' + submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + - name: Get LLVM and StableHlo submodules + run: | + set -eo pipefail + cd ${GITHUB_WORKSPACE} + + # Fetching the submodules concurrently may cause problems, so we fetch + # them one after another. + rm -f .git/modules/externals/llvm-project/index.lock + rm -f .git/modules/externals/stablehlo/index.lock + git submodule update --init --recursive externals/llvm-project + git submodule update --init --recursive externals/stablehlo + - name: Setup ccache uses: ./.github/actions/setup-build with: @@ -71,15 +83,14 @@ jobs: echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - name: Build and test (in-tree), also update ODS and abstract interpretation library + - name: Build and test (out-of-tree), also update ODS and abstract interpretation library if: env.PT_HASH_CHANGED != '0' run: | cd ${GITHUB_WORKSPACE} - TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ - TM_PYTHON_VERSIONS="cp311-cp311" \ ./build_tools/python_deploy/build_linux_packages.sh - name: Post issue comment on build failure diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 43630adcbd77..d0d11ad5a6eb 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -58,33 +58,33 @@ jobs: -t torch-mlir:ci \ . - - name: Bazel build torch-mlir + - name: Verify buildifier was run (bazel lint) run: | docker run --rm \ -v "$(pwd)":"/opt/src/torch-mlir" \ -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ torch-mlir:ci \ - bazel build @torch-mlir//:torch-mlir-opt + bazel run @torch-mlir//:buildifier + if [ -n "$(git status --porcelain)" ]; then + echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." + exit 1 + fi - - name: Bazel test torch-mlir (lit tests) + - name: Bazel build torch-mlir run: | docker run --rm \ -v "$(pwd)":"/opt/src/torch-mlir" \ -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ torch-mlir:ci \ - bazel test @torch-mlir//test/... + bazel build @torch-mlir//:torch-mlir-opt - - name: Verify buildifier was run (bazel lint) + - name: Bazel test torch-mlir (lit tests) run: | docker run --rm \ -v "$(pwd)":"/opt/src/torch-mlir" \ -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ torch-mlir:ci \ - bazel run @torch-mlir//:buildifier - if [ -n "$(git status --porcelain)" ]; then - echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." - exit 1 - fi + bazel test @torch-mlir//test/... # Switch back bazel cache directory to user ownership # to allow GHA post-cache step to save cache without diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 4fc497ba99c6..7247a3683281 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest if: | github.repository == 'llvm/torch-mlir' && - github.event.workflow_run.actor.login == 'silvasean' && + github.event.workflow_run.actor.login == 'stellaraccident' && github.event.workflow_run.conclusion == 'success' steps: diff --git a/.gitmodules b/.gitmodules index 5b0f4e7479eb..8b46098d9615 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,6 @@ [submodule "externals/llvm-project"] path = externals/llvm-project - url = https://github.com/Xilinx/llvm-project.git - branch = misc_fixes -[submodule "externals/mlir-hlo"] - path = externals/mlir-hlo - url = https://github.com/tensorflow/mlir-hlo.git + url = https://github.com/llvm/llvm-project.git +[submodule "externals/stablehlo"] + path = externals/stablehlo + url = https://github.com/openxla/stablehlo.git diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000000..c6ccb034610a --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,19 @@ +cff-version: 1.2.0 +title: Torch-MLIR +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - name: LLVM +repository-code: 'https://github.com/llvm/torch-mlir' +abstract: >- + The Torch-MLIR project aims to provide first class support + from the PyTorch ecosystem to the MLIR ecosystem. +keywords: + - Compiler + - PyTorch + - MLIR +license: + - Apache-2.0 with LLVM Exceptions + - BSD diff --git a/CMakeLists.txt b/CMakeLists.txt index a3c636fc6272..deeb99c20216 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,14 +118,7 @@ else() endif() if (TORCH_MLIR_ENABLE_STABLEHLO) - set(STABLEHLO_BUILD_EMBEDDED ON) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo - ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo - EXCLUDE_FROM_ALL) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo) - include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include) - include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo) endif() set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") @@ -229,3 +222,18 @@ if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY) COMPONENT torch-mlir-headers) endif() endif() + +# Important: If loading StableHLO in this fashion, it must come last, +# after all of our libraries and test targets have been defined. +# It seems that they both abuse upstream CMake macros that accumulate +# properties. +# Getting this wrong results in building large parts of the stablehlo +# project that we don't actually depend on. Further some of those parts +# do not even compile on all platforms. +if (TORCH_MLIR_ENABLE_STABLEHLO) + set(STABLEHLO_BUILD_EMBEDDED ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo + ${CMAKE_CURRENT_BINARY_DIR}/stablehlo + EXCLUDE_FROM_ALL) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo) +endif() diff --git a/README.md b/README.md index e273cedea230..c5fa561bcd15 100644 --- a/README.md +++ b/README.md @@ -43,17 +43,17 @@ We have few paths to lower down to the Torch MLIR Dialect. ## Install torch-mlir snapshot -At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS. +At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.11 on Linux and macOS. -If you have Python 3.10, the following commands initialize a virtual environment. +If you have Python 3.11, the following commands initialize a virtual environment. ```shell -python3.10 -m venv mlir_venv +python3.11 -m venv mlir_venv source mlir_venv/bin/activate ``` -Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10. +Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.11. ```shell -conda create -n torch-mlir python=3.10 +conda create -n torch-mlir python=3.11 conda activate torch-mlir python -m pip install --upgrade pip ``` @@ -61,7 +61,7 @@ python -m pip install --upgrade pip Then, we can install torch-mlir with the corresponding torch and torchvision nightlies. ``` pip install --pre torch-mlir torchvision \ - -f https://llvm.github.io/torch-mlir/package-index/ + -f https://llvm.github.io/torch-mlir/package-index/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu ``` diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 5af371d56ef9..4444015805bd 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -467,7 +467,8 @@ def gen_fallback_code(*args, **kwargs): node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), tensor_class=self.tensor_class, - tensor_class_hdr="torch/csrc/lazy/core/tensor.h", + tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", + create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), lazy_ir_generator=GenMlirLazyIr, ) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index f6366dd20e36..bfc4641640aa 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -1,16 +1,7 @@ blacklist: -# List of unsupported ops in LTC autogen because of some error -- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here -- _index_put_impl # Error: TODO not sure if there are other valid types to handle here -- empty_like # Error: TODO add support for type BaseType(name=) -- index.Tensor # Error: TODO not sure if there are other valid types to handle here -- index_put # Error: TODO not sure if there are other valid types to handle here -- index_put_ # Error: TODO not sure if there are other valid types to handle here - -# Ops with list of tensors output -- split.Tensor -- unbind.int -- chunk +# Disabled in favour of `aten::index_put` which supports optional indices via `hacked_twin` JIT hack. +# It also doesn't have confusing `unsafe` argument. +- _index_put_impl # Additional ops which autogen is supported for but don't compile yet - _convolution @@ -21,48 +12,34 @@ blacklist: # Disabled for consistency with TS backend - lift_fresh_copy -- new_empty - rsub -- slice.Tensor # Disabled in favour of slice_copy.Tensor -- zeros -- ones -- arange -- arange.start -- arange.start_step -- fill.Scalar -- scalar_tensor # Disabled in favour of functionalized alternatives - _reshape_alias -- expand - permute - select.int -- squeeze - squeeze.dim -- t - transpose.int +- expand +- squeeze - unsqueeze - view +- slice.Tensor +- split.Tensor +- split_with_sizes +- unbind.int -whitelist: -# Enabled for consistency with TS backend -- arange.start_out - -# List of ops to autogen even if not supported by Torch-MLIR explicitly -#- split_copy.Tensor -#- split_with_sizes_copy -#- unbind_copy.int # List of supported ops that we don't want to do the full codegen for supported: -# - bernoulli -# - bernoulli_ - _to_copy - clone -- empty.memory_format -- empty_strided -- fill_.Scalar - _unsafe_view +- unbind_copy.int +- split_copy.Tensor +- split_with_sizes_copy +- index.Tensor +- index_put # ops required for functionalization - lift @@ -83,20 +60,21 @@ supported: - _trilinear - linalg_pinv.atol_rtol_tensor - logsumexp.out +- t # List of ops that will take in symints for the size instead of ints symint: -- empty.memory_format - new_empty_strided - expand_copy - narrow_copy - slice_backward - slice_copy.Tensor +- split_copy.Tensor - slice_scatter -- view - view_copy - as_strided_copy - as_strided_scatter +- split_with_sizes_copy additional_ops: diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 2d5d38568cf6..9f4d265b278a 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -177,14 +177,20 @@ function run_in_docker() { ;; out-of-tree) setup_venv "$python_version" "$TM_TORCH_VERSION" - build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" + build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION" + if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then + pushd /main_checkout/torch-mlir + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_torch_ods.sh + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_abstract_interp_lib.sh + popd + fi if [ "${TM_SKIP_TESTS}" == "OFF" ]; then test_out_of_tree fi ;; in-tree) setup_venv "$python_version" "$TM_TORCH_VERSION" - build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version" + build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION" if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then pushd /main_checkout/torch-mlir ./build_tools/update_torch_ods.sh @@ -208,6 +214,14 @@ function run_in_docker() { function build_in_tree() { local torch_from_bin="$1" local python_version="$2" + + local torch_version="$3" + local enable_ltc="ON" + if [[ "${torch_version}" == "stable" ]] + then + enable_ltc="OFF" + fi + echo ":::: Build in-tree Torch from binary: $torch_from_bin with Python: $python_version" cmake -GNinja -B/main_checkout/torch-mlir/build \ -DCMAKE_BUILD_TYPE=Release \ @@ -225,7 +239,7 @@ function build_in_tree() { -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="/main_checkout/torch-mlir/externals/llvm-external-projects/torch-mlir-dialects" \ -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_LTC=ON \ + -DTORCH_MLIR_ENABLE_LTC=${enable_ltc} \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_bin" \ -DTORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO} \ -DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \ @@ -269,7 +283,7 @@ function _check_file_not_changed_by() { function test_in_tree() { local torch_version="$1" - + echo ":::: Test in-tree" cmake --build /main_checkout/torch-mlir/build --target check-torch-mlir-all @@ -287,12 +301,21 @@ function test_in_tree() { echo ":::: Run Lazy Tensor Core e2e integration tests" python -m e2e_testing.main --config=lazy_tensor_core -v + + echo ":::: Run Linalg e2e integration tests" + python -m e2e_testing.main --config=linalg -v + + # Dynamo is changing a lot in nightly versions, and thus the implementation + # tends to become incompatible to the stable version. + echo ":::: Run TorchDynamo e2e integration tests" + python -m e2e_testing.main --config=torchdynamo -v ;; stable) echo ":::: Test with stable torch" - echo ":::: Run Lazy Tensor Core e2e integration tests in experimental mode" - python -m e2e_testing.main --config=lazy_tensor_core -v --ignore_failures + # Disabled until the next stable PyTorch release (v2.1) is available + # echo ":::: Run Lazy Tensor Core e2e integration tests in experimental mode" + # python -m e2e_testing.main --config=lazy_tensor_core -v --ignore_failures ;; *) echo "Unrecognized torch version '$torch_version'" @@ -303,15 +326,6 @@ function test_in_tree() { echo ":::: Run make_fx + TOSA e2e integration tests" python -m e2e_testing.main --config=make_fx_tosa -v - echo ":::: Run TorchDynamo e2e integration tests" - python -m e2e_testing.main --config=torchdynamo -v - - echo ":::: Run Linalg e2e integration tests" - python -m e2e_testing.main --config=linalg -v - - echo ":::: Run StableHLO e2e integration tests" - python -m e2e_testing.main --config=stablehlo -v - echo ":::: Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v } @@ -352,6 +366,13 @@ function build_out_of_tree() { local python_version="$2" echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version" + local torch_version="$3" + local enable_ltc="ON" + if [[ "${torch_version}" == "stable" ]] + then + enable_ltc="OFF" + fi + if [ ! -d "/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" ] then echo ":::: LLVM / MLIR is not built so building it first.." @@ -385,7 +406,7 @@ function build_out_of_tree() { -DLLVM_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/llvm/" \ -DMLIR_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" \ -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ - -DTORCH_MLIR_ENABLE_LTC=ON \ + -DTORCH_MLIR_ENABLE_LTC=${enable_ltc} \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_bin" \ -DTORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO} \ -DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \ diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index 6bc4b7109bbd..e0564a62dff8 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -41,7 +41,8 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then ext_module="${TORCH_MLIR_EXT_MODULES}" fi -PYTHONPATH="${pypath}" python \ +set +u +PYTHONPATH="${PYTHONPATH}:${pypath}" python \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \ --torch_ir_include_dir="${torch_ir_include_dir}" \ --pytorch_op_extensions="${ext_module}" \ diff --git a/docs/code_owners.md b/docs/code_owners.md index 3a37c6245f52..fa43136332d0 100644 --- a/docs/code_owners.md +++ b/docs/code_owners.md @@ -12,14 +12,14 @@ and Clang's ### All parts not covered by anyone else -- Sean Silva (@silvasean) -- Stella Laurenzo (@stellaraccident) -- mostly emeritus +- Stella Laurenzo (@stellaraccident) +- Sean Silva (@silvasean) - emeritus -------------------------------------------------------------------------------- ### `torch` dialect and other core IR pieces, Python bindings/API, JIT IR importer -- Sean Silva (@silvasean) +- Stella Laurenzo (@stellaraccident) ### TorchToLinalg, Shape inference, Dtype refinement, MaximizeValueSemantics diff --git a/docs/development.md b/docs/development.md index 048f363c0763..323db4d8ba9c 100644 --- a/docs/development.md +++ b/docs/development.md @@ -408,13 +408,18 @@ Torch-MLIR by default builds with the latest nightly PyTorch version. This can b # Updating the LLVM and MLIR-HLO submodules Torch-MLIR depends on `llvm-project` (which contains, among other things, -upstream MLIR) and `mlir-hlo`, both of which are submodules in the `externals/` +upstream MLIR) and `stablehlo`, both of which are submodules in the `externals/` directory. We aim to update these at least weekly to bring in the latest features and spread out over time the effort of updating our code for MLIR API breakages. ## Which LLVM commit should I pick? +NOTE: This section is in flux. Specifically, the `mlir-hlo` dep has been +dropped and the project is running off of a `stablehlo` fork which can be +patched for certain OS combinations. As of 2023-09-12, stellaraccident@ +is massaging this situation. Please reach out for advice updating. + Since downstream projects may want to build Torch-MLIR (and thus LLVM and MLIR-HLO) in various configurations (Release versus Debug builds; on Linux, Windows, or macOS; possibly with Clang, LLD, and LLDB enabled), it is crucial to diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 3893edee4765..57cc4f1ca223 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -24,13 +24,13 @@ ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend -from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from .xfail_sets import ( LINALG_XFAIL_SET, MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, + STABLEHLO_CRASHING_SET, TOSA_PASS_SET, LTC_XFAIL_SET, LTC_CRASHING_SET, @@ -43,7 +43,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -51,7 +51,6 @@ def _get_argparse(): help=f""" Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. -"stablehlo": run through torch-mlir"s default StableHLO backend. "tosa": run through torch-mlir"s default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). @@ -74,7 +73,7 @@ def _get_argparse(): parser.add_argument("--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed", metavar="TEST", type=str, nargs="+", help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.") - parser.add_argument("--ignore_failures", + parser.add_argument("--ignore_failures", default=False, action="store_true", help="return exit code 0 even if the test fails to unblock pipeline") @@ -99,10 +98,6 @@ def main(): config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True) xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET crashing_set = set() - elif args.config == "stablehlo": - config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) - xfail_set = all_test_unique_names - STABLEHLO_PASS_SET - crashing_set = set() elif args.config == "native_torch": config = NativeTorchTestConfig() xfail_set = set() diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index d6ddb62fbc2b..74eb5b9deb35 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -14,6 +14,7 @@ from torch_mlir._version import torch_version_for_comparison, version LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { +<<<<<<< HEAD "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingTransposeModule_basic", "Conv1dNoPaddingGroupModule_basic", @@ -25,6 +26,11 @@ "EyeStaticModule_basic", # No lowering available "FakeQuantizePerTensorAffineCachemaskModule_basic", +======= + # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed + # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", +>>>>>>> ff7f8b21dcc842a4f70209a6d255d54c4ef6e39b } TORCHDYNAMO_XFAIL_SET = { @@ -71,6 +77,7 @@ "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", "UniformModule_basic", + "UniformStaticShapeModule_basic", # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", @@ -174,6 +181,9 @@ # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt 'SqrtIntConstantModule_basic', + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size + 'BroadcastDynamicDimModule_basic', + # START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int 'AtenIntBoolOpConstFalseModule_basic', 'AtenIntBoolOpConstTrueModule_basic', @@ -268,8 +278,6 @@ "RandnGeneratorModule_basic", # START tests failing due to: complex floating point ops - "AtenComplexImagModule_basic", - "AtenComplexRealModule_basic", # END tests failing due to: complex floating point ops # ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int @@ -292,6 +300,7 @@ # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", +<<<<<<< HEAD # failed to legalize operation 'torch.aten.clamp' that was explicitly marked illegal "ElementwiseClampIntModule_basic", @@ -300,8 +309,29 @@ # No lowering to linalg "FakeQuantizePerTensorAffineCachemaskModule_basic", +======= + # AssertionError: Unregistered operation: torch.aten._unsafe_index_put + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + + # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed + # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + + # Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention; + "ScaledDotProductAttentionSameModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + + # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only + "AtenEmbeddingBagStaticModule_basic", +>>>>>>> ff7f8b21dcc842a4f70209a6d255d54c4ef6e39b } +if torch_version_for_comparison() < version.parse("2.1.0.dev"): + TORCHDYNAMO_XFAIL_SET -= { + "ScaledDotProductAttentionSameModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + } + TORCHDYNAMO_CRASHING_SET = { # No upstream decompositions. # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) @@ -333,18 +363,51 @@ "ToCopyModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", - - # See https://github.com/llvm/torch-mlir/issues/2178 - "Add_Module_basic" + "IndexPutImpl2DNoneIndexStaticModule_basic", } STABLEHLO_PASS_SET = { + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "AddIntModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "EqIntModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "MulIntModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "SqrtIntModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToIntZeroRank_basic", + "TensorToFloatZeroRank_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "AliasModule_basic", + "TensorIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", + "AtenFloatScalarModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", "AtenSubFloatModule_basic", "BoolFloatConstantModule_basic", "BoolIntConstantModule_basic", @@ -378,6 +441,7 @@ "ConstantBoolParameterModule_basic", "MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddSizeIntModule_basic", "AddSizeIntNegDimModule_basic", @@ -403,7 +467,8 @@ "BatchNorm1DStaticShapeModule_basic", "ResNet18StaticModule_basic", "AtenToDtypeModule_basic", - "BmmModule_basic", + "BmmFloatModule_basic", + "BmmIntModule_basic", "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastToDifferentRankStaticModule_basic", @@ -429,6 +494,7 @@ "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseClampModule_basic", @@ -442,6 +508,8 @@ "ElementwiseExpModule_basic", "ElementwiseFlattenBroadcastModule_basic", "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", "ElementwiseLogModule_basic", "ElementwiseNegModule_basic", "ElementwiseRsqrtModule_basic", @@ -470,6 +538,7 @@ "ElementwiseNeFloatScalarModule_basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseEqBoolScalarModule_basic", "ElementwiseErfModule_basic", "ElementwiseGeluModule_basic", "ElementwiseGtFloatScalarModule_basic", @@ -507,10 +576,20 @@ "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", "EmbeddingModuleF16_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", "ExpandAsIntModule_basic", "ExpandModule_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithInt64Static_basic", + "FlipModuleStaticShape_basic", + "FlipNegativeIndexModule_basic", "FullLikeModuleDefaultDtype_basic", "FullLikeModuleFalsePinMemory_basic", "FullLikeModuleFloat2D_basic", @@ -525,6 +604,14 @@ "FullModuleFloat3D_basic", "FullModuleInt2D_basic", "FullModuleInt3D_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat2D_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleFloat3D_basic", + "NewFullModuleInt2DStatic_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", "GatherStaticModule_basic", "GatherModule_basic", "Gather2DInputModdule_basic", @@ -629,10 +716,15 @@ "ViewOffsetBackwardTestStaticModule_basic", "NumToTensorFloatModule_basic", "AtenToDeviceModule_basic", + "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingGroupModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Convolution2DStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "Convolution2DGroupsStatic_basic", @@ -710,9 +802,13 @@ "NewEmptyModuleNonDefaultFloatDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", + "EmptyStridedModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", "ZerosLikeModule_defaultDtype", "ZerosLikeModule_falsePinMemory", "ZerosLikeModule_float", @@ -746,6 +842,9 @@ "NewZerosStaticModuleLayoutStrided_basic", "DropoutEvalIntModule_basic", "DropoutEvalFloatModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", "ContiguousModule_basic", "DropoutModule_basic", "ViewCollapseModule_basic", @@ -770,6 +869,9 @@ "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", "PrimsSumFloatModule_basic", + "ReduceMinFloatModule_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumFloatModule_basic", @@ -781,6 +883,7 @@ "ReshapeExpandModule_basic", "RollModule_basic", "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "BaddbmmStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", @@ -789,6 +892,8 @@ "NarrowHorizontalTest_basic", "NarrowVerticalTest2_basic", "NarrowVerticalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", "NumToTensorIntModule_basic", "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", @@ -809,6 +914,7 @@ "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", "TypeAsSameModule_basic", + "TypeAsDifferentModule_basic", "TypeConversionF32ToF64Module_basic", "TypeConversionF64ToF32Module_basic", "TypeConversionI1ToF32Module_basic", @@ -838,6 +944,9 @@ "AtenComplex64Module_basic", "SplitTensorGetItem_Module_basic", "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitTensorLastSmallerModule_basic", + "SplitWithSizesListUnpackModule_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", "ChunkListUnpack_Module_basic", @@ -847,12 +956,33 @@ "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "RandModule_basic", + "UniformStaticShapeModule_basic", "UniformNoCorrelationModule_basic", + "TupleModule_basic", + "AtenEmbeddingBagStaticModule_basic", +} + +STABLEHLO_CRASHING_SET = { + # These e2e tests crash because currently mlir-hlo's shape-component-analysis + # only support exact one index in tensor::ExtractOp when it's related with + # some tensors' shape. REF: + # https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586 + # FIXME if upstream mlir-hlo fix this. + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + + "Aten_EmbeddingBagExample_basic", + "AtenEmbeddingBagSumExample_basic" } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "AliasModule_basic", "MaxPool2dEmptyStrideStaticModule_basic", "ConstantBoolParameterModule_basic", "ElementwiseCloneContiguousModule_basic", @@ -864,11 +994,15 @@ "ElementwiseExpModule_basic", "ElementwiseReluModule_basic", "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseMinimumModule_basic", "ElementwiseMinimumIntModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", "ElementwiseMaximumModule_basic", "ElementwiseMaximumIntModule_basic", "ElementwiseSinModule_basic", @@ -880,6 +1014,8 @@ "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic", "ElementwiseClampIntModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", "ViewDoubleMergeStaticModule_basic", "ViewCollapseOnesMiddleModule_basic", "ViewFiveTestStaticModule_basic", @@ -922,7 +1058,7 @@ "ElementwisePowTensorModule_basic", "ElementwisePowTensorStaticModule_basic", "AtenToDtypeModule_basic", - "BmmModule_basic", + "BmmFloatModule_basic", "MmDagModule_basic", "Matmul4dStatic_basic", "Matmul_dot", @@ -934,6 +1070,8 @@ "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", "ElementwiseBitwiseOrModule_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorModule_basic", @@ -986,6 +1124,8 @@ "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingGroupModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "BatchNorm1DModule_basic", "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", @@ -1114,6 +1254,10 @@ "FullModuleFloat3D_basic", "FullModuleFalsePinMemory_basic", "FullModuleInt2D_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleFloat3D_basic", "MaskedFillScalarDefaultModule_basic", "MaskedFillScalarFloatValueModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", @@ -1145,6 +1289,7 @@ "SliceStaticModule_basic", "SliceSizeTwoStepDivisibleStaticModule_basic", "SliceOutOfLowerBoundStartIndexStaticModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", "ArangeStartStepIntModule_basic", "ArangeDtypeFloatModule_basic", "ArangeIntModule_basic", @@ -1240,10 +1385,21 @@ "Fill_TensorFloat64WithFloat32Static_basic", "SplitTensorGetItem_Module_basic", "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitTensorLastSmallerModule_basic", + "SplitWithSizesListUnpackModule_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", "RepeatInterleaveStaticModule_basic", "RepeatInterleaveFillModule_basic", + "TupleModule_basic", + "NumpyTRank0Module_basic", + "Permute0RankModule_basic", + "Add_Module_basic", + "SoftmaxIntModule_basic", + "SoftmaxIntNegDimModule_basic", + "_LogSoftmaxModule_basic", + "_SoftmaxModule_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1269,17 +1425,38 @@ "NormalizeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", + + # failed to legalize operation 'torch.aten.max_pool2d_with_indices + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticModule_basic", + "ResNet18StaticModule_basic", + + # Unimplemented operator 'aten._index_put_impl_.hacked_twin' + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + # RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1 + "Add_Module_basic", } if torch_version_for_comparison() < version.parse("2.1.0.dev"): MAKE_FX_TOSA_PASS_SET -= { # 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1 "ReshapeCollapseModule_basic", + + # failed to lower torch.aten.empty.memory_format + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", } LTC_CRASHING_SET = { @@ -1287,7 +1464,10 @@ "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingTransposeModule_basic", "Conv1dNoPaddingGroupModule_basic", - "Add_Module_basic" + "Add_Module_basic", + # TODO: update test to move all inputs to the lazy device. Otherwise test fails with: + # Check failed: lazy_tensor Input tensor is not a lazy tensor: CPUBoolType. + "HBC_basic", } LTC_XFAIL_SET = { @@ -1300,8 +1480,6 @@ "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", @@ -1314,42 +1492,12 @@ "BoolIntTrueModule_basic", "CeilFloatModule_basic", "DivFloatModule_basic", - "ElementwiseAtenFloorDivideBroadcastModule_basic", - "ElementwiseAtenFloorDivideModule_basic", "EqIntModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", "GtFloatIntModule_basic", "GtIntModule_basic", - "HBC_basic", - "HardtanhBackward_basic", - "IndexPut1DFloatAccumulateModule_basic", - "IndexPut1DFloatNonAccumulateModule_basic", - "IndexPut1DIntAccumulateModule_basic", - "IndexPut1DIntNonAccumulateModule_basic", - "IndexPut2DFloatAccumulateModule_basic", - "IndexPut2DFloatNonAccumulateModule_basic", - "IndexPut2DIntAccumulateModule_basic", - "IndexPut2DIntNonAccumulateModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", - "IndexPut3DFloatAccumulateModule_basic", - "IndexPut3DFloatNonAccumulateModule_basic", - "IndexPut3DIntAccumulateModule_basic", - "IndexPut3DIntNonAccumulateModule_basic", - "IndexPutHackedTwin1DFloatAccumulateModule_basic", - "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin1DIntAccumulateModule_basic", - "IndexPutHackedTwin1DIntNonAccumulateModule_basic", - "IndexPutHackedTwin2DFloatAccumulateModule_basic", - "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin2DIntAccumulateModule_basic", - "IndexPutHackedTwin2DIntNonAccumulateModule_basic", - "IndexPutHackedTwin3DFloatAccumulateModule_basic", - "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin3DIntAccumulateModule_basic", - "IndexPutHackedTwin3DIntNonAccumulateModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", @@ -1357,36 +1505,16 @@ "IndexPutImpl2DFloatAccumulateModule_basic", "IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule3dInputStatic_basic", - "IndexTensorModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorMultiIndexStaticModule_basic", - "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputNonContiguous_basic", - "IndexTensorMultiInputOneDim_basic", - "IndexTensorMultiInputThreeIndexers_basic", - "IndexTensorMultiInput_basic", - "IndexTensorSelectDimModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousDynamic_basic", - "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", - "LiftFreshCopyModule_basic", "Matmul_dot", "MulIntModule_basic", "DivIntModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", @@ -1398,8 +1526,6 @@ "SqrtIntModule_basic", "SubFloatModule_basic", "SubIntModule_basic", - "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", @@ -1407,33 +1533,21 @@ "TensorToFloat_basic", "TensorToIntZeroRank_basic", "TensorToInt_basic", - "TensorsConcatModule_basic", - "TensorsConcatStaticModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsConcatPromoteDTypeStaticModule_basic", "UniformModule_basic", - "UniformNoCorrelationModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", "UpSampleNearest2dBackwardVec_basic", "UpSampleNearest2dBackwardOutputSizeNone_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2DPadded_basic", "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", - "PrimsConvertElementTypeModule_basic", - "PrimsSumFloatModule_basic", "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", @@ -1443,26 +1557,13 @@ "BernoulliModule_basic", "BernoulliPModule_basic", "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionNoneModule_basic", - "VarBiasedModule_basic", - "VarCorrectionAllDimReduceModule_basic", - "VarCorrectionEmptyDimModule_basic", "VarCorrectionKeepDimModule_basic", - "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", "VarCorrectionNoneModule_basic", - "VarCorrectionSingleDimReduceModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimBiasedModule_basic", - "VarDimEmptyDimModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", - "VarDimNoneDimModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", - "VarUnbiasedModule_basic", "AtenFloatScalarModule_basic", "PrimsSqueezeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", @@ -1491,4 +1592,9 @@ "RepeatInterleaveFillModule_basic", "Im2ColModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "UniformStaticShapeModule_basic", + "AtenEmbeddingBagStaticModule_basic", + "EmptyStridedModule_basic", } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index ba7ed76c81cf..dcb2f4215891 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, loc, init, [&](OpBuilder &b, Location loc, Value elem, Value acc) { Value x = b.create(loc, weight, localIVs); - Value max = b.create(loc, x, acc); + Value max = b.create(loc, x, acc); b.create(loc, max); }); }) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 36d061f3237e..64352ad1d5ce 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -31,7 +31,7 @@ using namespace ::mlir::torch::TMTensor; static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = memref.getType().cast(); auto alloc = b.create( - loc, memrefType, linalg::createDynamicDimensions(b, loc, memref)); + loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType()); b.create(loc, memref, alloc); return alloc; } @@ -73,8 +73,8 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, } resultBuffers.push_back(b.create( - loc, memrefType, - linalg::createDynamicDimensions(b, loc, resultTensor))); + loc, memref::getMixedSizes(b, loc, resultTensor), + memrefType.getElementType())); } return success(); } diff --git a/externals/llvm-project b/externals/llvm-project index 1683a67080e3..d13da154a7c7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 1683a67080e30a9c8055728d02640668d66e12f7 +Subproject commit d13da154a7c7eff77df8686b2de1cfdfa7cc7029 diff --git a/externals/mlir-hlo b/externals/mlir-hlo deleted file mode 160000 index a4ac6990f751..000000000000 --- a/externals/mlir-hlo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a4ac6990f7519a569a380452d7c1d3764aad7e59 diff --git a/externals/stablehlo b/externals/stablehlo new file mode 160000 index 000000000000..77a59815a82b --- /dev/null +++ b/externals/stablehlo @@ -0,0 +1 @@ +Subproject commit 77a59815a82b34f7b08ed2d42a711d9920682d0e diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 4524b9d5a78e..c852dd61387d 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -34,6 +34,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); +/// Gets the !torch.nn.Module typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.optional type. //===----------------------------------------------------------------------===// @@ -49,6 +52,9 @@ torchMlirTorchOptionalTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchOptionalTypeGetContained(MlirType containedType); +/// Gets the !torch.optional typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.tuple type. //===----------------------------------------------------------------------===// @@ -65,7 +71,11 @@ torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t); /// Returns the pos-th type in the !torch.tuple type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); +MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, + intptr_t pos); + +/// Gets the !torch.tuple typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.union type. @@ -83,7 +93,11 @@ torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t); /// Returns the pos-th type in the !torch.union type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); +MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, + intptr_t pos); + +/// Gets the !torch.union typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.list type. @@ -98,6 +112,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); /// Gets contained T in a !torch.list type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); +/// Gets the !torch.list typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -108,6 +125,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t); /// Gets the !torch.Device type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); +/// Gets the !torch.device typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.Generator type. //===----------------------------------------------------------------------===// @@ -118,6 +138,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t); /// Gets the !torch.Generator type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); +/// Gets the !torch.generator typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.bool type. //===----------------------------------------------------------------------===// @@ -128,6 +151,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t); /// Gets the !torch.bool type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); +/// Gets the !torch.bool typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.int type. //===----------------------------------------------------------------------===// @@ -138,6 +164,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t); /// Gets the !torch.int type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); +/// Gets the !torch.int typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.float type. //===----------------------------------------------------------------------===// @@ -148,6 +177,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t); /// Gets the !torch.float type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); +/// Gets the !torch.float typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.LinearParams type. //===----------------------------------------------------------------------===// @@ -159,6 +191,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context); +/// Gets the !torch.linearparams typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.qint8 type. //===----------------------------------------------------------------------===// @@ -169,6 +204,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); /// Gets the !torch.qint8 type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); +/// Gets the !torch.qint8 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.quint8 type. //===----------------------------------------------------------------------===// @@ -179,6 +217,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); /// Gets the !torch.quint8 type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); +/// Gets the !torch.quint8 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// @@ -217,10 +258,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t); /// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size /// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); +MLIR_CAPI_EXPORTED int64_t +torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); /// Gets the the dtype (data type) of a !torch.tensor. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); + +/// Gets the !torch.tensor typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.vtensor type. @@ -259,11 +305,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t); /// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size /// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); +MLIR_CAPI_EXPORTED int64_t +torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); /// Gets the the dtype (data type) of a !torch.vtensor. MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); +/// Gets the !torch.vtensor typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.none type. //===----------------------------------------------------------------------===// @@ -274,6 +324,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t); /// Gets the !torch.none type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); +/// Gets the !torch.none typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.str type. //===----------------------------------------------------------------------===// @@ -284,6 +337,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t); /// Gets the !torch.str type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); +/// Gets the !torch.str typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.any type. //===----------------------------------------------------------------------===// @@ -294,6 +350,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t); /// Gets the !torch.str type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); +/// Gets the !torch.any typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.number type. //===----------------------------------------------------------------------===// @@ -304,6 +363,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t); /// Gets the !torch.number type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); +/// Gets the !torch.number typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.dict type. //===----------------------------------------------------------------------===// @@ -324,6 +386,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t); /// Gets the value type of a !torch.dict type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); +/// Gets the !torch.dict typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(); + #ifdef __cplusplus } #endif diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 6d31d267ac0b..e8d57b7f6a72 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -45,7 +45,8 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarValue, Type dtype); -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); +Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + TensorType outType); Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, TensorType outType); diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 3ff4581d6895..c1b355e3c50d 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -58,6 +58,12 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Value params_value, Value indices_value); +std::optional convertScatterNdOp(PatternRewriter &rewriter, + Operation *op, Type outType, + Value paramsValue, Value indicesValue, + Value fillValues); + + // Lowers ReduceAll to a sequence of TOSA ops. std::optional convertReduceAllOp(PatternRewriter &rewriter, Operation *op, diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index a91074d43178..5e6934001d7c 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -59,7 +59,8 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // To create INT48 TOSA constant, need to pass in llvm::APInt instead. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape, std::optional dtype = {}); + ArrayRef vec, ArrayRef shape, + std::optional dtype = {}); LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result); diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 485160b7e830..8795974a395c 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -76,7 +76,7 @@ SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, // convert their elements to valid target type. // TODO: remove this when list gets full support. SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, - TypeConverter *converter, + const TypeConverter *converter, SmallVectorImpl &vs); mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 412291292872..1e1c84c86def 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -113,6 +113,57 @@ def Torch_AtenHardtanh_Op : Torch_Op<"aten.hardtanh_", [ }]; } +def Torch_AtenEluOp : Torch_Op<"aten.elu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenEluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenElu_Op : Torch_Op<"aten.elu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::elu_ : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenElu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenElu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenReluOp : Torch_Op<"aten.relu", [ AllowsTypeRefinement, HasValueSemantics, @@ -385,6 +436,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ }]; } +def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sgn : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSgnOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSgnOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSgn_Op : Torch_Op<"aten.sgn_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sgn_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSgn_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSgn_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [ AllowsTypeRefinement, HasValueSemantics, @@ -520,6 +616,51 @@ def Torch_AtenErf_Op : Torch_Op<"aten.erf_", [ }]; } +def Torch_AtenErfinvOp : Torch_Op<"aten.erfinv", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::erfinv : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenErfinvOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenErfinvOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenErfinv_Op : Torch_Op<"aten.erfinv_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::erfinv_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenErfinv_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenErfinv_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSiluOp : Torch_Op<"aten.silu", [ AllowsTypeRefinement, HasValueSemantics, @@ -2290,6 +2431,53 @@ def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ }]; } +def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$min + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$min + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ AllowsTypeRefinement, HasValueSemantics, @@ -2337,6 +2525,53 @@ def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ }]; } +def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$max + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$max + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ AllowsTypeRefinement, HasValueSemantics, @@ -3546,6 +3781,30 @@ def Torch_AtenMishOp : Torch_Op<"aten.mish", [ }]; } +def Torch_AtenXlogyTensorOp : Torch_Op<"aten.xlogy.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenXlogyTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenXlogyTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -3832,86 +4091,209 @@ def Torch_AtenViewAsComplexOp : Torch_Op<"aten.view_as_complex", [ }]; } -def Torch_AtenUniformOp : Torch_Op<"aten.uniform", [ +def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::view_as_real : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_FloatType:$from, - Torch_FloatType:$to, - AnyTorchOptionalGeneratorType:$generator + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUniformOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenViewAsRealOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenUniformOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenViewAsRealOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::unbind_copy.int : (Tensor, int) -> (Tensor[])`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_FloatType:$from, - Torch_FloatType:$to, - AnyTorchOptionalGeneratorType:$generator + Torch_IntType:$dim ); let results = (outs - AnyTorchTensorType:$result + AnyTorchListOfTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUniform_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenUnbindCopyIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenUniform_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenUnbindCopyIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ +def Torch_AtenSplitCopyTensorOp : Torch_Op<"aten.split_copy.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::split_copy.Tensor : (Tensor, int, int) -> (Tensor[])`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + Torch_IntType:$split_size, + Torch_IntType:$dim ); let results = (outs - AnyTorchTensorType:$result + AnyTorchListOfTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRandLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenSplitCopyTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenRandLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenSplitCopyTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ +def Torch_AtenSplitWithSizesCopyOp : Torch_Op<"aten.split_with_sizes_copy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$split_sizes, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitWithSizesCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitWithSizesCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenUniformOp : Torch_Op<"aten.uniform", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$from, + Torch_FloatType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniformOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUniformOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$from, + Torch_FloatType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniform_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUniform_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandLikeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRandLikeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenRandOp : Torch_Op<"aten.rand", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRandOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly @@ -3983,6 +4365,32 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ }]; } +def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$num_samples, + Torch_BoolType:$replacement, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMultinomialOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenMultinomialOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ AllowsTypeRefinement, HasValueSemantics, @@ -4172,6 +4580,56 @@ def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [ }]; } +def Torch_AtenRandomOp : Torch_Op<"aten.random", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random : (Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRandomOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenRandomFromOp : Torch_Op<"aten.random.from", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$from, + AnyTorchOptionalIntType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomFromOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRandomFromOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics, @@ -4414,6 +4872,32 @@ def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [ }]; } +def Torch_Aten_UnsafeIndexPutHackedTwinOp : Torch_Op<"aten._unsafe_index_put.hacked_twin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_UnsafeIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void Aten_UnsafeIndexPutHackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ AllowsTypeRefinement, HasValueSemantics, @@ -4990,6 +5474,32 @@ def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ }]; } +def Torch_AtenNormalFunctionalOp : Torch_Op<"aten.normal_functional", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$mean, + Torch_FloatType:$std, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormalFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNormalFunctionalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ AllowsTypeRefinement, HasValueSemantics, @@ -5106,56 +5616,260 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in }]; } -def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ +def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$kernel_size, AnyTorchListOfTorchIntType:$stride, AnyTorchListOfTorchIntType:$padding, - Torch_BoolType:$ceil_mode, - Torch_BoolType:$count_include_pad, - AnyTorchOptionalIntType:$divisor_override + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenAvgPool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenMaxPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ +def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$dtype + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMaxPool3dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); } - void AtenSoftmaxIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMaxPool3dWithIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + +def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode, + AnyTorchTensorType:$indices + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool3dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenMaxPool3dWithIndicesBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSoftmaxIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } @@ -5312,159 +6026,352 @@ def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [ }]; } -def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ +def Torch_AtenMaskedScatterOp : Torch_Op<"aten.masked_scatter", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$output_size + AnyTorchTensorType:$mask, + AnyTorchTensorType:$source ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenMaskedScatterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenMaskedScatterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::masked_scatter_ : (Tensor, Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$k, - Torch_IntType:$dim, - Torch_BoolType:$largest, - Torch_BoolType:$sorted + AnyTorchTensorType:$mask, + AnyTorchTensorType:$source ); let results = (outs - AnyTorchTensorType:$values, - AnyTorchTensorType:$indices + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTopkOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + ParseResult AtenMaskedScatter_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenTopkOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + void AtenMaskedScatter_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ +def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::transpose.int : (Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim0, - Torch_IntType:$dim1 + AnyTorchListOfTorchIntType:$output_size ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTransposeIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenTransposeIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ +def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dims + AnyTorchListOfTorchIntType:$output_size ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPermuteOp::print(OpAsmPrinter &printer) { + void AtenAdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ +def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::movedim.int : (Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$source, - Torch_IntType:$destination + AnyTorchListOfTorchIntType:$output_size ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMovedimIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult Aten_AdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMovedimIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void Aten_AdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ +def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bmm : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$mat2 + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBmmOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten_AdaptiveAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBmmOp::print(OpAsmPrinter &printer) { + void Aten_AdaptiveAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ +def Torch_AtenAdaptiveAvgPool3dOp : Torch_Op<"aten.adaptive_avg_pool3d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cumsum : (Tensor, int, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dOp : Torch_Op<"aten._adaptive_avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$k, + Torch_IntType:$dim, + Torch_BoolType:$largest, + Torch_BoolType:$sorted + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTopkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenTopkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::transpose.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim0, + Torch_IntType:$dim1 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTransposeIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenTransposeIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPermuteOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::movedim.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$source, + Torch_IntType:$destination + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMovedimIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMovedimIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bmm : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBmmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBmmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cumsum : (Tensor, int, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, @@ -5583,6 +6490,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ }]; } +def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Or__TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Or__TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ AllowsTypeRefinement, HasValueSemantics, @@ -5854,308 +6786,557 @@ def Torch_AtenVarMeanDimOp : Torch_Op<"aten.var_mean.dim", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenVarMeanDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 2); + ParseResult AtenVarMeanDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 2); + } + void AtenVarMeanDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 2); + } + }]; +} + +def Torch_AtenNllLoss2dForwardOp : Torch_Op<"aten.nll_loss2d_forward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchTensorType:$total_weight + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLoss2dForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenNllLoss2dForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_AtenNllLoss2dBackwardOp : Torch_Op<"aten.nll_loss2d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + AnyTorchTensorType:$total_weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLoss2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenNllLoss2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchTensorType:$total_weight + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLossForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenNllLossForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + AnyTorchTensorType:$total_weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenNllLossBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalTensorType:$weights, + Torch_IntType:$minlength + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBincountOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenBincountOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$ord, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A, + Torch_StringType:$mode + ); + let results = (outs + AnyTorchTensorType:$Q, + AnyTorchTensorType:$R + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgQrOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenLinalgQrOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + +def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFrobeniusNormDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenFrobeniusNormDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMseLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMseLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMseLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenVarMeanDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 2); + void AtenMseLossBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenNllLoss2dForwardOp : Torch_Op<"aten.nll_loss2d_forward", [ +def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index + AnyTorchTensorType:$grad_output, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$input_size, + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w ); let results = (outs - AnyTorchTensorType:$output, - AnyTorchTensorType:$total_weight + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLoss2dForwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + ParseResult AtenUpsampleNearest2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenNllLoss2dForwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + void AtenUpsampleNearest2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenNllLoss2dBackwardOp : Torch_Op<"aten.nll_loss2d_backward", [ +def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, AnyTorchTensorType:$target, AnyTorchOptionalTensorType:$weight, Torch_IntType:$reduction, Torch_IntType:$ignore_index, - AnyTorchTensorType:$total_weight + Torch_FloatType:$label_smoothing ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLoss2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenCrossEntropyLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenNllLoss2dBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenCrossEntropyLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ +def Torch_AtenNonzeroOp : Torch_Op<"aten.nonzero", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::nonzero : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$output, - AnyTorchTensorType:$total_weight + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLossForwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + ParseResult AtenNonzeroOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNllLossForwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + void AtenNonzeroOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ +def Torch_AtenNonzeroNumpyOp : Torch_Op<"aten.nonzero_numpy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::nonzero_numpy : (Tensor) -> (Tensor[])`"; let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index, - AnyTorchTensorType:$total_weight + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + AnyTorchListOfTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenNonzeroNumpyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNllLossBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenNonzeroNumpyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ +def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`"; + let summary = "Generated op for `aten::nonzero_static : (Tensor, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$weights, - Torch_IntType:$minlength + Torch_IntType:$size, + Torch_IntType:$fill_value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBincountOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenNonzeroStaticOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenBincountOp::print(OpAsmPrinter &printer) { + void AtenNonzeroStaticOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ +def Torch_AtenBinaryCrossEntropyOp : Torch_Op<"aten.binary_cross_entropy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$ord, - AnyTorchOptionalListOfTorchIntType:$dim, - Torch_BoolType:$keepdim, - AnyTorchOptionalIntType:$dtype + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenBinaryCrossEntropyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenBinaryCrossEntropyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ +def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)`"; let arguments = (ins + AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, - Torch_BoolType:$keepdim + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFrobeniusNormDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenBinaryCrossEntropyBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenFrobeniusNormDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenBinaryCrossEntropyBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [ +def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - Torch_IntType:$reduction + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$output, + AnyTorchTensorType:$buffer ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMseLossOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLogSigmoidForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 2); } - void AtenMseLossOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLogSigmoidForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 2); } }]; } -def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [ +def Torch_AtenLogSigmoidBackwardOp : Torch_Op<"aten.log_sigmoid_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - Torch_IntType:$reduction + AnyTorchTensorType:$buffer ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMseLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenLogSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenMseLossBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenLogSigmoidBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ +def Torch_AtenSigmoidBackwardOp : Torch_Op<"aten.sigmoid_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)`"; + let summary = "Generated op for `aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$grad_output, - AnyTorchListOfTorchIntType:$output_size, - AnyTorchListOfTorchIntType:$input_size, - AnyTorchOptionalFloatType:$scales_h, - AnyTorchOptionalFloatType:$scales_w + AnyTorchTensorType:$output ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUpsampleNearest2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenUpsampleNearest2dBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenSigmoidBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [ +def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)`"; + let summary = "Generated op for `aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, + AnyTorchTensorType:$input1, + AnyTorchTensorType:$input2, AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index, - Torch_FloatType:$label_smoothing + Torch_FloatType:$margin, + Torch_IntType:$reduction ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCrossEntropyLossOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenCosineEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenCrossEntropyLossOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenCosineEmbeddingLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } @@ -6487,6 +7668,61 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ }]; } +def Torch_AtenEyeOp : Torch_Op<"aten.eye", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$n, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEyeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenEyeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenEyeMOp : Torch_Op<"aten.eye.m", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$n, + Torch_IntType:$m, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEyeMOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEyeMOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -6684,6 +7920,31 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ }]; } +def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::all.dim : (Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAllDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAllDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenAnyOp : Torch_Op<"aten.any", [ AllowsTypeRefinement, HasValueSemantics, @@ -6834,18 +8095,43 @@ def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [ ParseResult AtenArangeStartOutOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenArangeStartOutOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenArangeStartOutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenArgmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenArgmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ +def Torch_AtenArgminOp : Torch_Op<"aten.argmin", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::argmin : (Tensor, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchOptionalIntType:$dim, @@ -6856,10 +8142,10 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArgmaxOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenArgminOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenArgmaxOp::print(OpAsmPrinter &printer) { + void AtenArgminOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; @@ -7086,6 +8372,54 @@ def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ let hasFolder = 1; } +def Torch_AtenDeviceWithIndexOp : Torch_Op<"aten.device.with_index", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::device.with_index : (str, int) -> (Device)`"; + let arguments = (ins + Torch_StringType:$type, + Torch_IntType:$index + ); + let results = (outs + Torch_DeviceType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDeviceWithIndexOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenDeviceWithIndexOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenCudaOp : Torch_Op<"aten.cuda", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::cuda : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCudaOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCudaOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ AllowsTypeRefinement, HasValueSemantics, @@ -7351,6 +8685,34 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ let hasCanonicalizer = 1; } +def Torch_AtenEmptyStridedOp : Torch_Op<"aten.empty_strided", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEmptyStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ AllowsTypeRefinement, ReadOnly @@ -7420,6 +8782,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ } }]; let hasCanonicalizer = 1; + let hasFolder = 1; } def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ @@ -7667,6 +9030,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ }]; } +def Torch_AtenTileOp : Torch_Op<"aten.tile", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::tile : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTileOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenTileOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRepeatInterleaveTensorOp : Torch_Op<"aten.repeat_interleave.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -7738,6 +9125,31 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ }]; } +def Torch_AtenResizeOp : Torch_Op<"aten.resize", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::resize : (Tensor, int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenResizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenResizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ AllowsTypeRefinement ]> { @@ -7825,70 +9237,220 @@ def Torch_AtenSumOp : Torch_Op<"aten.sum", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSumDimIntListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenSumDimIntListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenProdDimIntOp : Torch_Op<"aten.prod.dim_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::prod.dim_int : (Tensor, int, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenProdDimIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenProdDimIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenMaxOp : Torch_Op<"aten.max", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenMaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenMaxOtherOp : Torch_Op<"aten.max.other", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max.other : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxOtherOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMaxOtherOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenMaxDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + +def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenSumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ +def Torch_AtenMinOp : Torch_Op<"aten.min", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::min : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalListOfTorchIntType:$dim, - Torch_BoolType:$keepdim, - AnyTorchOptionalIntType:$dtype + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSumDimIntListOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSumDimIntListOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenMinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenMaxOp : Torch_Op<"aten.max", [ +def Torch_AtenMinOtherOp : Torch_Op<"aten.min.other", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::min.other : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenMinOtherOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenMinOtherOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ +def Torch_AtenMinDimOp : Torch_Op<"aten.min.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, @@ -7900,21 +9462,21 @@ def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMinDimOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 2); } - void AtenMaxDimOp::print(OpAsmPrinter &printer) { + void AtenMinDimOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 2); } }]; } -def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ +def Torch_AtenAminOp : Torch_Op<"aten.amin", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::amin : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$dim, @@ -7925,10 +9487,10 @@ def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAminOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAmaxOp::print(OpAsmPrinter &printer) { + void AtenAminOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; @@ -8016,6 +9578,7 @@ def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ @@ -8093,7 +9656,6 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_AtenViewOp : Torch_Op<"aten.view", [ @@ -8744,6 +10306,35 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ }]; } +def Torch_AtenNewFullOp : Torch_Op<"aten.new_full", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchScalarType:$fill_value, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNewFullOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenNewFullOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [ AllowsTypeRefinement, HasValueSemantics, @@ -8823,6 +10414,58 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFmodTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFmodTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$return_inverse, + Torch_BoolType:$return_counts, + AnyTorchOptionalIntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniqueConsecutiveOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 3); + } + void AtenUniqueConsecutiveOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 3); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -8846,6 +10489,29 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ }]; } +def Torch_AtenAliasOp : Torch_Op<"aten.alias", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::alias : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAliasOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAliasOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -9240,6 +10906,60 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ }]; } +def Torch_AtenIm2colOp : Torch_Op<"aten.im2col", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$stride + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIm2colOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIm2colOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenScatterReduceOp : Torch_Op<"aten.scatter.reduce", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src, + Torch_StringType:$reduce + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterReduceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenScatterReduceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, @@ -9805,6 +11525,7 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [ @@ -9879,6 +11600,30 @@ def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [ }]; } +def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$split_sizes, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitWithSizesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitWithSizesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ AllowsTypeRefinement, ReadOnly @@ -10455,6 +12200,30 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ }]; } +def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRemainderTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRemainderTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -10624,6 +12393,7 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ @@ -10673,6 +12443,7 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [ @@ -10721,6 +12492,7 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [ @@ -11184,6 +12956,7 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenSubOp : Torch_Op<"aten.sub", [ @@ -11380,6 +13153,31 @@ def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [ }]; } +def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$start, + Torch_IntType:$length + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNarrowTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ AllowsTypeRefinement, HasValueSemantics, @@ -11738,6 +13536,34 @@ def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", }]; } +def Torch_AtenEluBackwardOp : Torch_Op<"aten.elu_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale, + Torch_BoolType:$is_result, + AnyTorchTensorType:$self_or_result + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEluBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEluBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 7783b26abf08..64b70e097c39 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -10,6 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index f372b966deea..c86244f5f1e3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -506,7 +506,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> { } def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "TorchScript prim::Loop op"; let description = [{ This op (together with prim.Loop.condition) define a looping construct diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e168eaea204e..c083a8e8e217 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -128,7 +128,8 @@ class AnyTorchTensorType | torch.bool | i1 | | torch.qint8 | !torch.qint8 | | torch.quint8 | !torch.quint8 | - | torch.complex* | complex<*> | + | torch.complex64 | complex | + | torch.complex128 | complex | |-------------------|--------------------| ``` diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index e6493a154edd..d762bd840f7f 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -57,6 +57,14 @@ std::unique_ptr> createFuncBackendTypeConversionPass(); std::unique_ptr> createFinalizingBackendTypeConversionPass(); +// These passes do a one-off conversion of a specific kind of quantized group +// matmul as a prototype. Generalized quantized operation handling will likely +// obviate them but that are being carried for now in order to unblock progress +// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for +// the plan to support a more generalized lowering for these graphs. +std::unique_ptr> createUnpackQuantTensorPass(); +std::unique_ptr> createConvertCustomQuantOpPass(); + std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index cb58dbbd998b..4d3e16a81c5c 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -48,4 +48,16 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; } #endif // TORCH_MLIR_ENABLE_STABLEHLO + +// The following passes are for a one-off conversion of a specific kind of quantized group matmul. +// They should not be included in default lowering flows until further along. +def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { + let summary = "Unpack quantized int4 tensor from int8 containter"; + let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; +} + +def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { + let summary = "Convert torch custom quant op to linalg"; + let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; +} #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 76ae43c2c38b..f4a9ca032fce 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -34,6 +34,10 @@ MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); } +MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() { + return wrap(Torch::NnModuleType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.optional type. //===----------------------------------------------------------------------===// @@ -47,8 +51,12 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { } MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getContainedType()); + auto type = unwrap(t).cast(); + return wrap(type.getContainedType()); +} + +MlirTypeID torchMlirTorchOptionalTypeGetTypeID() { + return wrap(Torch::OptionalType::getTypeID()); } //===----------------------------------------------------------------------===// @@ -63,10 +71,9 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes) { return wrap(Torch::TupleType::get( - unwrap(context), - llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), - [](MlirType t) { return unwrap(t); })))); + unwrap(context), llvm::to_vector<6>(llvm::map_range( + llvm::ArrayRef(containedTypes, numContainedTypes), + [](MlirType t) { return unwrap(t); })))); } size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) { @@ -79,6 +86,10 @@ MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) { return wrap(type.getContainedTypes()[pos]); } +MlirTypeID torchMlirTorchTupleTypeGetTypeID() { + return wrap(Torch::TupleType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.union type. //===----------------------------------------------------------------------===// @@ -91,10 +102,9 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes) { return wrap(Torch::UnionType::get( - unwrap(context), - llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), - [](MlirType t) { return unwrap(t); })))); + unwrap(context), llvm::to_vector<6>(llvm::map_range( + llvm::ArrayRef(containedTypes, numContainedTypes), + [](MlirType t) { return unwrap(t); })))); } size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) { @@ -107,6 +117,10 @@ MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) { return wrap(type.getContainedTypes()[pos]); } +MlirTypeID torchMlirTorchUnionTypeGetTypeID() { + return wrap(Torch::UnionType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// @@ -123,6 +137,10 @@ MlirType torchMlirTorchListTypeGetContainedType(MlirType t) { return wrap(unwrap(t).cast().getContainedType()); } +MlirTypeID torchMlirTorchListTypeGetTypeID() { + return wrap(Torch::ListType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -135,6 +153,10 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) { return wrap(Torch::DeviceType::get(unwrap(context))); } +MlirTypeID torchMlirTorchDeviceTypeGetTypeID() { + return wrap(Torch::DeviceType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Generator type. //===----------------------------------------------------------------------===// @@ -147,6 +169,10 @@ MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) { return wrap(Torch::GeneratorType::get(unwrap(context))); } +MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() { + return wrap(Torch::GeneratorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.bool type. //===----------------------------------------------------------------------===// @@ -159,6 +185,10 @@ MlirType torchMlirTorchBoolTypeGet(MlirContext context) { return wrap(Torch::BoolType::get(unwrap(context))); } +MlirTypeID torchMlirTorchBoolTypeGetTypeID() { + return wrap(Torch::BoolType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.int type. //===----------------------------------------------------------------------===// @@ -171,6 +201,10 @@ MlirType torchMlirTorchIntTypeGet(MlirContext context) { return wrap(Torch::IntType::get(unwrap(context))); } +MlirTypeID torchMlirTorchIntTypeGetTypeID() { + return wrap(Torch::IntType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.float type. //===----------------------------------------------------------------------===// @@ -183,6 +217,10 @@ MlirType torchMlirTorchFloatTypeGet(MlirContext context) { return wrap(Torch::FloatType::get(unwrap(context))); } +MlirTypeID torchMlirTorchFloatTypeGetTypeID() { + return wrap(Torch::FloatType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.LinearParams type. //===----------------------------------------------------------------------===// @@ -195,6 +233,10 @@ MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) { return wrap(Torch::LinearParamsType::get(unwrap(context))); } +MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() { + return wrap(Torch::LinearParamsType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.qint8 type. //===----------------------------------------------------------------------===// @@ -207,6 +249,10 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { return wrap(Torch::QInt8Type::get(unwrap(context))); } +MlirTypeID torchMlirTorchQInt8TypeGetTypeID() { + return wrap(Torch::QInt8Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.quint8 type. //===----------------------------------------------------------------------===// @@ -219,6 +265,10 @@ MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { return wrap(Torch::QUInt8Type::get(unwrap(context))); } +MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { + return wrap(Torch::QUInt8Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// @@ -258,11 +308,11 @@ int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) { } bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return unwrap(t).cast().hasSizes(); } bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return unwrap(t).cast().hasDtype(); } int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { @@ -282,6 +332,10 @@ MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) { return wrap(unwrap(t).cast().getDtype()); } +MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { + return wrap(Torch::NonValueTensorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.vtensor type. //===----------------------------------------------------------------------===// @@ -321,11 +375,11 @@ int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) { } bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return unwrap(t).cast().hasSizes(); } bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return unwrap(t).cast().hasDtype(); } int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { @@ -345,6 +399,10 @@ MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) { return wrap(unwrap(t).cast().getDtype()); } +MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { + return wrap(Torch::ValueTensorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.none type. //===----------------------------------------------------------------------===// @@ -357,6 +415,10 @@ MlirType torchMlirTorchNoneTypeGet(MlirContext context) { return wrap(Torch::NoneType::get(unwrap(context))); } +MlirTypeID torchMlirTorchNoneTypeGetTypeID() { + return wrap(Torch::NoneType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.str type. //===----------------------------------------------------------------------===// @@ -369,6 +431,10 @@ MlirType torchMlirTorchStringTypeGet(MlirContext context) { return wrap(Torch::StringType::get(unwrap(context))); } +MlirTypeID torchMlirTorchStringTypeGetTypeID() { + return wrap(Torch::StringType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.any type. //===----------------------------------------------------------------------===// @@ -381,6 +447,10 @@ MlirType torchMlirTorchAnyTypeGet(MlirContext context) { return wrap(Torch::AnyType::get(unwrap(context))); } +MlirTypeID torchMlirTorchAnyTypeGetTypeID() { + return wrap(Torch::AnyType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.number type. //===----------------------------------------------------------------------===// @@ -393,6 +463,10 @@ MlirType torchMlirTorchNumberTypeGet(MlirContext context) { return wrap(Torch::NumberType::get(unwrap(context))); } +MlirTypeID torchMlirTorchNumberTypeGetTypeID() { + return wrap(Torch::NumberType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Dict type. //===----------------------------------------------------------------------===// @@ -413,11 +487,15 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType, } MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getKeyType()); + auto type = unwrap(t).cast(); + return wrap(type.getKeyType()); } MlirType torchMlirTorchDictTypeGetValueType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getValueType()); + auto type = unwrap(t).cast(); + return wrap(type.getValueType()); +} + +MlirTypeID torchMlirTorchDictTypeGetTypeID() { + return wrap(Torch::DictType::getTypeID()); } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 4c37cca5efb4..03123d2edc67 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,10 +3,12 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(RefBackend) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(LinkedLibs MLIRFuncDialect MLIRIR MLIRSupport + ${extension_libs} TorchMLIRTorchPasses TorchMLIRTorchConversionDialect @@ -21,14 +23,6 @@ set(LinkedLibs TorchMLIRRefBackend ) -if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND LinkedLibs - MhloPasses - MhloToLinalg - StablehloToMhlo - ) -endif() - add_mlir_library(TorchMLIRInitAll InitAll.cpp diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 45714601ded0..0dae24678a4b 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -11,7 +11,6 @@ #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "transforms/passes.h" #endif // TORCH_MLIR_ENABLE_STABLEHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4877568a6bdc..9ec6a6006be7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -34,6 +34,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +static int64_t productReduce(ArrayRef a) { + return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); +} + template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -177,144 +181,131 @@ namespace { class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + // If one of the two dims arrays has size 1, a mapping is created from the one + // dimension of the size-1 array to all the dimensions of the other array. For + // example for inputs: xDims = [6], yDims = [2, 3] the result in the indices + // arrays will be: xIndices = [0], yIndices = [0, 1]. + // + // An error is returned if the dimension size of the size-1 array is not equal + // to the product of all the dimension sizes in the other array, or if neither + // of the arrays is size-1. + static LogicalResult mapAllDimsToSingleDim(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + auto isValidReduction = [](int64_t expectedReductionProduct, + ArrayRef arrayToReduce) -> bool { + if (llvm::count(arrayToReduce, kUnknownSize) > 0 || + expectedReductionProduct == kUnknownSize) + return true; + return productReduce(arrayToReduce) == expectedReductionProduct; + }; - // Helper for filling in remaining un-collapsed dims when the - // input/output dim is next to the next boundary dim. Additionally - // computes the size of a collapsed dynamic dim if necessary. - static LogicalResult - collapseToSingleDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, - int64_t collapseDim, int64_t maxCollapseDim, - int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, - const SmallVector &expandShape, - ReassociationIndices &expandIndices) { - int64_t collapseDimSize = 1; - for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { - expandIndices.push_back(i); - if (collapseDimSize == kUnknownSize) - continue; - - int64_t expandedDimSize = expandShape[i]; - if (expandedDimSize == kUnknownSize) { - collapseDimSize = kUnknownSize; - continue; - } - collapseDimSize *= expandedDimSize; - } - int64_t rawCollapseDimSize = collapseShape[collapseDim]; - if (rawCollapseDimSize != kUnknownSize && collapseDimSize != kUnknownSize && - collapseDimSize != rawCollapseDimSize) { - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); + if (xDims.size() == 1) { + if (!isValidReduction(xDims[0], yDims)) + return failure(); + xIndices.assign({0}); + yIndices.assign(llvm::to_vector(llvm::seq(0, yDims.size()))); + return success(); + } else if (yDims.size() == 1) { + if (!isValidReduction(yDims[0], xDims)) + return failure(); + yIndices.assign({0}); + xIndices.assign(llvm::to_vector(llvm::seq(0, xDims.size()))); + return success(); } - collapseShape[collapseDim] = collapseDimSize; - return success(); + return failure(); } - // Helper to find the minimum set of dims to collapse with the - // same number of elements as that of collapseDim. This function assumes - // the size of the collapsed dim is never dynamic. - static LogicalResult minimallyCollapseDimHelper( - AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim, - int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, SmallVector &expandShape, - ReassociationIndices &collapseIndices, - ReassociationIndices &expandIndices) { - - int64_t collapseDimSize = collapseShape[collapseDim]; - - int64_t expandedSize = 1; - int64_t collapsedSize = collapseDimSize; - - int64_t expandIndex = startExpandDim; - int64_t collapseIndex = collapseDim + 1; - - if (collapseDimSize == kUnknownSize) { - if (llvm::all_of(collapseShape, - [](int64_t value) { return value == kUnknownSize; }) && - llvm::all_of(expandShape, - [](int64_t value) { return value == kUnknownSize; })) { - - for (size_t i = 0; i < collapseShape.size(); i++) { - collapseIndices.push_back(i); - } - - for (size_t i = 0; i < expandShape.size(); i++) { - expandIndices.push_back(i); - } - - return success(); + // Starting from the beginning of the dims arrays, this helper finds the + // smallest set of consecutive dims in each array such that the product of the + // dim sizes in the two subsets is equal. The indices arrays are populated + // with the indices of the dims arrays that correspond to the subsets found. + // + // An error is returned if two subsets of dims with total number of elements + // equal to each other is not found. + static LogicalResult mapStaticallyKnownDims(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + int64_t xTotalSize = xDims[0]; + int64_t yTotalSize = yDims[0]; + SmallVector xIndicesResult({0}); + SmallVector yIndicesResult({0}); + size_t nextXIndex = 1; + size_t nextYIndex = 1; + while (xTotalSize != yTotalSize) { + if (xTotalSize < yTotalSize) { + if (nextXIndex == xDims.size() || xDims[nextXIndex] == kUnknownSize) + return failure(); + xTotalSize *= xDims[nextXIndex]; + xIndicesResult.push_back(nextXIndex++); + } else { + if (nextYIndex == yDims.size() || yDims[nextYIndex] == kUnknownSize) + return failure(); + yTotalSize *= yDims[nextYIndex]; + yIndicesResult.push_back(nextYIndex++); } } - while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) { - if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) { - int64_t expandDimSize = expandShape[expandIndex]; - if (expandDimSize != kUnknownSize) { - expandedSize *= expandDimSize; - } - expandIndices.push_back(expandIndex); - expandIndex++; - - } else if (collapseIndex != maxCollapseDim && - collapsedSize < expandedSize) { - collapseDimSize = collapseShape[collapseIndex]; - if (collapseDimSize != kUnknownSize) { - collapsedSize *= collapseDimSize; - } - collapseIndices.push_back(collapseIndex); - collapseIndex++; - } - - if (expandedSize == collapsedSize) - return success(); - } - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); + xIndices.assign(std::move(xIndicesResult)); + yIndices.assign(std::move(yIndicesResult)); + return success(); } - static void solveDynamicSize(SmallVector &inputShape, - SmallVector &outputShape) { - int64_t inputProduct = 1; - int64_t outputProduct = 1; - - int64_t inputDynamicValues = 0; - int64_t outputDynamicValues = 0; - - for (int64_t value : inputShape) { - if (value == -1) { - ++inputDynamicValues; - } else { - inputProduct *= value; - } - } - for (int64_t value : outputShape) { - if (value == -1) { - ++outputDynamicValues; - } else { - outputProduct *= value; - } + // Calculates the size of a dynamic dimension if all other dimensions are + // statically known, and rewrites that dynamic dimension with the static size. + // + // Note: this function assumes that all the dimensions in `inputShape` map to + // all the dimensions in `outputShape`. + static void calculateSingleDynamicSize(MutableArrayRef inputShape, + MutableArrayRef outputShape) { + int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); + int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); + if (inputDynamicDimCount + outputDynamicDimCount != 1) + return; + + int64_t inputProduct = productReduce(inputShape); + int64_t outputProduct = productReduce(outputShape); + + if (inputDynamicDimCount == 1) { + inputProduct /= kUnknownSize; + *llvm::find(inputShape, kUnknownSize) = outputProduct / inputProduct; + } else { + outputProduct /= kUnknownSize; + *llvm::find(outputShape, kUnknownSize) = inputProduct / outputProduct; } + } - if (inputDynamicValues + outputDynamicValues == 1) { - if (inputDynamicValues) { - int64_t missingValue = outputProduct / inputProduct; - for (size_t i = 0; i < inputShape.size(); i++) { - if (inputShape[i] == -1) { - inputShape[i] = missingValue; - break; - } - } - } else { - int64_t missingValue = inputProduct / outputProduct; - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] == -1) { - outputShape[i] = missingValue; - break; - } + // Gets the shapes of the input and output tensors, making a best-effort + // attempt to extract static shape information given the inputs to + // `aten.view`. + static std::pair, SmallVector> + getInputAndOutputShape(Value inputTorchTensor, + SmallVector outputSizeTorchInt) { + SmallVector inputShape( + inputTorchTensor.getType().cast().getSizes()); + SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { + int64_t inputDim; + int64_t outputDimSizeInt; + // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (matchPattern(outputDimSize, + m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { + outputShape[outputDim] = inputShape[inputDim]; + } else if (matchPattern(outputDimSize, + m_TorchConstantInt(&outputDimSizeInt))) { + if (outputDimSizeInt != -1) { + outputShape[outputDim] = outputDimSizeInt; } } } + + calculateSingleDynamicSize(inputShape, outputShape); + return std::make_pair(inputShape, outputShape); } LogicalResult @@ -325,10 +316,9 @@ class ConvertAtenViewOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); - SmallVector inputShape = - makeShapeTorchCompatible(inputType.getShape()); + SmallVector inputSize = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputType.getRank(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); @@ -349,6 +339,15 @@ class ConvertAtenViewOp : public OpConversionPattern { "unimplemented: the target size is " "not constructed from ListConstruct"); } + if (llvm::count_if(outputSizeTorchInt, [](Value size) -> bool { + int64_t sizeInt; + if (matchPattern(size, m_TorchConstantInt(&sizeInt))) + return sizeInt == -1; + return false; + }) > 1) { + return rewriter.notifyMatchFailure( + op, "at most one element in size list is allowed to be -1"); + } SmallVector outputSizeInt = getTypeConvertedValues( rewriter, loc, typeConverter, outputSizeTorchInt); if (resultRank != (int64_t)outputSizeInt.size()) { @@ -356,6 +355,9 @@ class ConvertAtenViewOp : public OpConversionPattern { op, "desired size list length mismatches with the result type rank"); } + auto [inputShape, outputShape] = + getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -364,90 +366,24 @@ class ConvertAtenViewOp : public OpConversionPattern { // [6] => [3, 2]. // Iterate through the view op size list to do the following: - // - // 1. Combine output size list and input tensor type info to get the most - // static outputShape. - // - // 2. Mark dims in unchangedDims for size list items where the output dim + // Mark dims in unchangedDims for size list items where the output dim // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. - SmallVector outputShape(resultRank, kUnknownSize); - SmallVector unchangedDims; - std::optional inferredDimension; - for (auto en : llvm::enumerate(outputSizeTorchInt)) { + SmallVector> unchangedDims; + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; - int64_t size; - int64_t outputDim = en.index(); // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim - if (matchPattern(en.value(), + if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputDim); - unchangedDims.back().push_back(outputDim); - if (!inputType.isDynamicDim(inputDim)) { - outputShape[outputDim] = inputShape[inputDim]; - continue; - } - } else if (matchPattern(en.value(), m_TorchConstantInt(&size))) { - if (size != -1) { - outputShape[outputDim] = size; - continue; - } - - if (inferredDimension.has_value()) { - return rewriter.notifyMatchFailure( - op, "at most one element in size list is allowed to be -1"); - } - inferredDimension = outputDim; + unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } } - // Mark the end of the input/output shapes - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputRank); - unchangedDims.back().push_back(resultRank); - - // Use static information of input tensor to determine size of inferred - // dimension in output shape. - // - // If there is an inferred dimension and that is the only dimension - // in the output shape (i.e. the tensor is getting fully flattened), - // then we don't need to analyze the static information of the input - // shape since the reassociation of dimensions only requires rank - // information. - if (inferredDimension.has_value() && outputShape.size() > 1) { - if (llvm::count(outputShape, kUnknownSize) != 1 || - llvm::count(inputShape, kUnknownSize) != 0) { - return rewriter.notifyMatchFailure( - op, - "unimplemented: an inferred dimension is only supported when there " - "is enough static shape information to determine its size, or when " - "the input tensor is being flattened to a single dimension"); - } - auto productReduceKnownSizes = [](const ArrayRef sizes) { - auto knownSizes = llvm::make_filter_range( - sizes, [](int64_t val) { return val != kUnknownSize; }); - return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1, - std::multiplies()); - }; - - int64_t numOfElements = productReduceKnownSizes(inputShape); - int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape); - if (numOfElements % outputKnownNumOfElements != 0) { - return rewriter.notifyMatchFailure( - op, "number of elements in input tensor must be divisible by " - "product of non-inferred dimensions in size list"); - } - outputShape[*inferredDimension] = - numOfElements / outputKnownNumOfElements; - } - - SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef outputShapeInt = llvm::ArrayRef(outputSizeInt); - ArrayRef inputShapeInt = llvm::ArrayRef(inputSize); + unchangedDims.push_back(std::make_pair(inputRank, resultRank)); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds @@ -463,10 +399,6 @@ class ConvertAtenViewOp : public OpConversionPattern { SmallVector inputAssociations; SmallVector outputAssociations; - SmallVector inputShapeVec = llvm::to_vector(inputShape); - - solveDynamicSize(inputShapeVec, outputShape); - // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or @@ -482,119 +414,78 @@ class ConvertAtenViewOp : public OpConversionPattern { // the dynamic dimension with the one across from it and give up if we can't // reason about how the dimensions are associated. // e.g. [-1, -1] -> [2, 3, 4] - // 3. Set inputShapeVec and outputShape following the requirements by - // tensor.expand_shape verification code: - // a. As long as one or more of the related dimensions in the expanded - // shape is dynamic the collapsed dimension is dynamic. - // b. If all of the related dimensions are static, the collapsed - // dimension must be static. In other words, if a collapsed dimension is - // dynamic, at least one of the related dimensions need to be dynamic. + // For more information, see description of helper functions used in the + // `if-else` cases inside the while loop. int64_t inputDim = 0, outputDim = 0; - for (auto boundary : unchangedDims) { - // We assume dims specified by AtenSizeInt ops are unchanged - int64_t nextUnchangedInput = boundary[0]; - int64_t nextUnchangedOutput = boundary[1]; - - bool hasDynamic = false; + for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { + // Used for ensuring that we don't have an ambiguous expansion + bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { - - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - - // outputDim is next to the boundary - if (outputDim == nextUnchangedOutput - 1) { - - if (hasDynamic && inputDim != nextUnchangedInput - 1) { - return rewriter.notifyMatchFailure( - op, "found ambiguous collapse of dynamic input sizes (e.g. " - "[-1, -1, -1] -> [-1, -1])"); - } - outputAssociations.back().push_back(outputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - inputAssociations.back()))) - return failure(); - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - // inputDim is next to the boundary - if (inputDim == nextUnchangedInput - 1) { - - if (hasDynamic && inputShape[inputDim] == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " - "[-1, -1, -1])"); - } - inputAssociations.back().push_back(inputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - outputAssociations.back()))) - return failure(); - - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - int64_t inputMatchingDimSize = inputShapeVec[inputDim]; - int64_t outputMatchingDimSize = outputShape[outputDim]; - - // If the input is dynamic, first assume it is not split - if (inputMatchingDimSize == kUnknownSize) { - - checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], - outputShapeInt[outputDim]); - outputShape[outputDim] = kUnknownSize; - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - hasDynamic = true; - continue; + auto inputShapeSlice = + MutableArrayRef(inputShape) + .slice(inputDim, nextUnchangedInput - inputDim); + auto outputShapeSlice = + MutableArrayRef(outputShape) + .slice(outputDim, nextUnchangedOutput - outputDim); + SmallVector inputSliceIndices; + SmallVector outputSliceIndices; + + // TODO: this can be removed by replacing it with a checkDimEqualHelper + // that takes into account the product of all the dimensions being + // reduced + if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 && + outputShapeSlice.size() != 1 && + inputShapeSlice[0] == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "found ambiguous expand of dynamic input sizes " + "(e.g. [-1, -1] -> [-1, -1, -1])"); } - // inputDim size is larger; try to collapse onto it - if (inputMatchingDimSize >= outputMatchingDimSize) { - - inputAssociations.back().push_back(inputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - inputAssociations.back(), outputAssociations.back()))) { - return failure(); + if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, + inputSliceIndices, + outputSliceIndices))) { + calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); + // Update shape to pass the tensor.expand_shape and + // tensor.collapse_shape verifiers. If one of the dimensions of the + // tensor being flattened is dynamic, the size of the flattened tensor + // must also be dynamic. + if (inputShapeSlice.size() == 1 && + llvm::count(outputShapeSlice, kUnknownSize) > 0) { + inputShapeSlice[0] = kUnknownSize; + } else if (outputShapeSlice.size() == 1 && + llvm::count(inputShapeSlice, kUnknownSize) > 0) { + outputShapeSlice[0] = kUnknownSize; } - hasDynamic = false; - outputDim = outputAssociations.back().back() + 1; - inputDim = inputAssociations.back().back() + 1; - continue; + } else if (succeeded(mapStaticallyKnownDims( + inputShapeSlice, outputShapeSlice, inputSliceIndices, + outputSliceIndices))) { + /// `mapStaticallyKnownDims` maps the smallest number of + /// input and output dimensions in the slice statically + /// known to have the same number of elements. + } else if (inputShapeSlice[0] == kUnknownSize) { + // If the input is dynamic, assume it is not split + checkDimEqualHelper(rewriter, loc, inputSize[inputDim], + outputSizeInt[outputDim]); + // If output dimension is not dynamic, improve static information of + // input + inputShape[inputDim] = outputShape[outputDim]; + inputSliceIndices.push_back(0); + outputSliceIndices.push_back(0); + assumedDynamicDimNotSplit = true; + } else { + return rewriter.notifyMatchFailure( + op, "unimplemented: found unhandled case of expansion/collapse " + "in `aten.view`"); } - // outputDim is larger; try to collapse onto it - outputAssociations.back().push_back(outputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - outputAssociations.back(), inputAssociations.back()))) { - - return failure(); - } - hasDynamic = false; + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + for (int64_t inputSliceIndex : inputSliceIndices) + inputAssociations.back().push_back(inputSliceIndex + inputDim); + for (int64_t outputSliceIndex : outputSliceIndices) + outputAssociations.back().push_back(outputSliceIndex + outputDim); inputDim = inputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1; - continue; - } - - if (inputDim != nextUnchangedInput) { - hasDynamic = true; - if (inputAssociations.size() < 1) { - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - } - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - continue; } // Append the associations for the dims matching `aten.size.int` @@ -624,7 +515,7 @@ class ConvertAtenViewOp : public OpConversionPattern { Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( - makeShapeLLVMCompatible(inputShapeVec), resultType.getElementType()); + makeShapeLLVMCompatible(inputShape), resultType.getElementType()); Value castedInput = rewriter.create(loc, adjustedInputType, input); std::optional expandedInput; @@ -649,8 +540,9 @@ class ConvertAtenViewOp : public OpConversionPattern { intermediateShape.push_back(sum); } - Type intermediateResultType = RankedTensorType::get( - makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); + Type intermediateResultType = + RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), + resultType.getElementType()); expandedInput = rewriter @@ -695,7 +587,7 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); @@ -804,7 +696,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { op, "unimplemented: dim(th) dimension is not expected to be dynamic"); } - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); @@ -1014,10 +906,10 @@ class ConvertAtenPermuteOp : public OpConversionPattern { for (unsigned i = 0; i < inputRank; i++) swapExprs.push_back(idExprs[dimensions[i]]); - AffineMap inputMap = AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, - op->getContext()); - AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, - op->getContext()); + AffineMap inputMap = + AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); + AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, + swapExprs, op->getContext()); SmallVector indexingMaps{inputMap, outputMap}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); @@ -1046,7 +938,7 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = @@ -1081,7 +973,7 @@ class ConvertAtenCatOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); // Collect all the tensors to be concatenated. auto tensorList = op.getTensors(); @@ -1096,14 +988,9 @@ class ConvertAtenCatOp : public OpConversionPattern { typeConverter->convertType(op.getType()).cast(); auto outElemType = newResultType.getElementType(); - auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, - ValueRange payloadArgs) { - Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType); - builder.create(loc, elem); - }; for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric( - rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody); + tensors[i] = torch_to_linalg::convertTensorToElementType( + rewriter, loc, tensors[i], outElemType); } int rank = newResultType.getRank(); @@ -1114,7 +1001,7 @@ class ConvertAtenCatOp : public OpConversionPattern { dim = toPositiveDim(dim, rank); if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - + SmallVector offsets, sizes, strides; sizes.reserve(rank); strides.resize(rank, rewriter.create(loc, 1)); @@ -1179,12 +1066,29 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } + // For dynamic input dimension we need to use the `broadcastToShape` + // which in this case is `inShapeConverted` because this shape will yield + // us the dimension size of the output. + SmallVector useBroadcastToShape; + for (auto x : inShape) { + int64_t dim; + if (!matchPattern(x, m_TorchConstantInt(&dim))) { + Operation *defOp = x.getDefiningOp(); + if (isa(defOp)) + useBroadcastToShape.push_back(true); + else + useBroadcastToShape.push_back(false); + } else { + useBroadcastToShape.push_back(false); + } + } + SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); - Value result; - if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, self, inShapeConverted, result))) { + if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self, + inShapeConverted, result, + useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } @@ -1295,7 +1199,7 @@ class ConvertAtenSliceScatterOp return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); @@ -1344,7 +1248,7 @@ class ConvertAtenViewAsComplexOp return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); MLIRContext *context = rewriter.getContext(); auto input = adaptor.getSelf(); @@ -1410,6 +1314,89 @@ class ConvertAtenViewAsComplexOp }; } // namespace +namespace { +class ConvertAtenViewAsRealOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenViewAsRealOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + MLIRContext *context = rewriter.getContext(); + + auto input = adaptor.getSelf(); + + RankedTensorType resultType = + typeConverter->convertType(op.getType()).cast(); + + RankedTensorType inputType = input.getType().cast(); + auto inputElementType = getElementTypeOrSelf(input.getType()); + if (!inputElementType.isa()) { + return op.emitError("only ComplexType is allowed as input type"); + } + Type elementType = resultType.getElementType(); + + // returned real tensor has a size increase, where the last dim has size 2 + SmallVector resultShape = + tensor::getMixedSizes(rewriter, loc, input); + resultShape.push_back( + rewriter.createOrFold(loc, 2)); + + Value outTensor = + rewriter.create(loc, resultShape, elementType); + + SmallVector inputExpr; + for (unsigned i = 0; i < resultType.getRank() - 1; i++) { + inputExpr.push_back(getAffineDimExpr(i, context)); + } + + AffineMap inputMap = + AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); + + inputExpr.push_back(getAffineDimExpr(resultType.getRank() - 1, context)); + + AffineMap outputMap = + AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); + + SmallVector indexingMaps{inputMap, outputMap}; + + SmallVector iteratorTypes(resultType.getRank(), utils::IteratorType::parallel); + + Value constantZero = + getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); + auto realVar = + rewriter + .create( + loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + + Value realVal = + b.create(loc, elementType, args[0]); + Value imagVal = + b.create(loc, elementType, args[0]); + Value lastIndex = + b.create(loc, inputType.getRank()); + Value cmpResult = b.create( + loc, arith::CmpIPredicate::eq, lastIndex, constantZero); + Value yieldValue = b.create( + loc, cmpResult, realVal, imagVal); + + b.create(loc, yieldValue); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, realVar); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1442,4 +1429,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 0aaecb7fbaac..cfbac2632a28 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -525,6 +525,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { }; } // namespace +static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, + Value input, int64_t dim) { + Value cstZero = b.create(loc, b.getI64IntegerAttr(0)); + Value isIndexNegative = + b.create(loc, arith::CmpIPredicate::slt, index, cstZero); + Value inputShape = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim)); + Value toPositiveIndex = b.create(loc, index, inputShape); + return b.create(loc, isIndexNegative, toPositiveIndex, + index); +} + // IndexTensor for multiple input tensors broadcasts their shapes to a common // shape and then replaces the indexed dims with the indices given by the // indexing tensors: @@ -541,11 +552,11 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { // e.g. x: [2, 3] // x[[4], [6, 1]] -> x[6, 4] namespace { -class ConvertAtenIndexTensorOp : public OpConversionPattern { +class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, + matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -731,8 +742,10 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { b.create(loc, i)); } for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { - extractionIndices.push_back( - castIntToIndex(b, loc, args[i])); + extractionIndices.push_back(castIntToIndex( + b, loc, + makeIndexValuePositive(b, loc, args[i], input, + extractionIndices.size()))); } for (auto i : llvm::seq((int)extractionIndices.size(), inputRank)) { @@ -744,8 +757,11 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { for (auto i : llvm::seq(0, inputRank)) { if (indexCount < replacedIndexCount && i == indexTensorDims[indexCount]) { - extractionIndices.push_back( - castIntToIndex(b, loc, args[indexCount++])); + extractionIndices.push_back(castIntToIndex( + b, loc, + makeIndexValuePositive(b, loc, args[indexCount++], + input, + extractionIndices.size()))); continue; } extractionIndices.push_back(b.create( @@ -1091,8 +1107,8 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d36b8c309daf..23528bb01f80 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -113,6 +113,13 @@ class ConvertAtenFlipOp : public OpConversionPattern { if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) return rewriter.notifyMatchFailure(op, "only constant dim lists supported"); + for (unsigned i = 0, e = axis.size(); i < e; i++) { + axis[i] = toPositiveDim(axis[i], selfRank); + if (!isValidDim(axis[i], selfRank)) { + return rewriter.notifyMatchFailure(op, "axis is statically invalid"); + } + } + // Only used to calculate flipped values, i.e. those on the flip axes. Other // dims won't be used. SmallVector dims = getTensorSizes(rewriter, loc, self); @@ -434,16 +441,28 @@ class ConvertAtenBmmOp : public OpConversionPattern { Value rhs = adaptor.getMat2(); RankedTensorType lhsType = lhs.getType().cast(); RankedTensorType rhsType = rhs.getType().cast(); + Type newResultType = getTypeConverter()->convertType(op.getType()); + Type resultElementType = newResultType.cast().getElementType(); + Type lhsElementType = lhsType.cast().getElementType(); + Type rhsElementType = rhsType.cast().getElementType(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.bmm to be rank 3"); } - if (!lhsType.getElementType().isa() || - lhsType.getElementType() != rhsType.getElementType()) - return op.emitError( - "unimplemented: non floating point operands or operands of " - "different types"); + + // Convert the inputs element type equivalent to the result' element type. + if (lhsElementType != rhsElementType) { + if (lhsElementType != resultElementType) { + // True if the lhs element type is not equal to the result' element type. + lhs = torch_to_linalg::convertTensorToElementType( + rewriter, loc, lhs, resultElementType); + } else { + // True if the rhs element type is not equal to the result' element type. + rhs = torch_to_linalg::convertTensorToElementType( + rewriter, loc, rhs, resultElementType); + } + } Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); @@ -458,10 +477,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = newResultType.cast().getElementType(); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType); Value bmm = rewriter diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 850363724153..1d7ff925b6ed 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -32,14 +32,14 @@ using namespace mlir::torch::Torch; template static LogicalResult checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, - TypeConverter *typeConverter, bool &ceilMode, + const TypeConverter *typeConverter, bool &ceilMode, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { // Pattern match against the op's original operands, because otherwise we // will get the lowered version of the operands which is harder to pattern // match. - SmallVector kernelSizeTorchInt; + SmallVector kernelSizeTorchInt; if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) { return rewriter.notifyMatchFailure(op, "unimplemented: the kernel size is " @@ -77,7 +77,7 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, template static LogicalResult createPoolingOp( Operation *op, ConversionPatternRewriter &rewriter, Value self, - bool supportNonFPInput, bool ceilMode, + bool supportNonFPInput, bool ceilMode, int64_t dimensionality, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVectorImpl &dilationInts, Attribute initValueAttr, @@ -87,22 +87,23 @@ static LogicalResult createPoolingOp( if (!elementType.isa() && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); - SmallVector lowPaddingIncludingNC = {0, 0}; + SmallVector lowPaddingIncludingNC = {0, 0}; lowPaddingIncludingNC.append(paddingInts); - SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + if (ceilMode) { - highPaddingIncludingNC[2] += strideInts[0]; - highPaddingIncludingNC[3] += strideInts[1]; + for (int64_t i = 0; i < dimensionality; ++i) { + highPaddingIncludingNC[i + 2] += strideInts[i]; + } } + Value initValue = rewriter.create(loc, cast(initValueAttr)); paddedInput = torch_to_linalg::getPaddedTensor( op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, initValue); - + Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); - Value H = getDimOp(rewriter, loc, self, 2); - Value W = getDimOp(rewriter, loc, self, 3); SmallVector paddingIntValues = getAsConstantIntValues(rewriter, loc, paddingInts); @@ -111,15 +112,17 @@ static LogicalResult createPoolingOp( SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); - Value hOut = torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, H, paddingIntValues[0], dilationIntValues[0], - kernelSizeIntValues[0], strideIntValues[0], ceilMode); - Value wOut = torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, W, paddingIntValues[1], dilationIntValues[1], - kernelSizeIntValues[1], strideIntValues[1], ceilMode); + // Get dimension size for each dimension and calculate output size + for (int64_t i = dimensionality - 1; i > -1; --i) { + Value dimSize = getDimOp(rewriter, loc, self, i + 2); + Value outDim = torch_to_linalg::getOutputDimForConvOps( + rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i], + kernelSizeIntValues[i], strideIntValues[i], ceilMode); + outTensorShape.insert(outTensorShape.begin(), {outDim}); + } // Create output tensor initialized with smallest floating point value. - outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut}); + outTensorShape.insert(outTensorShape.begin(), {N, C}); Value outTensorInitialized = createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); @@ -138,6 +141,7 @@ static LogicalResult createPoolingOp( return success(); } + namespace { class ConvertAtenMaxPool2dOp : public OpConversionPattern { public: @@ -148,7 +152,7 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); int64_t selfRank = self.getType().cast().getRank(); // TODO: Add support for 3D inputs. @@ -177,8 +181,9 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern { Value maxPool2d, paddedInput; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, - kernelSizeIntValues, strideInts, paddingInts, dilationInts, - smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) + /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, + dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, + maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); @@ -219,7 +224,7 @@ class ConvertAtenMaxPool2dWithIndicesOp if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); RankedTensorType selfType = self.getType().cast(); Type elementType = selfType.getElementType(); @@ -253,8 +258,9 @@ class ConvertAtenMaxPool2dWithIndicesOp SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, - kernelSizeIntValues, strideInts, paddingInts, dilationInts, - smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) + /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, + dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, + maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Value cstMinusOne = @@ -366,29 +372,32 @@ class ConvertAtenMaxPool2dWithIndicesOp }; } // namespace + namespace { -class ConvertAtenAvgPool2dOp : public OpConversionPattern { +template +class ConvertAtenAvgPoolOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + Location loc = op->getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); Type inputElementType = self.getType().cast().getElementType(); - Type resultType = getTypeConverter()->convertType(op.getType()); + Type resultType = typeConverter->convertType(op.getType()); Type resultElementType = resultType.cast().getElementType(); bool ceilMode; - SmallVector kernelSizeIntValues; - SmallVector strideInts, paddingInts, dilationInts{1, 1}; - if (failed(checkAndGetPoolingParameters( + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts(Dim, 1); + if (failed(checkAndGetPoolingParameters( op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); @@ -404,34 +413,36 @@ class ConvertAtenAvgPool2dOp : public OpConversionPattern { op, "unimplemented: count_include_pad is expected to be true"); } - // `sumPool2d` contains the result of sumpool2d operation over the input. - Value sumPool2d, paddedInput; - SmallVector outTensorShape; - if (failed(createPoolingOp( + // `sumPool` contains the result of sumpool operation over the input. + Value sumPool, paddedInput; + SmallVector outTensorShape; + if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, - kernelSizeIntValues, strideInts, paddingInts, dilationInts, - rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, - sumPool2d))) - return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d"); - - Value kHtimeskW = rewriter.create( - loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - Value divisor = op.getDivisorOverride().getType().isa() - ? kHtimeskW - : adaptor.getDivisorOverride(); + /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts, + dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, + paddedInput, sumPool))) + return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); + Value divisor; + if constexpr (std::is_same()) { + Value kHtimeskW = rewriter.create( + loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); + divisor = op.getDivisorOverride().getType().template isa() + ? kHtimeskW + : adaptor.getDivisorOverride(); + } else { + divisor = kernelSizeIntValues[0]; + } divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg(2, - rewriter.getMultiDimIdentityMap(4)); + SmallVector indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2)); SmallVector iteratorTypesAvg( - 4, utils::IteratorType::parallel); - - Value avgPool2d = + Dim+2, utils::IteratorType::parallel); + Value avgPool = rewriter .create( - loc, outputTensor.getType(), sumPool2d, outputTensor, + loc, outputTensor.getType(), sumPool, outputTensor, /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -444,11 +455,12 @@ class ConvertAtenAvgPool2dOp : public OpConversionPattern { }) .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool2d); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } }; -} // namespace +} + void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -458,6 +470,9 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 9a1c0ae53729..641f1ef8cc1c 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -176,8 +176,8 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { Value resultMax, predicate; if (inElementType.isa()) { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); + resultMax = rewriter.create(nestedLoc, newValue, + oldValue); predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else { @@ -208,6 +208,13 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, if (isa(op)) return b.create(loc, b.getZeroAttr(elementType)); + if (isa(op)) { + if (elementType.isa()) + return b.create(loc, b.getFloatAttr(elementType, 1.0)); + else if (elementType.isa()) + return b.create(loc, b.getIntegerAttr(elementType, 1)); + } + if (isa(op)) { if (elementType.isa()) return b.create( @@ -224,6 +231,22 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } + if (isa(op)) { + if (elementType.isa()) + return b.create( + loc, b.getFloatAttr( + elementType, + APFloat::getInf( + elementType.cast().getFloatSemantics(), + /*Negative=*/false))); + else if (elementType.isa() && + elementType.getIntOrFloatBitWidth() != 8) + return b.create( + loc, b.getIntegerAttr(elementType, + APSInt::getSignedMaxValue( + elementType.getIntOrFloatBitWidth()))); + } + if (isa(op) || isa(op)) return b.create(loc, b.getZeroAttr(elementType)); @@ -244,12 +267,20 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, return b.create(loc, self, result); else if (resultElementType.isa()) return b.create(loc, self, result); + } else if (isa(op)) { + Value self = + convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); + Value result = payloadArgs[1]; + if (resultElementType.isa()) + return b.create(loc, self, result); + else if (resultElementType.isa()) + return b.create(loc, self, result); } else if (auto max = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) - return b.create(loc, self, result); + return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = max.getSelf() .getType() @@ -261,6 +292,23 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (intType.isSigned()) return b.create(loc, self, result); } + } else if (auto min = dyn_cast(op)) { + Value self = + convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); + Value result = payloadArgs[1]; + if (resultElementType.isa()) + return b.create(loc, self, result); + else if (resultElementType.isa()) { + IntegerType intType = min.getSelf() + .getType() + .cast() + .getDtype() + .dyn_cast(); + if (intType.isUnsigned()) + return b.create(loc, self, result); + if (intType.isSigned()) + return b.create(loc, self, result); + } } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -307,6 +355,7 @@ class ConvertReductionOp : public ConversionPattern { "`keepdim` must be a constant bool"); SmallVector dimList; + int64_t dim; bool isNoneOrEmptyDimList = op.getDim().getType().template isa(); if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { @@ -319,6 +368,12 @@ class ConvertReductionOp : public ConversionPattern { } if (dimList.empty()) isNoneOrEmptyDimList = true; + } else if (matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + dim = toPositiveDim(dim, inputType.getRank()); + if (!isValidDim(dim, inputType.getRank())) + return rewriter.notifyMatchFailure( + op, "`dim` argument must be valid, invalid received."); + opInfo.dimSet.insert(dim); } else if (!isNoneOrEmptyDimList) { return rewriter.notifyMatchFailure( op, "`dim` argument must be a constant int list or None"); @@ -340,11 +395,11 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the + // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the // input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -355,6 +410,9 @@ class ConvertReductionOp : public ConversionPattern { if (auto sumOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter); + if (auto prodOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(prodOp, operands, rewriter); + if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); @@ -519,7 +577,9 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 724430401ab1..7e73fabd8e9f 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -106,7 +106,7 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { } Location loc = op.getLoc(); - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( @@ -211,7 +211,7 @@ class ConvertAtenEmptyMemoryFormatOp } Location loc = op.getLoc(); - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( @@ -282,7 +282,7 @@ class ConvertAtenArangeStartStepOp } Location loc = op.getLoc(); - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5007786b5fef..1d25d22720d2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -127,8 +127,10 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { } template -static Value createCalculationForMathOpWithDtypeConversion( - OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) { +static Value +createCalculationForMathOpWithDtypeConversion(OpBuilder &b, + const TypeConverter *converter, + Value payloadArg, Operation *op) { Type dtype = converter->convertType(op->getResult(0).getType()) .template cast() .getElementType(); @@ -207,7 +209,7 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs, } static Value createLinalgPayloadCalculationForElementwiseOp( - OpBuilder &b, Location loc, TypeConverter *converter, + OpBuilder &b, Location loc, const TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); @@ -565,6 +567,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); + } else if(dtype.isa()) { + return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } @@ -658,18 +662,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( divTensorMode.emitError("invalid rounding mode"); return nullptr; } + if (auto pow = dyn_cast(op)) { - if (!pow.getType() - .cast() - .getDtype() - .isa()) { + Type dtype = pow.getType().cast().getDtype(); + if (!dtype.isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Type dtype = pow.getExponent().getType().cast().getDtype(); Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype); - return b.create(loc, selfPromoted, payloadArgs[0]); + Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + return b.create(loc, selfPromoted, expPromoted); } + if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() @@ -1178,14 +1182,14 @@ class ConvertElementwiseOp : public ConversionPattern { AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenAtanOp, AtenRealOp, AtenImagOp>(op)) + AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, + AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1707,17 +1711,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, - AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, - AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, - AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp>(); + AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, + AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, + AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 27299458de8b..42c5d0b441cc 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -323,7 +323,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result) { + SmallVector broadcastToShape, Value &result, + SmallVector useBroadcastToShape) { RankedTensorType inputType = input.getType().cast(); SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); @@ -335,13 +336,16 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); SmallVector outShape; // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zeroIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); size_t diff = broadcastToShape.size() - inputShape.size(); for (size_t i = 0; i < broadcastToShape.size(); i++) { Value shapeValue = broadcastToShape[i]; @@ -358,46 +362,65 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } if (inputShape[j] == 1) { // Broadcast singleton dimension - Value one = - rewriter.create(loc, rewriter.getIndexAttr(1)); Value isNegative = rewriter.create( loc, arith::CmpIPredicate::slt, shapeValue, zero); Value select = rewriter.create( - loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); + loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); - outExpr.push_back(mlir::getAffineConstantExpr(0, context)); - continue; + } else { + // Case of dynamic input dimension wherein the shape to broadcast will + // yield us the dimension size of the output. + Value dim = getDimOp(rewriter, loc, input, j); + if (!useBroadcastToShape.empty()) { + if (useBroadcastToShape[i]) + dim = castIntToIndex(rewriter, loc, broadcastToShape[j]); + } + outShape.push_back(dim); } - // Non-broadcast case - Value dim = getDimOp(rewriter, loc, input, j); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value isEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim), - shapeValue); - Value isValid = rewriter.create(loc, isNegative, isEqual); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "only broadcasting singleton dimensions supported")); - outShape.push_back(dim); - outExpr.push_back(mlir::getAffineDimExpr(i, context)); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(outShape), elementType); SmallVector indexingMaps = { - AffineMap::get(broadcastToShape.size(), 0, outExpr, context), rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; SmallVector iteratorTypes(broadcastToShape.size(), utils::IteratorType::parallel); result = rewriter .create( - loc, outTensor.getType(), input, outTensor, indexingMaps, - iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); + loc, outTensor.getType(), ValueRange(), outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // `loopIndices` contains IV of the linalg loops which + // would be used to extract values from the input tensor + // later on. + SmallVector loopIndices; + for (size_t i = 0; i < broadcastToShape.size(); ++i) { + if (i < diff) + continue; + loopIndices.push_back(b.create(loc, i)); + } + // `inputIndicesToExtract` contains i-th linalg loop IV if + // the i-th input dimension is not 1, else it contains a + // zero index. + SmallVector inputIndicesToExtract; + for (size_t i = 0, n = inputShape.size(); i < n; i++) { + if (inputShape[i] == 1) { + inputIndicesToExtract.push_back(zeroIndex); + } else { + Value inputDim = getDimOp(b, loc, input, i); + Value isEqual = b.create( + loc, arith::CmpIPredicate::eq, inputDim, oneIndex); + Value select = rewriter.create( + loc, isEqual, zeroIndex, loopIndices[i]); + inputIndicesToExtract.push_back(select); + } + } + // Extract and yield the value from input tensor at + // `inputIndicesToExtract` indices. + Value result = b.create( + loc, input, inputIndicesToExtract); + b.create(loc, result); }) .getResult(0); @@ -412,3 +435,16 @@ Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, return b.create( loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor); } + +Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, + Value tensor, + Type elementType) { + auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, + ValueRange payloadArgs) { + Value elem = + convertScalarToDtype(builder, loc, payloadArgs[0], elementType); + builder.create(loc, elem); + }; + return torch_to_linalg::createElementwiseLinalgGeneric( + b, loc, {tensor}, elementType, dtypePromoteBody); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 5fd5538c264b..354012028b01 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -73,14 +73,19 @@ Value createElementwiseLinalgGeneric( function_ref bodyBuild); // Broadcasts input tensor based on the broadcastToShape. -LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, - Value input, - SmallVector broadcastToShape, - Value &result); +LogicalResult +broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input, + SmallVector broadcastToShape, Value &result, + SmallVector useBroadcastToShape = {}); // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); + +// Converts a tensor' element type to the specified `elementType`. +Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor, + Type elementType); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 7c256c071ded..96e14f0fdd6e 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -77,7 +77,7 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { if (op.isForLike()) return failure(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); SmallVector newResultTypes; if (failed( typeConverter->convertTypes(op.getResultTypes(), newResultTypes))) @@ -217,7 +217,7 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { if (!op.isForLike()) return failure(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); SmallVector newResultTypes; if (failed( typeConverter->convertTypes(op.getResultTypes(), newResultTypes))) @@ -237,17 +237,17 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { SmallVector regionArgTypes; SmallVector regionArgLocs; - for (Value value : scfForOp.getLoopBody().front().getArguments()) { + for (Value value : scfForOp.getRegion().front().getArguments()) { regionArgTypes.push_back(value.getType()); regionArgLocs.push_back(value.getLoc()); } // Populate the loop body region. - if (!scfForOp.getLoopBody().empty()) - rewriter.eraseBlock(&scfForOp.getLoopBody().back()); + if (!scfForOp.getRegion().empty()) + rewriter.eraseBlock(&scfForOp.getRegion().back()); - auto *block = rewriter.createBlock(&scfForOp.getLoopBody(), - scfForOp.getLoopBody().begin(), + auto *block = rewriter.createBlock(&scfForOp.getRegion(), + scfForOp.getRegion().begin(), regionArgTypes, regionArgLocs); // Rewrite uses of the torch loop block arguments to the new for-loop diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 6ed3e5d7dc34..979182ae7fd7 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -13,6 +13,8 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -24,7 +26,6 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "utils/hlo_utils.h" #include #include @@ -33,6 +34,34 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; +namespace { + +template +static Value getConstantLike(OpBuilder &b, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (ty.isa()) + return b.getIntegerAttr(ty, constant); + if (ty.isa()) + return b.getFloatAttr(ty, constant); + if (auto complexTy = ty.dyn_cast()) + return complex::NumberAttr::get(complexTy, constant, 0); + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, cast(getAttr()), + val); +} + +Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + return b.create(loc, b.getFloatAttr(ty, constant), + val); +} + +} // namespace + LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, size_t dimSizeIndexBits) { @@ -148,7 +177,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - self = hlo::promoteType(rewriter, self, outType); + self = hlo::promoteType(rewriter, op.getLoc(), self, outType); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -231,6 +260,48 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { } // namespace +namespace { +// Casts a tensor of exactly one element to an elemental type. +// Many codes borrowed from +// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp` +template +class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputType = + adaptor.getA().getType().template dyn_cast(); + if (!inputType) + + op.emitError("only Tensor types supported in StableHLO"); + Location loc = op.getLoc(); + Value input = adaptor.getA(); + SmallVector inputSizes = getTensorSizes(rewriter, loc, input); + int64_t inputRank = inputSizes.size(); + Type inputDtype = + op.getA().getType().template cast().getDtype(); + + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + for (int64_t i = 0; i < inputRank; i++) + checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); + + Value constantZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + SmallVector indices(inputRank, constantZero); + Value result = rewriter.create(loc, input, indices); + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, + resultType, inputDtype)); + return success(); + } +}; +} // namespace + // The binary broadcast patterns namespace { template @@ -253,8 +324,8 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { ->convertType(op.getType()) .template cast(); - lhs = hlo::promoteType(rewriter, lhs, outTy); - rhs = hlo::promoteType(rewriter, rhs, outTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -300,8 +371,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } } - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, @@ -354,8 +425,8 @@ class ConvertAtenMulDivOp : public OpConversionPattern { outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -427,7 +498,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { } // TODO: what is the PyTorch default type promotion? - rhs = hlo::promoteType(rewriter, rhs, lhsTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; @@ -494,8 +565,10 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType); - Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType); + Value lhs = + hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType); + Value rhs = + hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -610,8 +683,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()).cast(); // promote self and other types - self = hlo::promoteType(rewriter, self, outType); - other = hlo::promoteType(rewriter, other, outType); + self = hlo::promoteType(rewriter, op.getLoc(), self, outType); + other = hlo::promoteType(rewriter, op.getLoc(), other, outType); if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) @@ -760,6 +833,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenTensorIntOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTensorIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Type outElementType = resultType.getElementType(); + Value innerValue = adaptor.getT(); + Value stablehloTensor = + hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType); + rewriter.replaceOp(op, stablehloTensor); + return success(); +} + // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> @@ -775,7 +864,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "for AtenReciprocalOp"); } - Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); + Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } @@ -807,8 +896,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -832,6 +921,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenScalarImplicitOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScalarImplicitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Type inputDtype = + op.getA().getType().template cast().getDtype(); + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + auto result = + rewriter.create(loc, adaptor.getA()); + + rewriter.replaceOp( + op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); + return success(); +} + // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details template <> @@ -866,7 +973,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value zeroTensor; - zeroTensor = chlo::getConstantLike( + zeroTensor = getConstantLike( rewriter, op->getLoc(), APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), @@ -888,9 +995,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op.emitError("only ranked tensor type is supported."); } - Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); - Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); + Value one = getConstantLike(rewriter, loc, 1.0, input); + Value two = getConstantLike(rewriter, loc, 2.0, input); + Value half = getConstantLike(rewriter, loc, 0.5, input); auto rsqrtTwo = rewriter.create(loc, two); auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); @@ -921,7 +1028,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); - // shape = [N, C, H, W] auto inputTy = input.getType().cast(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); @@ -940,7 +1046,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto inputElemTy = inputTy.getElementType().cast(); - Value channelDim = rewriter.create(op->getLoc(), input, 1); + Value channelDim = + rewriter.create(op->getLoc(), input, feature_index); if (options.dimSizeIndexBits == 32) { auto channelDimI64 = rewriter.create( @@ -1016,12 +1123,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); - auto batchNormTrainingResult = - rewriter.create( - op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, - weight, bias, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(feature_index)); - rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); + + Value output; + // supported mixed types, like input type is fp16 and weight type is fp32. + if (inputTy.getElementType() != weightTy.getElementType()) { + RankedTensorType convertedType = inputTy; + if (weightTy.getElementType().cast().getWidth() > + inputTy.getElementType().cast().getWidth()) { + convertedType = RankedTensorType::get(inputTy.getShape(), + weightTy.getElementType()); + } + input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); + weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); + auto batchNormTrainingResult = + rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); + output = hlo::promoteType(rewriter, op.getLoc(), + batchNormTrainingResult.getResult(0), + outputTy.cast()); + } else { + auto batchNormTrainingResult = + rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); + output = batchNormTrainingResult.getResult(0); + } + rewriter.replaceOp(op, output); return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); @@ -1033,12 +1164,38 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // stablehlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); - Value output = rewriter.create( - op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, - runningMean, runningVar, - // 'epsilon' must satisfy constraint: 32-bit float attribute. - rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(feature_index)); + + Value output; + // supported mixed types, like input type is fp16 and weight type is fp32. + if (inputTy.getElementType() != weightTy.getElementType()) { + RankedTensorType convertedType = inputTy; + if (weightTy.getElementType().cast().getWidth() > + inputTy.getElementType().cast().getWidth()) { + convertedType = RankedTensorType::get(inputTy.getShape(), + weightTy.getElementType()); + } + input = + hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType); + weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); + runningMean = + hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); + runningVar = + hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); + Value bnResult = rewriter.create( + op.getLoc(), convertedType, input, weight, bias, runningMean, + runningVar, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); + output = hlo::promoteType(rewriter, op.getLoc(), bnResult, + outputTy.cast()); + } else { + output = rewriter.create( + op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, + runningMean, runningVar, + // 'epsilon' must satisfy constraint: 32-bit float attribute. + rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); + } rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } @@ -1212,7 +1369,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Promote type for (auto &v : builtinTensors) { - v = hlo::promoteType(rewriter, v, outType); + v = hlo::promoteType(rewriter, op->getLoc(), v, outType); } rewriter.replaceOpWithNewOp( @@ -1356,13 +1513,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } // Create constant value - Value kAlpha = - chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input); + Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input); Value cstAlpha0 = - chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input); - Value half = chlo::getConstantLike(rewriter, loc, .5, input); - Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); - Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); + getConstantLike(rewriter, loc, 1.12837916709551257390, input); + Value half = getConstantLike(rewriter, loc, .5, input); + Value one = getConstantLike(rewriter, loc, 1.0, input); + Value negHalf = getConstantLike(rewriter, loc, -0.5, input); // Compute Value kBeta0 = @@ -1404,8 +1560,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outTy = this->getTypeConverter()->convertType(op.getType()).cast(); - lhs = hlo::promoteType(rewriter, lhs, outTy); - rhs = hlo::promoteType(rewriter, rhs, outTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -1474,15 +1630,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "memory_format is supported"); } - // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( op, "unimplemented: device must be a constant str"); - else if (device != "cpu") - return rewriter.notifyMatchFailure( - op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. @@ -1498,7 +1650,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Location loc = op.getLoc(); - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( @@ -1513,8 +1665,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( typeConverter->convertType(op.getType()).cast(); Type resultElementType; if (op.getDtype().getType().isa()) { - resultElementType = - getDefaultDtypeForTorchScalar(Torch::FloatType::get(op->getContext())); + resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) @@ -1560,6 +1711,7 @@ class ConvertRuntimeAssertOp : public OpConversionPattern { }; } // namespace +// AtenFillScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFillScalarOp op, OpAdaptor adaptor, @@ -1569,12 +1721,40 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto dtype = outType.getElementType(); Value scalarTensor = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); - Value bcastScalar = rewriter.create( - op->getLoc(), outType, scalarTensor, rewriter.getI64TensorAttr({})); + Value shapeTensor = + rewriter.create(op->getLoc(), adaptor.getSelf()); + Value bcastScalar = rewriter.create( + op->getLoc(), outType, scalarTensor, shapeTensor, + rewriter.getI64TensorAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } +// AtenFlipOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFlipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + + SmallVector dims; + if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) { + return rewriter.notifyMatchFailure(op, "dims must be a list of const int"); + } + for (unsigned i = 0, e = dims.size(); i < e; i++) { + dims[i] = toPositiveDim(dims[i], outType.getRank()); + if (!isValidDim(dims[i], outType.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + } + + rewriter.replaceOpWithNewOp( + op, outType, self, rewriter.getI64TensorAttr(dims)); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -1619,6 +1799,16 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN +#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context) + + INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); + INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); + INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp); +#undef INSERT_TENSOR_TO_SCALAR_PATTERN + #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) @@ -1676,9 +1866,11 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); @@ -1700,6 +1892,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 84a560cd753d..0f9b8fabaa54 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -20,7 +20,8 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo LINK_LIBS PUBLIC MLIRIR MLIRPass - MLIRBufferTransforms + MLIRComplexDialect + ChloOps StablehloOps TorchMLIRTorchDialect TorchMLIRConversionUtils diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index c2dc9561fa3c..9c8123bfdbad 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -29,6 +29,32 @@ using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; namespace { +static Value createInitialValueForGatherScatterOp(Operation *op, + RankedTensorType constType, + PatternRewriter &rewriter) { + auto elementTy = constType.getElementType(); + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getZero( + elementTy.cast().getFloatSemantics(), + /*negative=*/false)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + op->emitError("unimplemented lowering in " + "createInitialValueForGatherScatterOp"); + return nullptr; +} + Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value input, Value indices, int64_t axis, size_t dimSizeIndexBits) { @@ -217,6 +243,162 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value weight = adaptor.getWeight(); + Value indices = adaptor.getIndices(); + Value offsets = adaptor.getOffsets(); + + auto weightTy = weight.getType().cast(); + if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2) + return rewriter.notifyMatchFailure( + op, "weight must be rank 2 tensor with static shapes"); + + auto indicesTy = indices.getType().cast(); + if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "indices must be a vector with static shapes"); + + auto offsetsTy = offsets.getType().cast(); + if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() && + offsetsTy.getShape()[0] == 1) + return rewriter.notifyMatchFailure( + op, "offsets must be a vector with static shape equal to 1"); + + if (!op.getPaddingIdx().getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: padding_idx should be none"); + + if (!op.getPerSampleWeights().getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: per_sample_weights should be none"); + + bool includeLastOffset; + if (!matchPattern(op.getIncludeLastOffset(), + m_TorchConstantBool(&includeLastOffset))) { + return rewriter.notifyMatchFailure( + op, "include_last_offset is expected to be a constant boolean value."); + } + if (includeLastOffset) + return rewriter.notifyMatchFailure( + op, "include_last_offset is currently not supported"); + + bool scaleGradByFreq; + if (!matchPattern(op.getScaleGradByFreq(), + m_TorchConstantBool(&scaleGradByFreq))) + return rewriter.notifyMatchFailure( + op, "only constant scale_grad_by_freq is currently supported"); + if (scaleGradByFreq) + return rewriter.notifyMatchFailure( + op, "scale gradients is currently not supported"); + + bool sparse; + if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse))) + return rewriter.notifyMatchFailure( + op, "only constant sparse is currently supported"); + if (sparse) + return rewriter.notifyMatchFailure( + op, "sparse gradients is currently not supported"); + + int64_t modeInt; + if (!matchPattern(op.getMode(), m_TorchConstantInt(&modeInt))) { + return rewriter.notifyMatchFailure( + op, "mode is expected to be a constant integer value."); + } + if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { + return rewriter.notifyMatchFailure(op, + "Unimplemented: Mean and Max mode are " + "not supported yet for EmbeddingBag."); + } + + const auto &options = + ConvertAtenOp::getOptions(); + auto weightDimSizes = + *hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits); + auto indicesDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, indices, + options.dimSizeIndexBits); + auto offsetsDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, offsets, + options.dimSizeIndexBits); + + Value gatherOutput = gatherTensorAlongSingleAxis( + rewriter, op, weight, indices, 0, options.dimSizeIndexBits); + + Type elementTy = weightTy.getElementType(); + auto constType = RankedTensorType::get({}, elementTy); + Value initValue = + createInitialValueForGatherScatterOp(op, constType, rewriter); + if (!initValue) + return failure(); + + auto stablehloReduceOp = rewriter.create( + op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0})); + + Region ®ion = stablehloReduceOp.getBody(); + Block &block = region.emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, elementTy); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value addResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), addResult); + } + + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + outShapeVec[0] = one; + auto outShapeTensor = + rewriter.create(op->getLoc(), outShapeVec); + auto resultA = rewriter.create( + loc, getTypeConverter()->convertType(op.getType(0)), + stablehloReduceOp.getResult(0), outShapeTensor); + + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(1).getType()) + .cast(); + Value resultB = + createInitialValueForGatherScatterOp(op, resultType, rewriter); + if (!resultB) + return failure(); + + resultType = getTypeConverter() + ->convertType(op->getResult(2).getType()) + .cast(); + Value resultC = + createInitialValueForGatherScatterOp(op, resultType, rewriter); + if (!resultC) + return failure(); + + resultType = getTypeConverter() + ->convertType(op->getResult(3).getType()) + .cast(); + Value resultD = + createInitialValueForGatherScatterOp(op, resultType, rewriter); + if (!resultD) + return failure(); + + rewriter.replaceOp(op, {resultA, resultB, resultC, resultD}); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, @@ -342,7 +524,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); @@ -376,6 +558,137 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenScatterSrcOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScatterSrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + Value index = adaptor.getIndex(); + Value src = adaptor.getSrc(); + auto inputType = input.getType().cast(); + auto indexType = index.getType().cast(); + auto srcType = src.getType().cast(); + auto indexElemType = indexType.getElementType(); + + if (indexType.getRank() != inputType.getRank() || + inputType.getRank() != srcType.getRank()) { + return op.emitError( + "`index`, `input` and `src` param should have the same rank"); + } + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "only constant int `dim` param supported"); + } + dim = toPositiveDim(dim, inputType.getRank()); + if (!isValidDim(dim, inputType.getRank())) { + return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); + } + + auto options = getOptions(); + + auto indexShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + if (failed(indexShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dim sizes of `index` param"); + } + auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); + + // slice src tensor to have the same shape bound of index tensor in the + // leading dimensions. PyTorch has guaranteed that src tensor size will not be + // smaller than that of index tensor. REF: + // https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_ + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 0)); + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + SmallVector sliceIndicies(srcType.getRank(), zero); + SmallVector sliceStrides(srcType.getRank(), one); + + auto sliceIndiciesValue = + rewriter.create(loc, sliceIndicies); + auto sliceStridesValue = + rewriter.create(loc, sliceStrides); + auto sliceLimitIndiciesValue = + rewriter.create(loc, *indexShapeInfo); + + auto newSrcType = + RankedTensorType::get(indexType.getShape(), srcType.getElementType()); + src = rewriter.create( + loc, newSrcType, src, sliceIndiciesValue, sliceLimitIndiciesValue, + sliceStridesValue); + + // generate scatter indicies for stablehlo::Scatter op. + auto toConcatIndexShapeValueVec = *indexShapeInfo; + toConcatIndexShapeValueVec.push_back(one); + auto toConcatIndexShape = + rewriter.create(loc, toConcatIndexShapeValueVec); + + auto indexShape = indexType.getShape(); + SmallVector toConcatIndexShapeVec(indexShape.begin(), + indexShape.end()); + toConcatIndexShapeVec.push_back(1); + RankedTensorType toConcatIndexType = + RankedTensorType::get(toConcatIndexShapeVec, indexElemType); + + SmallVector toConcat; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (i == dim) { + toConcat.push_back(rewriter.create( + loc, toConcatIndexType, index, toConcatIndexShape)); + } else { + toConcat.push_back(rewriter.create( + loc, toConcatIndexType, toConcatIndexShape, + rewriter.getI64IntegerAttr(i))); + } + } + + auto scatterIndicies = rewriter.create( + loc, toConcat, static_cast(inputType.getRank())); + SmallVector sliceSizes(inputType.getRank(), 1); + + // generate ScatterDimensionNumbers for stablehlo::Scatter op. + int64_t indexVecDim = inputType.getRank(); + SmallVector scatterDimOperandDimMap; + SmallVector insertedWindowDims; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + scatterDimOperandDimMap.push_back(i); + insertedWindowDims.push_back(i); + } + auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + /*updateWindowDims=*/{}, + /*insertedWindowDims=*/insertedWindowDims, + /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, + /*indexVectorDim=*/indexVecDim); + + auto stablehloScatterOp = rewriter.create( + loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false); + + // config update computation function: just return the element from src. + Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); + // add block arguments + auto blockArgumentType = + RankedTensorType::get({}, inputType.getElementType()); + block.addArgument(blockArgumentType, loc); + block.addArgument(blockArgumentType, loc); + + auto *lhsArg = block.args_begin(); + auto *rhsArg = std::next(lhsArg); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + rewriter.create(loc, *rhsArg); + } + + rewriter.replaceOp(op, stablehloScatterOp.getResults()); + return success(); +} + // AtenIndexTensorOp // Convert AtenIndexTensorOp to StableHlo::GatherOp // Step 1: broadcast indices to the same shape @@ -402,8 +715,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Output: [[3, 3, 3], // [8, 8, 2]] template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenIndexTensorOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); Value input = adaptor.getSelf(); @@ -429,11 +742,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto indexTensor = indexTensors[i]; - auto indexTorchTensor = indicesTorchType[i]; - // TODO: add support for none index input - if (indexTorchTensor.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Only list ranked tensor types index are supported"); auto indexTensorType = indexTensor.getType().cast(); for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { if (size == kUnknownSize) @@ -539,9 +847,11 @@ void mlir::torch::torch_to_stablehlo:: target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingBagPaddingIdxOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 0786151cb217..71d679aeada4 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -785,7 +785,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { const auto &options = getOptions(); bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, options.dimSizeIndexBits); - bias = hlo::promoteType(rewriter, bias, outTy); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 4bfe6c6110ef..7c28a2fd3004 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -16,13 +16,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include #include @@ -35,7 +35,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -373,168 +373,195 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// AtenAvgPool2dOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenAvgPool2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); - auto inputElemTy = inputTy.getElementType(); - auto inputRank = inputTy.getRank(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); - auto outShape = outTy.getShape(); - - if (inputRank <= 2) { - return op.emitError( - "avg_pooling2d only supports inputs with rank higher than 2"); - } - SmallVector padding, kernelSize, stride; - bool ceilMode = false; - bool countIncludePad = true; - - if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); - } - if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); - } - if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); - } - if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); - } - if (!(matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)))) { - return rewriter.notifyMatchFailure( - op, "non-const bool count_include_pad unsupported!"); - } - if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) { - return rewriter.notifyMatchFailure( - op, "only None divisor_override supported for now!"); - } - - // prepend 1 to kernelSize, stride, dilation until they are of same rank as - // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); - - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); - std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); - DenseIntElementsAttr pad = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(inputRank), static_cast(2)}, - rewriter.getI64Type()), - stablehloPadding); - - auto reduceWindowSum = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); - - Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); +namespace { +template +class ConvertAtenAvgPoolOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + RankedTensorType inputTy = input.getType().cast(); + Type inputElemTy = inputTy.getElementType(); + int64_t inputRank = inputTy.getRank(); + RankedTensorType outTy = ConvertAtenOp::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + auto outShape = outTy.getShape(); + + + if (inputRank <= Dim) { + return op.emitError( + "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + } + SmallVector padding, kernelSize, stride; + bool ceilMode = false; + bool countIncludePad = true; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + if (!(matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); + } - // Add bb argument - auto blockArgumentType = RankedTensorType::get({}, inputElemTy); - sumBlock.addArgument(blockArgumentType, op->getLoc()); - sumBlock.addArgument(blockArgumentType, op->getLoc()); - auto *firstArg = sumBlock.args_begin(); - auto secondArg = sumBlock.args_rbegin(); + if constexpr (std::is_same()) { + if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); + } - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sumBlock); + // Prepend 1 to kernelSize, stride, dilation until they are of same rank + // as input + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - Dim); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - Dim); + if (Dim == 1) { + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + } else { + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } - Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); - } + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(stablehloKernelSize.size())}, + rewriter.getI64Type()), + stablehloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(stablehloStride.size())}, + rewriter.getI64Type()), + stablehloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(stablehloDilation.size())}, + rewriter.getI64Type()), + stablehloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + + auto reduceWindowSum = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + auto firstArg = *sumBlock.args_begin(); + auto secondArg = *sumBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); + + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); + } - // Use kernel size as the divisor - if (countIncludePad) { - Value divisor = hlo::getConstTensor( + // Use kernel size as the divisor + if (countIncludePad) { + Value divisor; + if (Dim == 1) { + divisor = + hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) + .value(); + } else { + divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); - divisor = hlo::promoteType(rewriter, divisor, outTy); - DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); - return success(); - } - - // Use another stablehlo.ReduceWindowOp to get the divisor - Value windowSizeConst = - hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); - windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); - const auto &options = getOptions(); - auto inputShapeVec = - *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - auto inputShapeTensor = rewriter.create( - op->getLoc(), inputShapeVec); - - windowSizeConst = rewriter.create( - op->getLoc(), - RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), - windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); - - Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - auto reduceWindowSize = rewriter.create( - op->getLoc(), RankedTensorType::get(outShape, inputElemTy), - windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, - windowDilations, pad); - - Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock(); + } + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); + } - // Add bb argument - blockArgumentType = RankedTensorType::get({}, inputElemTy); - sizeBlock.addArgument(blockArgumentType, op->getLoc()); - sizeBlock.addArgument(blockArgumentType, op->getLoc()); - firstArg = sizeBlock.args_begin(); - secondArg = sizeBlock.args_rbegin(); + // Use another mhlo.ReduceWindowOp to get the divisor + Value windowSizeConst = + hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); + windowSizeConst = + hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); + const auto &options = ConvertAtenOp::getOptions(); + auto inputShapeVec = + *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + windowSizeConst = rewriter.create( + op->getLoc(), + RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), + windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + + Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + auto reduceWindowSize = rewriter.create( + op->getLoc(), RankedTensorType::get(outShape, inputElemTy), + windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, + windowDilations, pad); + + Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock(); + + // Add bb argument + blockArgumentType = RankedTensorType::get({}, inputElemTy); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + firstArg = *sizeBlock.args_begin(); + secondArg = *sizeBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); + + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); + } - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sizeBlock); + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); + return success(); - Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); } - rewriter.replaceOpWithNewOp( - op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); - return success(); +}; } + // AtenCumsumOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -620,6 +647,8 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add>(typeConverter, context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); target.addIllegalOp(); @@ -629,4 +658,11 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); +#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context, options) + INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); + INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); +#undef INSERT_ATEN_AVGPOOL_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index ce0d1f371cb6..36f4d49e9a99 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -68,6 +68,24 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getInf( + elementTy.cast().getFloatSemantics(), + /*negative=*/false)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, + {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -481,6 +499,68 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenMinOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenMinOp to StableHLO"); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value minResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), minResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + // AtenSumDimIntListOp namespace { template <> @@ -838,6 +918,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 785ae50e6b01..a25a66bbb293 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -185,15 +185,14 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, dtype_tensor); } -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { - Operation *op = input.getDefiningOp(); - TensorType in_type = input.getType().dyn_cast(); +Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + TensorType outType) { + TensorType in_type = input.getType().cast(); if (in_type.getElementType() != outType.getElementType()) { TensorType promotedType = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, - input); + return rewriter.create(loc, promotedType, input); } return input; } diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 434d55c760d3..4bcc02344e7d 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -44,6 +44,7 @@ class ConvertTorchToStablehlo registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -51,7 +52,8 @@ class ConvertTorchToStablehlo MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); + tensor::TensorDialect, arith::ArithDialect, + shape::ShapeDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index a34e2db8359b..d11a5524af7d 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -309,7 +309,7 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); @@ -361,7 +361,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); - TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = getTypeConverter(); Value input = adaptor.getSelf(); Value torchTypeInput = op.getSelf(); Value minlength = adaptor.getMinlength(); @@ -1273,13 +1273,13 @@ class ConvertAtenScatterReduceTwoOp // Set the values in the input tensor to the smallest element of that // type TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/true); + /*getMin=*/true); normalizationValue = rewriter.create(loc, minAttr); } else if (reduceEnum == torch_upstream::ReductionType::MIN) { // Set the values in the input tensor to the largest element of that // type TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/false); + /*getMin=*/false); normalizationValue = rewriter.create(loc, maxAttr); } @@ -1332,7 +1332,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } @@ -1340,7 +1340,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } @@ -1498,11 +1498,29 @@ class ConvertAtenCumsumOp : public OpConversionPattern { matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto resultType = input.getType().cast(); + auto resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); Type elementType = resultType.getElementType(); + Type inputElementType = + input.getType().cast().getElementType(); + + // Converting the input element type to the result's element type. + // The only possible mismatch would be when the input element type is an + // integer but not `si64`. Therefore, we directly convert the input to + // `si64`. Rest all cases are handled in the dtype definition for this op. + if (elementType != inputElementType) { + Value torchInput = convertTensorToDtype( + rewriter, loc, op.getSelf(), + rewriter.getIntegerType(64, IntegerType::Signed)); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(torchInput.getType()), + torchInput); + } + int64_t inputRank = resultType.getRank(); - Location loc = op->getLoc(); Value dtype = op.getDtype(); if (!dtype.getType().isa()) return rewriter.notifyMatchFailure( @@ -1533,10 +1551,10 @@ class ConvertAtenCumsumOp : public OpConversionPattern { Value result = createTMTensorScanOp( rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { - Value sum = (input.getType().isa() - ? b.create(loc, input, acc) - : b.create(loc, input, acc)) - ->getResult(0); + Value sum = + (input.getType().isa() + ? b.create(loc, input, acc)->getResult(0) + : b.create(loc, input, acc)->getResult(0)); b.create(loc, sum); }); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a8498a83bba2..51928163a27b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -125,8 +125,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, return (doubleValue == static_cast(static_cast(doubleValue))); } else { assert(isInt); - return (intValue >= std::numeric_limits::min()) && - (intValue <= std::numeric_limits::max()); + return (intValue >= static_cast(std::numeric_limits::min())) && + (intValue <= static_cast(std::numeric_limits::max())); } return true; } @@ -149,12 +149,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "Unable to extract the scalar constant"); if (dtype.isa()) { - tosaTensor = tosa::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) + tosaTensor = tosa::getConstTensor(rewriter, op, + (isFloat ? doubleValue : intValue), + dshape, dtype) .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); - if (w!= 1 && w != 32 && w != 64) + if (w != 1 && w != 32 && w != 64) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "Unsupported integer type: " << intType; }); @@ -166,7 +167,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } bool d = isFloat ? static_cast(doubleValue) - : static_cast(intValue); + : static_cast(intValue); tosaTensor = tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } else if (w == 32) { @@ -627,7 +628,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()).value(); + auto zero = + tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()) + .value(); auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -1063,17 +1066,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + Value expTensor; Value expScalar = op.getExponent(); if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, - selfTy.getElementType(), {}))) + outType.getElementType(), {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA Pow operation"); - auto outType = - getTypeConverter()->convertType(op.getType()).template cast(); - auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, self, expTensor); rewriter.replaceOp(op, powOp.getResult()); @@ -2029,6 +2032,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto biasElemTy = inputElemTy.isa() ? inputElemTy : rewriter.getI32Type(); + int64_t groups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { + return rewriter.notifyMatchFailure(op, "non-const group size unsupported"); + } + SmallVector stride; if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride))) return rewriter.notifyMatchFailure(op, "non-const stride list unsupported"); @@ -2048,11 +2056,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Unimplemented: only non-transposed convolutions supported"); - int64_t groups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) - return rewriter.notifyMatchFailure( - op, "non-const group convolution unsupported"); - // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. @@ -2064,7 +2067,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); - // TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose. + // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. + // Perform the necessary transformations. std::optional nchwToNhwcTransposeConst = tosa::getConstTensor(rewriter, op, /*vec=*/{0, 2, 3, 1}, @@ -2081,26 +2085,80 @@ LogicalResult ConvertAtenOp::matchAndRewrite( nchwToNhwcTransposeConst.value()) .getResult(); - SmallVector transposedWeightShape( - {weightShape[0], weightShape[2], weightShape[3], weightShape[1]}); - auto transposedWeightType = RankedTensorType::get( - makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); - auto transposedWeight = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedWeightType), weight, - nchwToNhwcTransposeConst.value()) - .getResult(); + SmallVector transformedWeightShape; + RankedTensorType transformedWeightType; + Value transformedWeight; + int64_t outputCDim; + if (groups == 1 || weightShape[1] != 1) { + // full (group) convolution: O(I/G)HW-> OHWI + transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3], + weightShape[1]}; + transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + transformedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), weight, + nchwToNhwcTransposeConst.value()) + .getResult(); + outputCDim = transformedWeightShape[0]; + } else { + // depthwise convolution: O(I/G)HW-> HWIM) + // transpose: O(I/G)HW -> HWO(I/G) + std::optional transposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{2, 3, 0, 1}, + /*shape=*/{static_cast(4)}); + SmallVector transposedWeightShape = { + weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; + auto transposedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); + auto transposedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedWeightType), weight, + transposeConst.value()) + .getResult(); + + // reshape: HWO(I/G) -> HWIM + outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; + if (outputCDim == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "number of output channels must be statically known for " + "depthwise convolutions"); + } + transformedWeightShape = { + transposedWeightShape[0], + transposedWeightShape[1], + groups, + outputCDim / groups, + }; + transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + transformedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), + transposedWeight, + rewriter.getDenseI64ArrayAttr(transformedWeightShape)) + .getResult(); + } int64_t outputHDim, outputWDim; if (inputTy.hasStaticShape()) { - outputHDim = (transposedInputShape[1] + padding[0] + padding[1] - - dilation[0] * (transposedWeightShape[1] - 1) - 1) / + int64_t inputHDim = inputShape[2]; + int64_t inputWDim = inputShape[3]; + int64_t weightHDim = weightShape[2]; + int64_t weightWDim = weightShape[3]; + outputHDim = (inputHDim + padding[0] + padding[1] - + dilation[0] * (weightHDim - 1) - 1) / stride[0] + 1; - outputWDim = (transposedInputShape[2] + padding[2] + padding[3] - - dilation[1] * (transposedWeightShape[2] - 1) - 1) / + outputWDim = (inputWDim + padding[2] + padding[3] - + dilation[1] * (weightWDim - 1) - 1) / stride[1] + 1; } else { @@ -2111,25 +2169,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Output shape is NHWC, to be transposed back to NCHW. Output elemTy for // quantized input is i32, which gets rescaled down to quantized output range. SmallVector outputShape = {transposedInputShape[0], outputHDim, - outputWDim, transposedWeightShape[0]}; + outputWDim, outputCDim}; DenseI64ArrayAttr paddingAttr = rewriter.getDenseI64ArrayAttr(padding); DenseI64ArrayAttr strideAttr = rewriter.getDenseI64ArrayAttr(stride); DenseI64ArrayAttr dilationAttr = rewriter.getDenseI64ArrayAttr(dilation); + Value convOpResult; if (groups == 1) { + // full convolution auto convOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); convOpResult = rewriter .create(op->getLoc(), getTypeConverter()->convertType(convOpTy), - transposedInput, transposedWeight, bias, - paddingAttr, strideAttr, dilationAttr) + transposedInput, transformedWeight, bias, + paddingAttr, + strideAttr, + dilationAttr) + .getResult(); + } else if (weightShape[1] == 1) { + // depthwise convolution + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); + convOpResult = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + paddingAttr, + strideAttr, + dilationAttr) .getResult(); } else { + // general group convolution convOpResult = createConvInGroups( - rewriter, op, outputTy, weightShape, transposedInput, transposedWeight, + rewriter, op, outputTy, weightShape, transposedInput, transformedWeight, bias, groups, paddingAttr, strideAttr, dilationAttr); } @@ -2275,7 +2351,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // reshaped so it sits on the same dim as 'C'. auto reshapeToNormInputDim = [&](Operation *op, ConversionPatternRewriter &rewriter, - TypeConverter *converter, Type outType, + const TypeConverter *converter, Type outType, const Value toBcast, Value &result) { RankedTensorType toBcastType = toBcast.getType().dyn_cast(); @@ -2324,11 +2400,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); - auto epsilonConst = - tosa::getConstTensor(rewriter, op.getOperation(), - {static_cast(eps)}, {}, - meanType.getElementType()) - .value(); + auto epsilonConst = tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, + meanType.getElementType()) + .value(); auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, @@ -2417,7 +2492,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape), inputType.getElementType()), - sumDiv, rewriter.getI64IntegerAttr(i)); + sumDiv, rewriter.getI32IntegerAttr(i)); } return rewriter.create( @@ -2642,7 +2717,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); - auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, + auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056f}, ln2Shape, selfType.getElementType()) .value(); auto rcpOp = @@ -2873,21 +2948,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}, dtype).value(); + auto a1 = + tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}, dtype).value(); + auto a2 = + tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}, dtype).value(); + auto a3 = + tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}, dtype).value(); + auto a4 = + tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2913,7 +2992,6 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, Type dtype) { auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one @@ -2922,13 +3000,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}, dtype).value(); + tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + Value oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -2962,8 +3041,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); cdf = rewriter.createOrFold( - op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); - + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, @@ -2999,15 +3078,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto loc = op->getLoc(); - const double cstAlpha0 = 1.12837916709551257390; - const double cstAlpha1 = 0.70710678118654752440; - const double oneHalf = 0.5; - const double kAlpha = cstAlpha0 * cstAlpha1; + const float cstAlpha0 = 1.12837916709551257390f; + const float cstAlpha1 = 0.70710678118654752440f; + const float oneHalf = 0.5f; + const float kAlpha = cstAlpha0 * cstAlpha1; - Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); + Value kAlphaHalf = tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, + {}, selfElemTy) + .value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}, selfElemTy).value(); + tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( @@ -3078,7 +3158,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); } - Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); + Value replace = + tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( @@ -3286,7 +3367,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( prunedShape.push_back(en.value()); } - auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim); + auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); Value reduceMax = rewriter.create( @@ -3360,14 +3441,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( start = toPositiveDim(start, sizeOfDim); start = std::clamp(start, (int64_t)0, sizeOfDim); + start = std::min(selfType.getShape()[dim], start); + int64_t end; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { if (isa(op.getEnd().getDefiningOp())) - end = sizeOfDim; + end = selfType.getShape()[dim]; else return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); } - // support for end < 0 end = toPositiveDim(end, selfType.getShape()[dim]); // support for end out of upper bound @@ -3647,7 +3729,7 @@ class SimplifyAten_IndexPutImplOpNone Value newIndicesList = rewriter.create(op->getLoc(), op.getIndices().getType(), newIndices); - + newIndexPut = rewriter.create(op.getLoc(), op.getType(), newIndexPut, newIndicesList, op.getValues(), op.getAccumulate(), op.getUnsafe()); } @@ -3798,7 +3880,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Convert indicesTorchType to TOSA types auto indexTensors = getTypeConvertedValues( rewriter, op->getLoc(), getTypeConverter(), indicesTorchType); - + // the number of tensors in indexTensors is equal to the rank of outType if (indexTensors.size() != 1) { return rewriter.notifyMatchFailure(op, "Expected 1 indices "); @@ -3811,7 +3893,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Expected indices to have same shape as values"); - + auto outType = dyn_cast(getTypeConverter()->convertType(op.getType())); if (!outType) @@ -4084,8 +4166,8 @@ class ConvertAtenIndexTensorOpNone }; template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenIndexTensorOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10], // [11,12,13,14,15],[16,17,18,19,20]]) # 4*5 @@ -4133,19 +4215,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto index = indexTensors[i]; - auto indexTorch = tensorsTorchType[i]; - // TODO add support for none index input like torch.ops.aten.index(x, - // (None, index1, index2, None)) - if (indexTorch.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Only list ranked tensor types index are supported"); auto indexType = index.getType().dyn_cast(); auto indexShape = indexType.getShape(); indexesShape.push_back(makeShapeTorchCompatible(indexShape)); indexesRank.push_back(indexType.getRank()); - // index i64 to i32 for tosa compatible + // Make type of index tosa compatible, i64 to i32. if (indexType.getElementType() != rewriter.getIntegerType(32)) { index = rewriter.create( op->getLoc(), @@ -4206,12 +4282,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Support for multiple index auto index = indexTensors[0]; - auto indexTorch = tensorsTorchType[0]; - // TODO add support for none index input like torch.ops.aten.index(x, (None, - // index1, index2, None)) - if (indexTorch.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Only list ranked tensor types index are supported"); auto indexType = index.getType().dyn_cast(); auto indexShape = indexType.getShape(); // index i64 to i32 for tosa compatible @@ -4387,7 +4457,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -4468,7 +4538,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -5283,7 +5353,7 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - TypeConverter *typeConverter = this->getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); auto outType = typeConverter->convertType(op.getType()).cast(); int64_t rank = outType.getRank(); @@ -5317,7 +5387,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( in = tosa::promoteType(rewriter, in, outType); auto result = tosa::CreateOpAndInfer( - rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim)); + rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); rewriter.replaceOp(op, result.getResult()); return success(); } @@ -5338,7 +5408,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .template cast(); auto elementType = resultType.getElementType(); - if (selfTy.getElementType().isa()) { + if (isa(selfTy.getElementType())) { self = rewriter.createOrFold( op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType), self); @@ -5356,9 +5426,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); + auto loc = op.getLoc(); MLIRContext* ctx = op->getContext(); - mlir::TypeConverter* typeConverter = this->getTypeConverter(); + const TypeConverter* typeConverter = this->getTypeConverter(); bool pinMemory; if (!op.getPinMemory().getType().template isa() && @@ -5440,7 +5510,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( emptyVal = DenseFPElementsAttr::get(resultType, {0.0}); else if (maybeResultElementType->isF32()) emptyVal = DenseFPElementsAttr::get(resultType, {0.0F}); - else + else return rewriter.notifyMatchFailure(op, "unsupported: dtype used for empty.memory_format is unsupported"); } @@ -5564,7 +5634,7 @@ class SimplifyAtenIndexTensorWithSliceIndex if (!input) { return rewriter.notifyMatchFailure(op, "requires tensor type"); } - + if (llvm::count_if(indices, [](Value v) { return !isa(v.getType()); }) == 1) { @@ -5722,9 +5792,12 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - // Mark constant ops as legal, so the error message about - // "failed to legalize" - // mentions the real problematic op and not the constants used by it. + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); @@ -5945,7 +6018,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenLeTensorOp); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index afc041263174..24e0e36fc474 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -382,7 +382,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixReducesumShape, indicesType.getElementType()), - flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1)); + flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1)); // And reshape to [N, W] // %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> @@ -412,6 +412,277 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, .getResult(); } +// Lower indexput op to tosa::scatter op +// Mostly take from the up function convertGatherNdOp() +std::optional convertScatterNdOp(PatternRewriter &rewriter, + Operation *op, Type outType, + Value paramsValue, Value indicesValue, + Value fillValues) { + auto resultType = outType.dyn_cast(); + auto paramsType = paramsValue.getType().dyn_cast(); + auto indicesType = indicesValue.getType().dyn_cast(); + auto fillValuesType = fillValues.getType().dyn_cast(); + + if (!resultType || !paramsType || !indicesType) + return std::nullopt; + + // N: number of batches + // Always 1 for ScatterOp + // + // Because TOSA's Scatter operator already uses the symbol 'N' for + // the number of batches, we will use the symbol 'ND' to specify the + // number of dimensions that are sliced from params instead of'N' in + // the TF MLIR documentation. + // + // ND: indices.shape[-1] + // + // W: number of indices in each batch + // Computed as: + // product(indices.shape[0:-1]) (all but the last dimension) + // + // K: range of each index + // Computed as: + // product(params.shape[0:ND-1]) + // + // C: number of channels for each index + // Computed as: + // product(params.shape[ND:]) + // + // The params tensor needs to be reshaped, but not transposed, to move the + // dimensions into [N, K, C] order. + // + // The dimensions of the input params[] tensor are grouped in the following + // order to begin with: + // + // [ParamIndices, ParamChannels] + // |------------||-------------| + // K C + // + // The reshape simply flattens the params tensor into a 2D [K, C] shape. + // + // Indices needs to be put in the form of [N, W], but a simple flattening + // will not suffice, because the indices need to index into a [W]-shape + // vector instead of the params.shape[0:ND-1] tensor that we had before. + // + // To flatten the coordinates, first reshape indices to a [W, ND] matrix, + // where the matrix now represents W ND-dimensional coordinates into the + // params tensor. + // + // From here, we take each of the ND dimensions and multiply it with + // the size of the next params dimension (or 1 for the last + // dimension), then sum all these together with a reduce_sum + // operator. This is exactly the same mathematics as one would use + // flatten the indices of an N-dimensional row-major array into a + // 1-D array in C. + // + // More precisely, do an element-wise multiply with [params.shape[1 + // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a + // [W]-shaped tensor, then trivially reshape to [N=1, W] to be + // compatible with the scatter operator's shape. + // + // Then perform the tosa.scatter() operation. + // + // Now we have result = [N, K, C]. + // + // Reshape with a single, simple reshape to the final output shape of: + // [Indices, ParamChannels] + // + // Where, Indices is indices.shape[0:ND-1] + // + // For easy understanding, all following comments take an exact value for each + // argument Example: Take TF style indices as input + // torch.aten._index_put_impl %input, %indices, %fillValue, %false, %false : + // !torch.vtensor<[1,4],si64>, !torch.vtensor<[3,2],si64>, + // !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool -> + // !torch.vtensor<[1,4],si64> + // Detail algorithm visualization: + + int N = 1, W = 1, K = 1, fillK = 1, C = 1, ND = 1; + + int paramsRank = paramsType.getShape().size(); // 2 + int indicesRank = indicesType.getShape().size(); // 2 + + // ND: indices.shape[-1] + ND = indicesType.getShape()[indicesRank - 1]; // 2 depth of input + + if (ND > paramsRank) { + (void)rewriter.notifyMatchFailure( + op, "size of last dimension of indices must be <= params rank"); + return std::nullopt; + } + + // Calculate N, K, W, C. (N is always 1) + // number of indices/selected value in each batch product(indices.shape[0:-1]) + // (all but the last dimension) W = 1*3 = 3 + for (int i = 0; i < (indicesRank - 1); i++) { + W *= indicesType.getShape()[i]; + } + + // K: range of each index, total number of inputs(chould be scatter) after + // flattened k = 1*1*4 = 4 + for (int i = 0; i < ND; i++) { + K *= paramsType.getShape()[i]; + } + + // C: number of channels for each index : numbers of values inside each + // input(chould be scatter) C = product(params.shape[ND:] ND = 2, paramsRank, + // C = 1 + for (int i = ND; i < paramsRank; i++) { + C *= paramsType.getShape()[i]; + } + + // int N = 1, W = 3, K = 4, fillk = 3, C = 1, ND = 2; + SmallVector tosaInputValuesShape({N, K, C}); // {1,4,1} + SmallVector tosaIndicesShape({N, W}); // {1,3} + SmallVector indicesMatrixShape({W, ND}); // {3,2} + SmallVector indicesMatrixReducesumShape({W, 1}); // {3,1} + + // Preprocess fill value. + // There are 2 cases of fillValues, + // 1. !torch.vtensor<[1,3],si64> + // [[0,0,0]] -> [[[0], [0], [0]]] + // 2. !torch.vtensor<[],si64> + // reshape(1) tile(3) reshape(1,3) reshape(1,3,1) + // [] -> [0] -> [0,0,0] -> [[0,0,0]] -> [[[0], [0], [0]]] + // reshape to [1] and then tile to same number of indicesValue.shape[0], + // [1,1,1] + if (fillValuesType.getRank() == 0) { + // [] -> [0] + SmallVector oneShape({1}); // {3,1} + auto tosaFillValuesOneReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(oneShape, fillValuesType.getElementType()), + fillValues, rewriter.getDenseI64ArrayAttr(oneShape)); + + // [0] -> [0,0,0] + SmallVector tileShape({W}); // {3} + auto tosaFillValuesTileOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()), + tosaFillValuesOneReshapeOp.getResult(), + rewriter.getDenseI64ArrayAttr(tileShape)); + + // [0,0,0] -> [[0,0,0]] + SmallVector newTosaFillValuesShape({N, W}); // {1,3} + auto newTosaFillValuesReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(newTosaFillValuesShape, + fillValuesType.getElementType()), + tosaFillValuesTileOp.getResult(), + rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape)); + fillValues = newTosaFillValuesReshapeOp.getResult(); + fillValuesType = fillValues.getType().dyn_cast(); + } + + // fillK: range of each index, total number of fillInput(could be scatter) + // after flattened k = 1*1*3 = 3 + for (int i = 0; i < ND; i++) { + fillK *= fillValuesType.getShape()[i]; + } + SmallVector tosaFillValuesShape({N, fillK, C}); // {1,3,1} + + // Reshape/Flatten fillValues to 3d tensor + // [[0,0,0]] -> [[[0], [0], [0]]] + // %10 = "tosa.reshape"(%1) {new_shape = array} : + // (tensor<1x3xi64>) -> tensor<1x3x1xi64> + auto tosaFillValuesReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(tosaFillValuesShape, + fillValuesType.getElementType()), + fillValues, rewriter.getDenseI64ArrayAttr(tosaFillValuesShape)); + + // Reshape/Flatten input to 3d tensor + // [[1, 2, 3, 4]] -> [[[1], [2], [3], [4]]] + // %9 = "tosa.reshape"(%0) {new_shape = array} : + // (tensor<1x4xi64>) -> tensor<1x4x1xi64> + auto tosaValuesReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(tosaInputValuesShape, paramsType.getElementType()), + paramsValue, rewriter.getDenseI64ArrayAttr(tosaInputValuesShape)); + + // Reshape/Flatten the input indices tensor to a 2d [W, ND] matrix. + // [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]] + // %11 = "tosa.reshape"(%8) {new_shape = array} : (tensor<3x2xi32>) + // -> tensor<3x2xi32> + auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), + indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); + + SmallVector flattenedCoeffVec; // [4,1] + // flattenedCoeffVec = [4,1] + for (int i = 1; i < ND; i++) { + flattenedCoeffVec.push_back(paramsType.getShape()[i]); + } + flattenedCoeffVec.push_back(1); + + // flattenedCoeffVec = [4,1] + for (int i = ND - 1; i > 0; i--) { + flattenedCoeffVec[i - 1] *= flattenedCoeffVec[i]; + } + + // Create the tosaConstTensor for the flattenedCoeffVec. + // %12 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> + // tensor<2xi32> + auto flattenedCoeffValue = + getConstTensor(rewriter, op, flattenedCoeffVec, + {static_cast(flattenedCoeffVec.size())}); + + if (!flattenedCoeffValue) + return std::nullopt; + + // Multiply the coefficients by the coordinates. + // [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]] + // %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, + // tensor<2xi32>) -> tensor<3x2xi32> + auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), + indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + + // Sum up the products of the coefficients and coordinates + // [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]] + // %14 = "tosa.reduce_sum"(%13) {axis = 1 : i64} : (tensor<3x2xi32>) -> + // tensor<3x1xi32> + auto flattenedIndicesReduceOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(indicesMatrixReducesumShape, + indicesType.getElementType()), + flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1)); + + // And reshape to [N, W] + // [[1],[2],[3]] -> [[1,2,3]] + // %15 = "tosa.reshape"(%14) {new_shape = array} : + // (tensor<3x1xi32>) -> tensor<1x3xi32> + auto tosaIndicesReshapeOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()), + flattenedIndicesReduceOp.getResult(), + rewriter.getDenseI64ArrayAttr(tosaIndicesShape)); + + // Now the Scatter op itself + // %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>, + // tensor<1x3x1xi64>) -> tensor<1x4x1xi64> input = [[[1], [2], [3], [4]]], + // indices = [[1,2,3]], fillValues= [[[0], [0], [0]]] result = [[[1], [0], + // [0], [0]]] + auto tosaScatterOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(tosaInputValuesShape, resultType.getElementType()), + tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult(), + tosaFillValuesReshapeOp.getResult()); + + // Finally, reshape back to the original output shape of [Indices, + // ParamChannels]. + // [[1, 0, 0, 0]] + // %17 = "tosa.reshape"(%16) {new_shape = array} : + // (tensor<1x4x1xi64>) -> tensor<1x4xi64> + return tosa::CreateOpAndInfer( + rewriter, op->getLoc(), resultType, tosaScatterOp.getResult(), + rewriter.getDenseI64ArrayAttr(resultType.getShape())) + .getResult(); +} + + // Common function for lowering reduce operations to TOSA ops. template std::optional convertReduceOpCommon( @@ -453,7 +724,7 @@ std::optional convertReduceOpCommon( int64_t axis_val = axes_elems.getValues()[i].getInt(); if (axis_val < 0) axis_val += input_rank; - auto axis_attr = rewriter.getI64IntegerAttr(axis_val); + auto axis_attr = rewriter.getI32IntegerAttr(axis_val); shape_vec[axis_val] = 1; RankedTensorType reduce_type = RankedTensorType::get( diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index b71378fa5ad4..ed7f6b2a9539 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -236,7 +236,6 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); - if (dtype) { return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); @@ -264,7 +263,6 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); - if (dtype) { return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 1f6a889b5567..c192ff33a25f 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -104,8 +104,8 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, Type lhsType = lhsDim.getType(); Type rhsType = rhsDim.getType(); auto checkIntOrIndex = [](Type type) { - assert(type.isa() || - type.isa() && "must be either integer or index type"); + assert((type.isa() || type.isa()) && + "must be either integer or index type"); }; checkIntOrIndex(lhsType); checkIntOrIndex(rhsType); @@ -230,7 +230,7 @@ SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, // convert their elements to valid target type. // TODO: remove this when list gets full support. SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, - TypeConverter *converter, + const TypeConverter *converter, SmallVectorImpl &vs) { return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { return converter->materializeTargetConversion( diff --git a/lib/Dialect/Torch/IR/CMakeLists.txt b/lib/Dialect/Torch/IR/CMakeLists.txt index cf54afe06c2e..00210e4fd379 100644 --- a/lib/Dialect/Torch/IR/CMakeLists.txt +++ b/lib/Dialect/Torch/IR/CMakeLists.txt @@ -16,6 +16,9 @@ add_mlir_library(TorchMLIRTorchDialect Core LINK_LIBS PUBLIC + MLIRBytecodeOpInterface + MLIRBytecodeReader + MLIRBytecodeWriter MLIRFuncDialect MLIRIR MLIRSupport diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 35f1a753b46b..c4bae1f9c1c0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -156,6 +156,8 @@ static Value getScalarIntValue(Value input, Location loc, } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); + } else if (auto tensorIntOp = input.getDefiningOp()) { + return tensorIntOp.getT(); } return nullptr; } @@ -299,23 +301,20 @@ LogicalResult ClassTypeOp::verify() { // PrimLoopOp //===----------------------------------------------------------------------===// -OperandRange -PrimLoopOp::getSuccessorEntryOperands(std::optional index) { - assert(index.has_value() && index.value() == 0); +OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getRegion()); return getIterArgsInit(); } void PrimLoopOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - (void)operands; - - if (!index.has_value()) { - regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1)); + RegionBranchPoint point, SmallVectorImpl ®ions) { + Region ®ion = getRegion(); + if (!point.getRegionOrNull()) { + regions.emplace_back(®ion, region.getArguments().slice(1)); return; } - assert(*index == 0); - regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1)); + assert(point == region); + regions.emplace_back(®ion, region.getArguments().slice(1)); regions.emplace_back(getResults()); } @@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() { // PrimLoopConditionOp //===----------------------------------------------------------------------===// -MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands( - std::optional index) { +MutableOperandRange +PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { // Pass all operands except the condition to the successor which is the // parent loop op. return getIterArgsMutable(); @@ -378,19 +377,18 @@ void PrimIfOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); } -void PrimIfOp::getSuccessorRegions(std::optional index, - ArrayRef operands, +void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (index.has_value()) { + if (point.getRegionOrNull()) { regions.push_back(RegionSuccessor(getResults())); return; } // If the condition is constant, we can give a more precise answer. - if (auto condAttr = operands.front().dyn_cast_or_null()) { - Region *executedRegion = - condAttr.getValue().isOne() ? &getThenRegion() : &getElseRegion(); + bool condition; + if (matchPattern(getCondition(), m_TorchConstantBool(&condition))) { + Region *executedRegion = condition ? &getThenRegion() : &getElseRegion(); regions.push_back(RegionSuccessor(executedRegion)); return; } @@ -712,20 +710,6 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenTypeAsOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { - Type inType = getSelf().getType(); - Type newType = getOther().getType(); - - if (inType == newType) - return getSelf(); - - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// @@ -860,6 +844,26 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenToOtherOp +//===----------------------------------------------------------------------===// + +void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // Canonicalize `aten.to.other` to `aten.to.device` + patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) { + auto lhs = op.getSelf(); + auto rhs = op.getOther(); + auto getRhsDevice = rewriter.create(op.getLoc(), rhs); + auto getRhsDtype = rewriter.create(op.getLoc(), rhs); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, getRhsDevice.getResult(), + getRhsDtype.getResult(), op.getNonBlocking(), + op.getCopy(), op.getMemoryFormat()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// @@ -925,6 +929,34 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenMinOtherOp +//===----------------------------------------------------------------------===// + +void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // `aten.min.other` -> `aten.minimum` + patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + op.getOther()); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenMaxOtherOp +//===----------------------------------------------------------------------===// + +void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // `aten.max.other` -> `aten.maximum` + patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + op.getOther()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenLenStrOp //===----------------------------------------------------------------------===// @@ -1105,6 +1137,19 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// Aten__Or__TensorOp +//===----------------------------------------------------------------------===// + +void Aten__Or__TensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getOther()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenScalarImplicitOp //===----------------------------------------------------------------------===// @@ -1444,6 +1489,24 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenAnyBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { + auto inputConstruct = getSelf().getDefiningOp(); + if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) + return nullptr; + // If any operand is a constant true, return true. + for (auto operand : inputConstruct.getOperands()) { + bool b = false; + if (matchPattern(operand, m_TorchConstantBool(&b)) && b) { + return getI1IntegerAttr(getContext(), true); + } + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// @@ -1546,7 +1609,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = attributes.get("value").dyn_cast_or_null(); + auto attr = properties.as() + ->getValue() + .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); @@ -1586,7 +1651,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = attributes.get("value").dyn_cast_or_null(); + auto attr = properties.as() + ->getValue() + .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); @@ -2095,7 +2162,16 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, if (!tupleConstruct) return failure(); - rewriter.replaceOp(op, tupleConstruct.getElements()); + llvm::SmallVector derefinedElements; + // The result types may be supertypes of the tuple element types. + // Ensure we maintain the exact type, with identity `derefine`s being + // folded. + for (auto [type, element] : + llvm::zip(op.getResultTypes(), tupleConstruct.getElements())) { + derefinedElements.push_back( + rewriter.createOrFold(op.getLoc(), type, element)); + } + rewriter.replaceOp(op, derefinedElements); return success(); }); } @@ -2233,6 +2309,14 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs)); } +//===----------------------------------------------------------------------===// +// AtenAliasOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { + return getOperand(); +} + //===----------------------------------------------------------------------===// // AtenFloordivIntOp //===----------------------------------------------------------------------===// @@ -2292,6 +2376,25 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } +//===----------------------------------------------------------------------===// +// AtenBroadcastToOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + return nullptr; + if (inType.getSizes().size() != outType.getSizes().size() || + !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) + return nullptr; + for (size_t i = 0; i < inType.getSizes().size(); ++i) { + if (inType.getSizes()[i] != outType.getSizes()[i]) + return nullptr; + } + return getOperand(0); +} + //===----------------------------------------------------------------------===// // AtenSliceTensorOp //===----------------------------------------------------------------------===// @@ -2335,6 +2438,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenMulFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) { + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubFloatOp //===----------------------------------------------------------------------===// @@ -2344,6 +2456,25 @@ OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// AtenAddOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + + if (adaptor.getA().isa() && adaptor.getB().isa()) { + return atenBinaryIntOperatorFoldHelper( + adaptor.getOperands(), + [](int64_t a, int64_t b) -> int64_t { return a + b; }); + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), + [](double a, double b) -> double { return a + b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// @@ -2378,6 +2509,18 @@ OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a / b; }); } +//===----------------------------------------------------------------------===// +// AtenAddFloatIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a + b; }); +} + //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// @@ -2418,6 +2561,21 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenNegFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA()) { + return nullptr; + } + auto value = adaptor.getA().dyn_cast_or_null(); + if (!value) { + return nullptr; + } + return getF64FloatAttr(getContext(), -value.getValue().convertToDouble()); +} + //===----------------------------------------------------------------------===// // AtenSqrtIntOp //===----------------------------------------------------------------------===// @@ -2519,6 +2677,43 @@ void AtenBroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenCudaOp +//===----------------------------------------------------------------------===// + +void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenCudaOp op, PatternRewriter &rewriter) { + // Device information isn't relevant to torch-mlir + auto inputTensor = op.getSelf(); + rewriter.replaceOp(op, inputTensor); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenDeviceWithIndexOp +//===----------------------------------------------------------------------===// + +void AtenDeviceWithIndexOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) { + std::string type; + int64_t index; + if (!matchPattern(op.getType(), m_TorchConstantStr(type))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: type must be a constant string"); + } + if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: index must be a constant integer"); + } + rewriter.replaceOpWithNewOp( + op, type + ":" + std::to_string(index)); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// @@ -2528,6 +2723,8 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); + if (auto tensorIntOp = getA().getDefiningOp()) + return tensorIntOp.getT(); return nullptr; } @@ -2651,28 +2848,26 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { template static void -getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional index, - ArrayRef operands, +getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, SmallVectorImpl ®ions) { - if (!index.has_value()) { + if (!point.getRegionOrNull()) { // First thing the op does is branch into the calculation. regions.emplace_back(&op.getCalculation()); return; } - if (*index == 0) { + if (point == op.getBody()) { // Body returns control to the outer op, passing through results. regions.emplace_back(op.getResults()); return; } - assert(*index == 1); + assert(point == op.getCalculation()); // Calculation branches to the body. regions.emplace_back(&op.getBody()); } void ShapeCalculateOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - getSuccessorRegionsForCalculateOp(*this, index, operands, regions); + RegionBranchPoint point, SmallVectorImpl ®ions) { + getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// @@ -2680,9 +2875,8 @@ void ShapeCalculateOp::getSuccessorRegions( //===----------------------------------------------------------------------===// void DtypeCalculateOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - getSuccessorRegionsForCalculateOp(*this, index, operands, regions); + RegionBranchPoint point, SmallVectorImpl ®ions) { + getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// @@ -2690,7 +2884,7 @@ void DtypeCalculateOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( - std::optional index) { + RegionBranchPoint point) { // The shape operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. @@ -2709,7 +2903,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { //===----------------------------------------------------------------------===// MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands( - std::optional index) { + RegionBranchPoint point) { // The dtype operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 8eb844cbd00b..cee9705af24a 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) { if (type.isSignless() && type.getWidth() == 1) return true; if (type.isSigned()) { - for (unsigned width : {8, 16, 32, 64}) { + for (unsigned width : {4, 8, 16, 32, 64}) { if (type.getWidth() == width) return true; } } if (type.isUnsigned()) { - return type.getWidth() == 8; + return type.getWidth() == 8 || type.getWidth() == 4; } } return false; @@ -404,20 +404,8 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (auto integerType = dtype.dyn_cast()) { return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); - } else if (auto complexType = dtype.dyn_cast()) { - // torch-complex types add the precision of the real and imag values to - // get the final precision i.e., if the real and imag value is of `float` - // type then the complex value is of `complex` type. OTOH, MLIR - // built in complex type doesn't add the precision i.e., if the real and - // imag value is of float type then the resulting complex value is of - // complex type. - auto floatType = complexType.getElementType().dyn_cast(); - if (floatType.getWidth() == 32) - return ComplexType::get(mlir::FloatType::getF16(context)); - else if (floatType.getWidth() == 64) - return ComplexType::get(mlir::FloatType::getF32(context)); - else if (floatType.getWidth() == 128) - return ComplexType::get(mlir::FloatType::getF64(context)); + } else if (dtype.isa()){ + return dtype; } emitError(UnknownLoc::get(context)) << "unimplemented: conversion of dtype " << dtype diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2414538eaf6f..697ad6bbd7ef 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -17,7 +17,7 @@ using namespace mlir; StringRef mlir::torch::Torch::getAbstractInterpLibrary() { -#ifndef _MSC_VER +#if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Woverlength-strings" #endif @@ -6290,14 +6290,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6445,6 +6437,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_dropout\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.gelu\"(%arg0: !torch.list, %arg1: !torch.str) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6541,6 +6538,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.elu\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.prelu\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6565,10 +6566,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.min\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.min.other\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max.other\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.sum\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" @@ -6688,11 +6701,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %0 = torch.derefine %arg2 : !torch.optional to !torch.any\n" -" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" -" return %1 : !torch.list\n" +" func.func @\"__torch_mlir_shape_fn.aten.prod.dim_int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" %1 = torch.derefine %0 : !torch.list to !torch.optional>\n" +" %2 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %arg2, %2) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %3 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" @@ -6803,20 +6817,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %none = torch.constant.none\n" -" %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" %4 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" torch.prim.If.yield %4 : !torch.int\n" +" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.lt.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.list) {\n" +" %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list\n" +" %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" +" torch.prim.If.yield %8 : !torch.list\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.int\n" +" torch.prim.If.yield %arg1 : !torch.list\n" " }\n" -" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list\n" -" return %3 : !torch.list\n" +" %4 = call @\"__torch_mlir_shape_fn.aten.repeat\"(%arg0, %3) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" @@ -6867,6 +6883,163 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.avg_pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" +" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %13 = torch.aten.eq.int %12, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %16 = torch.aten.eq.int %15, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %18 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%19, %2, %11, %8, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %21 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.list) {\n" +" %24 = torch.prim.ListConstruct %18, %20 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %24 : !torch.list\n" +" } else {\n" +" %24 = torch.prim.ListConstruct %17, %18, %20 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %24 : !torch.list\n" +" }\n" +" return %23 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.adaptive_avg_pool1d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7047,12 +7220,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.full_like\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.new_full\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.zeros_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7105,6 +7284,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bernoulli.float\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7172,28 +7354,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list\n" " return %5 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %2 : !torch.tuple, list>\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" -" %int11 = torch.constant.int 11\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" -" %1 = torch.prim.TupleConstruct %0, %int11 : !torch.int, !torch.int -> !torch.tuple\n" -" return %1 : !torch.tuple\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7226,6 +7386,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__or__.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.minimum\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7337,7 +7501,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.number, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -7536,6 +7700,45 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.view_as_real\"(%arg0: !torch.list) -> !torch.list {\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list\n" +" %1 = torch.aten.add.t %arg0, %0 : !torch.list, !torch.list -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.view_as_real\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7593,9 +7796,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.narrow.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %0 = torch.aten._set_item.t %arg0, %arg1, %arg3 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.masked_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.select.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7625,10 +7835,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> {\n" -" %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list>\n" +" %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %arg8) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int, !torch.optional>, !torch.optional) -> !torch.tuple, list, list, list>\n" " return %0 : !torch.tuple, list, list, list>\n" " }\n" -" func.func @__torch__._embedding_bag_helper(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int) -> !torch.tuple, list, list, list> {\n" +" func.func @__torch__._embedding_bag_helper(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.optional>, %arg6: !torch.optional) -> !torch.tuple, list, list, list> {\n" +" %false = torch.constant.bool false\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" @@ -7675,8 +7886,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %19 = torch.aten.append.t %12, %int0 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %12 : !torch.list\n" " } else {\n" -" %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list\n" -" torch.prim.If.yield %19 : !torch.list\n" +" %19 = torch.aten.__is__ %arg5, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" %22 = torch.aten.__is__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %22 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %21 = torch.prim.If %20 -> (!torch.list) {\n" +" %22 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %22 : !torch.list\n" +" } else {\n" +" %22 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %22 : !torch.list\n" +" }\n" +" torch.prim.If.yield %21 : !torch.list\n" " }\n" " %15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list\n" " %16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool\n" @@ -7691,8 +7915,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %18 : !torch.tuple, list, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._embedding_bag\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple, list, list, list> {\n" -" %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list>\n" -" return %0 : !torch.tuple, list, list, list>\n" +" %0 = torch.derefine %arg8 : !torch.int to !torch.optional\n" +" %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int, !torch.optional>, !torch.optional) -> !torch.tuple, list, list, list>\n" +" return %1 : !torch.tuple, list, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" @@ -7837,16 +8062,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int\n" " %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list\n" " %12 = torch.prim.min.self_int %11 : !torch.list -> !torch.int\n" -" %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {\n" -" ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):\n" +" %13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int):\n" " %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional>\n" " %17 = torch.aten.__isnot__ %16, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" %18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) {\n" " %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool\n" " %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" " torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int\n" " } else {\n" -" %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int\n" " %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" " %23 = torch.prim.If %22 -> (!torch.bool) {\n" " torch.prim.If.yield %false : !torch.bool\n" @@ -7855,12 +8080,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int\n" " }\n" -" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" torch.prim.If.yield %20#0, %20#1, %arg2 : !torch.bool, !torch.int, !torch.int\n" " } else {\n" -" torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int\n" +" torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int\n" " }\n" -" torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)\n" +" torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int)\n" " %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool\n" " %15 = torch.prim.If %14 -> (!torch.list) {\n" " %16 = torch.aten.add.t %6, %4 : !torch.list, !torch.list -> !torch.list\n" @@ -7934,6 +8159,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " return %none : !torch.none\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.masked_select\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nonzero_static\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.ListConstruct %arg1, %0 : (!torch.int, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg4 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -8005,17 +8246,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" -" %int10 = torch.constant.int 10\n" -" %int9 = torch.constant.int 9\n" -" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8036,16 +8266,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.acos\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8086,7 +8306,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" @@ -8162,19 +8382,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %none = torch.constant.none\n" -" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %1#1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" @@ -8203,6 +8410,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8239,7 +8454,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8251,7 +8466,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8263,7 +8478,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8279,7 +8494,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8322,6 +8537,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.optional) -> !torch.tuple {\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int11 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8330,7 +8551,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8382,7 +8603,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" @@ -8393,7 +8614,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -8414,6 +8635,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8448,7 +8673,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -8463,11 +8688,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8497,6 +8722,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -8575,7 +8804,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8615,7 +8844,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8662,7 +8895,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8690,6 +8923,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8718,12 +8963,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.number -> !torch.tensor\n" " %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" " return %1 : !torch.int\n" " }\n" @@ -8781,7 +9026,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8789,11 +9034,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8805,7 +9050,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8829,7 +9074,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8849,15 +9094,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -8906,11 +9151,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -8923,7 +9168,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__or__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -9247,7 +9500,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -9255,7 +9508,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" @@ -9429,7 +9682,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nonzero_static\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9447,7 +9708,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -9480,7 +9741,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" " return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9496,39 +9757,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" @@ -9539,16 +9800,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9561,30 +9822,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.union, %arg1: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -9597,7 +9858,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" @@ -9616,16 +9877,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.elu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.prim.ListConstruct %arg1, %arg2, %arg3 : (!torch.number, !torch.number, !torch.number) -> !torch.list\n" +" torch.prim.Loop %int3, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %7 = torch.aten.__getitem__.t %3, %arg4 : !torch.list, !torch.int -> !torch.number\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%7) : (!torch.number) -> !torch.int\n" +" %9 = torch.aten.append.t %2, %8 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %2 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %7 = torch.aten.__getitem__.t %2, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %9 = torch.aten.append.t %4, %8 : !torch.list, !torch.bool -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.aten.any.bool %4 : !torch.list -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" @@ -9670,14 +9979,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int4 = torch.constant.int 4\n" " %false = torch.constant.bool false\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" @@ -9690,20 +9999,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -9755,6 +10064,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %7 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.one_hot\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9767,7 +10090,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.number, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -9785,7 +10108,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" @@ -9796,7 +10119,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" @@ -9815,12 +10138,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %7 : !torch.bool\n" " }\n" @@ -9833,7 +10156,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" @@ -9852,19 +10175,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" " %5 = torch.prim.If %4 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" @@ -9900,6 +10223,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prod.dim_int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -9930,10 +10272,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.min.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0, %arg1) : (!torch.tuple, !torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0, %arg1) : (!torch.tuple, !torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.amax\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" " return %0 : !torch.int\n" @@ -9976,7 +10330,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" @@ -9991,7 +10345,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" @@ -10001,7 +10355,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10127,7 +10481,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -10135,7 +10489,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" @@ -10182,7 +10536,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -10194,6 +10560,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_full\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.number, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -10290,7 +10668,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" " return %arg1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.prims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " return %arg1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.int {\n" @@ -10379,7 +10757,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.tuple {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.tuple {\n" " %int7 = torch.constant.int 7\n" " %int10 = torch.constant.int 10\n" " %int6 = torch.constant.int 6\n" @@ -10473,7 +10851,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" @@ -10551,8 +10933,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" @@ -10652,7 +11034,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { "}\n" ""; // clang-format on -#ifndef _MSC_VER +#if defined(__clang__) #pragma clang diagnostic pop #endif } diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 8f310da08983..30cc4db44181 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -187,53 +187,8 @@ class AdjustCallingConventionForReturn }; } // namespace -static bool isValidNonContainerResultType(Type resultType) { - return resultType.isa() || - resultType.isa() || - resultType.isa() || - resultType.isa() || - resultType.isa(); -} - -static LogicalResult validateReturns(func::FuncOp func) { - if (func.getResultTypes().size() > 1) { - return func->emitError( - "Functions directly imported from Python should only ever return one " - "item. Multiple return values are returned as a tuple."); - } - - // Allow returns of nothing. This shouldn't be possible from Python, but it - // can happen in IR that's been directly constructed. - if (func.getResultTypes().size() == 0) - return success(); - - const auto& resultType = func.getResultTypes().front(); - - // Allow single tensor, scalar, and bool returns - if (isValidNonContainerResultType(resultType)) { - return success(); - } - - // Allow multi-tensor/scalar/bool tuple returns - if (auto tuple = resultType.dyn_cast()) { - const auto& containedTypes = tuple.getContainedTypes(); - bool containsValidTypes = llvm::all_of( - tuple.getContainedTypes(), isValidNonContainerResultType); - if (containedTypes.size() >= 2 && containsValidTypes) { - return success(); - } - } - - return func->emitError( - "Functions must return a single tensor-like value, multiple tensor-like " - "values, or a tuple of more than one tensor-like value. Tensor-like " - "values: tensors, scalars, bools, and Nones."); -} - static LogicalResult adjustCallingConventions(func::FuncOp func, TypeBoundMap &typeBoundMap) { - if (failed(validateReturns(func))) - return failure(); MLIRContext *context = func.getContext(); RewritePatternSet patterns(context); TypeConverter typeConverter; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6285ee02fb05..db4c2dff914a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -219,13 +219,18 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for keepDim"); - Value input = op.getSelf(); + Value input = op.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); + } // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. std::sort(dims.begin(), dims.end()); std::reverse(dims.begin(), dims.end()); for (int64_t dimInt : dims) { - int64_t inputRank = input.getType().cast().getSizes().size(); + int64_t inputRank = inputTy.getSizes().size(); dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -335,6 +340,27 @@ class DecomposeAtenNarrowOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.narrow.Tensor` to `aten.narrow` op +class DecomposeAtenNarrowTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNarrowTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *context = op.getContext(); + // PyTorch makes sure that `start` param is an 0-dim integral tensor. + // REF: https://pytorch.org/docs/stable/generated/torch.narrow.html. + auto start = rewriter.create( + loc, Torch::IntType::get(context), op.getStart()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenZeroOp : public OpRewritePattern { @@ -418,15 +444,28 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - if (!op.getDtype().getType().isa()) + BaseTensorType resultTensorType = op.getType().cast(); + if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( - op, "Unimplemented non-None dtype for softmax"); + op, "expected result type to have a dtype"); + } + Type resultTensorDtype = resultTensorType.getDtype(); + if (!resultTensorDtype.isa()) + return rewriter.notifyMatchFailure(op, + "Only support floating-point type"); - BaseTensorType tensorType = self.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) - return rewriter.notifyMatchFailure(op, "Only support floating type"); + // If `dtype` arg is non-none then convert the input to `dtype`. + if (!op.getDtype().getType().isa()) { + Location loc = op.getLoc(); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + self = rewriter.create( + loc, resultTensorType, self, + getDtypeIntValueForType(rewriter, loc, resultTensorDtype), + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + } - Value result = getSoftmaxResult(op, self, tensorType, rewriter); + Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -1036,6 +1075,46 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +// Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) +namespace { +class DecomposeAtenEluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value alpha = op.getAlpha(); + Value scale = op.getScale(); + Value inputScale = op.getInputScale(); + auto resType = op.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value maxZeroX = rewriter.create(loc, resType, zeroTensor, input); + Value positiveOutput = rewriter.create(loc, resType, maxZeroX, scale); + Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); + Value scaledMinZeroX = rewriter.create(loc, resType, minZeroX, inputScale); + Value expX = rewriter.create(loc, resType, scaledMinZeroX); + Value expXM1 = rewriter.create(loc, resType, expX, constantOne, constantOne); + Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, scale); + Value negativeOutput = rewriter.create(loc, resType, scaledExpXM1, alpha); + + Value eluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOne); + + rewriter.replaceOp(op, eluOutput); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenTOp : public OpRewritePattern { public: @@ -1253,8 +1332,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { SmallVector unsqueezedSizes, expandedSizes, reshapedSizes; SmallVector unsqueezedIntSizes, expandedIntSizes; + assert(repeats.size() >= rank && "leadingRank should greater than 0"); auto leadingRank = repeats.size() - rank; - assert(leadingRank >= 0 && "leadingRank should greater than 0"); for (size_t i = 0; i < leadingRank; ++i) { insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); insertDimSizes(expandedSizes, expandedIntSizes, @@ -2123,6 +2202,58 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { return success(); } }; + +class DeomposeAtenNativeDropoutOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + Value input = op.getInput(); + Value prob = op.getP(); + bool train = false; + if (!op.getTrain().getType().isa()) { + if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { + return rewriter.notifyMatchFailure( + op, "train must be a boolean constant or none"); + } + } + Value noneVal = rewriter.create(loc); + if (!train) { + Value i1Type = + getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1)); + Value inputSize = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), input); + Value trueValue = rewriter.create(loc, 1); + Value trueMask = rewriter.create( + loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + rewriter.replaceOp(op, ArrayRef{input, trueMask}); + return success(); + } + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + } + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = rewriter.create( + loc, op->getResultTypes()[0], maskedInput, oneMinusP); + rewriter.replaceOp( + op, ArrayRef{ + output, convertTensorToDtype(rewriter, loc, boolMask, + IntegerType::get(context, 1))}); + return success(); + } +}; } // namespace // Decompose aten.var into: aten.var.dim op. @@ -3035,6 +3166,33 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.new_full` op into `aten.full` op. +class DecomposeAtenNewFullOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNewFullOp op, + PatternRewriter &rewriter) const override { + Value dtype = op.getDtype(); + if (dtype.getType().isa()) { + BaseTensorType tensorType = op.getSelf().getType().cast(); + if (!tensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected input tensor to have a dtype"); + } + dtype = + getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(), + op.getPinMemory()); + + return success(); + + } +}; +} // namespace + namespace { // Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op. class DecomposeAtenIndexPutOp : public OpRewritePattern { @@ -3108,7 +3266,7 @@ class DecomposeAtenCopyOp : public OpRewritePattern { auto srcTy = op.getSrc().getType().cast(); if (!srcTy.hasSizes() || !srcTy.hasDtype()) { return rewriter.notifyMatchFailure( - op, "expected src type to have a known rank"); + op, "expected src type to have a known rank and dtype"); } Type resultDtype = resultType.getDtype(); Value srcToDtype = @@ -3180,6 +3338,25 @@ class DecomposeAten_IndexPutImpl_HackedTwinOp }; } // namespace +namespace { +// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` +// op. +class DecomposeAten_UnsafeIndexPutHackedTwinOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), + /*unsafe=*/cstFalse); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { @@ -3220,10 +3397,15 @@ class DecomposeAtenToDtypeLayoutOp op, "unimplemented: pinMemory is expected to be false"); } - // TODO: Add support for non-None device arg. + // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: device arg must be None"); + std::string device; + if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) + return rewriter.notifyMatchFailure( + op, "unimplemented: device must be a constant str"); + else if (device != "cpu") + return rewriter.notifyMatchFailure( + op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. @@ -3265,6 +3447,85 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. + +// The logic of this decomposition is totally same with +// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two +// cases are supported: +// 1. inputSize = outputSize +// 2. outputSize = 1 +class DecomposeAtenAdaptiveAvgPool1dOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getSelf(); + std::optional maybeRank = getTensorRank(input); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected input to have a rank"); + } + unsigned rank = *maybeRank; + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 1)); + Value inputSize = rewriter.create(loc, input, sizeDim); + + Value outputShape = op.getOutputSize(); + SmallVector outputShapeSizesTorchInt; + getListConstructElements(outputShape, outputShapeSizesTorchInt); + Value outputSize = outputShapeSizesTorchInt[0]; + + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = rewriter.create(loc, false); + Value constantTrue = rewriter.create(loc, true); + + int64_t outputSizeInt; + if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { + return rewriter.notifyMatchFailure( + op, "the output size of adaptive_pool_1d must be a constant int"); + } + + SmallVector kernelSize; + if (outputSizeInt == 1) { + BaseTensorType inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + kernelSize.push_back( + inputShape[rank - 1] == kUnknownSize + ? inputSize + : rewriter.create( + loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + } else { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + kernelSize.push_back(constantOne); + } + + Value kernelSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + Value paddingSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantZero}); + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. // @@ -3747,21 +4008,6 @@ class DecomposeAtenLiftFreshCopyOp }; } // namespace -namespace { -// Decompose `aten.index.TensorHackedTwin` op into `aten.index.Tensor` op. -class DecomposeAtenIndexTensorHackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getIndices()); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenMseLossOp : public OpRewritePattern { public: @@ -3902,11 +4148,11 @@ class DecomposeAtenRandintOp : public OpRewritePattern { Value low = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - + rewriter.replaceOpWithNewOp( op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); - + return success(); } }; @@ -4096,6 +4342,39 @@ class DecomposeAtenRandnLikeOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenRandOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRandOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto resultType = op.getType().cast(); + + if (!resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + Value noneVal = rewriter.create(loc); + Value low = rewriter.create( + loc, rewriter.getF64FloatAttr((double)0.0)); + Value high = rewriter.create( + loc, rewriter.getF64FloatAttr((double)1.0)); + Value emptyTensor = rewriter.create( + loc, resultType, op.getSize(), /*dtype=*/op.getDtype(), + /*layout=*/op.getLayout(), + /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), + /*memory_format=*/noneVal); + rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, + /*from=*/low, + /*to=*/high, + /*generator=*/noneVal); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: @@ -4159,6 +4438,53 @@ class DecomposeAtenNewEmptyStridedOp }; } // namespace +namespace { +class DecomposeAtenEmptyStridedOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEmptyStridedOp op, + PatternRewriter &rewriter) const override { + SmallVector sizeListInts, strideListInts; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) + return rewriter.notifyMatchFailure( + op, "all size list elements must be constant ints"); + if (!matchPattern(op.getStride(), + m_TorchListOfConstantInts(strideListInts))) + return rewriter.notifyMatchFailure( + op, "all stride list elements must be constant ints"); + + // We only support the cases with default stride values. + // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) + // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and + // stride[2] == 1. + bool isDefaultStride = true; + for (unsigned i = 0; i < strideListInts.size(); i++) { + int64_t defaultStride = 1; + for (unsigned j = i + 1; j < sizeListInts.size(); j++) + defaultStride *= sizeListInts[j]; + if (defaultStride != strideListInts[i]) { + isDefaultStride = false; + break; + } + } + if (!isDefaultStride) + return rewriter.notifyMatchFailure( + op, "only default strides supported for new_empty_strided op"); + + Value noneVal = rewriter.create(op.getLoc()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), + op.getPinMemory(), /*memoryFormat=*/noneVal); + + return success(); + + + } +}; +} // namespace + namespace { class DecomposePrimsSqueezeOp : public OpRewritePattern { public: @@ -4330,7 +4656,6 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "unimplemented: num_classes must be constant"); Value none = rewriter.create(loc); - Value falseValue = rewriter.create(loc, false); // arange tensor auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); @@ -4358,11 +4683,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 - Value si64TypeValue = - Torch::getDtypeIntValueForType(rewriter, loc, si64Type); - Value result = rewriter.create( - loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue, - /*copy=*/falseValue, /*memory_format=*/none); + Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); rewriter.replaceOp(op, result); return success(); } @@ -4645,6 +4966,29 @@ class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp }; } // namespace +namespace { +// Unconditionally decompose `torch.type_as` into `prim.dtype` + +// `torch.to.dtype`. +class DecomposeAtenTypeAsOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTypeAsOp op, + PatternRewriter &rewriter) const override { + auto input = op.getSelf(); + auto other = op.getOther(); + Location loc = op.getLoc(); + + Value targetDtype = rewriter.create(loc, other); + Value nonBlocking = rewriter.create(loc, false); + Value copy = rewriter.create(loc, false); + Value memoryFormat = rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat); + return success(); + } +}; +} // namespace + namespace { // Decompose aten.max_pool2d_with_indices // into aten.max_pool2d @@ -4670,6 +5014,264 @@ class DecomposeAtenMaxPool2dWithIndicesOp }; } // namespace +// AtenIndexTensorOp +namespace { +// The goal of this pattern is to eliminate none index in aten.Index.Tensor's +// `indices` param for the ease of various backend. The detailed steps are: +// 1. reorder input tensor so that the non-none index appears at adjacent +// positions. +// 2. manually generate index tensor with some ops like iota, to replace the +// none index in `indices` +// 3. replace the old aten.Index.Tensor with a new +// aten.Index.Tensor_hacked_twin. +class DecomposeAtenIndexTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO: It might be better to use aten.view op instead of mulitple + // aten.unsqueeze. But currently, torch-to-linalg pass has limited support for + // view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will + // cause relevant e2e tests fail. + static FailureOr + unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, + Value input, int count) { + Location loc = op->getLoc(); + Value constMinusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(-1)); + Value result = input; + while (count--) { + auto unsqzTensorInfo = + unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); + if (failed(unsqzTensorInfo)) { + return failure(); + } + + result = *unsqzTensorInfo; + } + return result; + } + + static Value createIndexToReplaceNone(Operation *op, + PatternRewriter &rewriter, Value input, + int dimInt, int64_t dimSize) { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value none = rewriter.create(loc); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + + auto resultType = ValueTensorType::get( + context, {dimSize}, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + auto end = rewriter.create(loc, input, dim); + auto v = rewriter.create( + loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + return v; + } + + LogicalResult matchAndRewrite(AtenIndexTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return rewriter.notifyMatchFailure(op, + "failed to get elements of `indices`"); + + auto input = op.getSelf(); + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "only input with shape information is supported"); + } + auto inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + auto outputType = op.getType().cast(); + if (!outputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "only output with shape information is supported"); + } + auto outputRank = outputType.getSizes().size(); + + auto isTensor = [](Value v) { + return v.getType().isa(); + }; + + // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin + if (llvm::all_of(indices, isTensor)) { + if (indices.size() == 0) { + return rewriter.notifyMatchFailure( + op, "the indices is empty, it should be folded as a nop"); + } + // By default, we regard the first index type as the list element type. + auto indexElemType = indices[0] + .getType() + .template cast() + .getWithSizesAndDtype(std::nullopt, nullptr); + auto newIndex = rewriter.create( + loc, Torch::ListType::get(indexElemType), indices); + rewriter.replaceOpWithNewOp(op, op.getType(), + input, newIndex); + return success(); + } + + SmallVector indexUsed = + llvm::to_vector(llvm::map_range(indices, isTensor)); + for (int64_t i = indices.size(); i < inputRank; ++i) + indexUsed.emplace_back(false); + bool indexIsConsecutive = true; + int64_t firstUsedIndex = -1; + for (size_t i = 0; i < indices.size(); ++i) { + if (indexUsed[i] && firstUsedIndex == -1) { + firstUsedIndex = i; + } else if (indexUsed[i] && !indexUsed[i - 1]) { + indexIsConsecutive = false; + break; + } + } + + // use aten.permute to reorder the input + Value newInput; + // `dims` stores the mapping from new index to the old index of input + // tensor. + SmallVector dims; + if (!indexIsConsecutive) { + SmallVector dimValues; + SmallVector permutedSizes; + for (int i = 0; i < inputRank; i++) { + if (indexUsed[i]) { + dims.emplace_back(i); + dimValues.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + permutedSizes.emplace_back(inputSizes[i]); + } + } + for (int i = 0; i < inputRank; i++) { + if (!indexUsed[i]) { + dims.emplace_back(i); + dimValues.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + permutedSizes.emplace_back(inputSizes[i]); + } + } + auto dimValueList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), dimValues); + newInput = rewriter.create( + loc, + inputType.getWithSizesAndDtype(permutedSizes, + inputType.getOptionalDtype()), + input, dimValueList); + } else { + newInput = input; + for (int i = 0; i < inputRank; i++) { + dims.emplace_back(i); + } + } + + // manually generate new indices. + SmallVector listElements(inputRank); + + int64_t trailingDimCnt = 0; + int64_t i; + // handle trailing none index. + for (i = inputRank - 1; i >= 0; --i) { + int64_t oldI = dims[i]; + if (indexUsed[oldI]) + break; + Value v = + createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); + auto vInfo = + unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt); + if (failed(vInfo)) { + return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); + } + listElements[i] = *vInfo; + trailingDimCnt++; + } + // handle non-none index in between. + for (; i >= 0; --i) { + int64_t oldI = dims[i]; + if (!indexUsed[oldI]) + break; + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI], + trailingDimCnt); + if (failed(vInfo)) { + return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); + } + listElements[i] = *vInfo; + } + + // handle possible leading none dimensions. + for (; i >= 0; --i) { + int64_t oldI = dims[i]; + if (indexUsed[oldI]) { + return rewriter.notifyMatchFailure( + op, "the indices are still unconsecutive after reordering input " + "tensor"); + } + Value v = + createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); + auto vInfo = + unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i); + if (failed(vInfo)) { + return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); + } + listElements[i] = *vInfo; + } + + auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); + auto newIndexList = rewriter.create( + loc, Torch::ListType::get(listElemType), listElements); + rewriter.replaceOpWithNewOp( + op, op.getType(), newInput, newIndexList); + return success(); + } +}; +} // namespace + +namespace { +// Unconditionally decompose `aten.tile` into `aten.repeat`. +class DecomposeAtenTileOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTileOp op, + PatternRewriter &rewriter) const override { + auto input = op.getSelf(); + auto repeats = op.getDims(); + SmallVector dimsElements; + if (!getListConstructElements(repeats, dimsElements)) { + return rewriter.notifyMatchFailure( + op, "failed to get elements of `dims` param"); + } + auto dimsSize = dimsElements.size(); + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "only support input tensor with shape information"); + } + auto inputRank = inputType.getSizes().size(); + if (dimsSize < inputRank) { + auto constantOne = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + for (auto i = dimsSize; i < inputRank; ++i) { + dimsElements.insert(dimsElements.begin(), constantOne); + } + repeats = rewriter.create( + op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimsElements); + } + rewriter.replaceOpWithNewOp(op, op.getType(), input, + repeats); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -4786,17 +5388,21 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4810,10 +5416,9 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( - patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4822,13 +5427,16 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4848,6 +5456,9 @@ class DecomposeComplexOpsPass DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenMaxPool2dWithIndicesOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 4890c6a8cad9..5efbc69834a7 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -17,8 +17,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "llvm/Support/Debug.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "torch-lower-to-backend-contract" @@ -426,6 +426,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -436,15 +437,19 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -458,9 +463,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -468,11 +474,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -482,6 +490,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index c3e88e1a925d..69c8715442a7 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -18,6 +18,21 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace { + +// calculate: (a + b - 1) // b +// a/b's type should be !torch.int +Value getIntCeilDiv(PatternRewriter &rewriter, Location loc, Value a, Value b) { + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dividend = rewriter.create(loc, a, b); + dividend = rewriter.create(loc, dividend, cstOne); + Value result = rewriter.create(loc, dividend, b); + return result; +} + +} // namespace + namespace { class RecomposeSliceCopy_ : public OpRewritePattern { public: @@ -151,14 +166,26 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindOp + PrimListUnpackOp to select.int - auto unbind = dyn_cast(op.getOperand().getDefiningOp()); - if (!unbind) + auto unbindOp = dyn_cast(op.getOperand().getDefiningOp()); + if (!unbindOp) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); - if (isListPotentiallyMutated(unbind.getResult())) + if (isListPotentiallyMutated(unbindOp.getResult())) return rewriter.notifyMatchFailure( op, "AtenUnbindIntOp result is potentially mutated"); - Value dim = unbind.getDim(); - Value input = unbind.getSelf(); + Location loc = op.getLoc(); + Value dim = unbindOp.getDim(); + Value input = unbindOp.getSelf(); + + // add runtime.assert to check unbind's dim size == numResults + Value totalSize = rewriter.create(loc, input, dim); + Value cstNumResults = rewriter.create( + loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value eqOrNot = rewriter.create(loc, totalSize, cstNumResults); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("unbind's dim size should equal to " + "prim.list_unpack's num results")); + SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to select.int op @@ -170,8 +197,8 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { slices.push_back(newSelect); } rewriter.replaceOp(op, slices); - if (unbind.getResult().use_empty()) - rewriter.eraseOp(unbind); + if (unbindOp.getResult().use_empty()) + rewriter.eraseOp(unbindOp); return success(); } }; @@ -192,10 +219,21 @@ class RecomposeUnbindGetItem : public OpRewritePattern { if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + if (index < 0) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int"); Location loc = op.getLoc(); Value dim = unbind.getDim(); Value input = unbind.getSelf(); + + // add runtime.assert to check: index + Value totalSize = rewriter.create(loc, input, dim); + Value ltOrNot = rewriter.create(loc, op.getIdx(), totalSize); + rewriter.create( + loc, ltOrNot, + rewriter.getStringAttr("index should less than unbind's dim size")); + // rewrite to slice op auto resultTy = op.getResult().getType(); Value newSelect = rewriter.create(loc, resultTy, input, @@ -270,6 +308,9 @@ class RecomposeSplitTensorGetItemOp if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + if (index < 0) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int"); int64_t splitSize; if (!matchPattern(splitTensorOp.getSplitSize(), @@ -279,6 +320,19 @@ class RecomposeSplitTensorGetItemOp "Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int"); Location loc = op.getLoc(); + Value input = splitTensorOp.getSelf(); + Value dim = splitTensorOp.getDim(); + + // add runtime.assert to check rank constraint: index < split_result_size + Value totalSize = rewriter.create(loc, input, dim); + Value splitResultSize = + getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); + Value ltOrNot = + rewriter.create(loc, op.getIdx(), splitResultSize); + rewriter.create( + loc, ltOrNot, + rewriter.getStringAttr("index should less than split_result_size")); + Value step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value start = rewriter.create( @@ -286,8 +340,7 @@ class RecomposeSplitTensorGetItemOp Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize)); Value sliceTensorOp = rewriter.create( - loc, op.getResult().getType(), splitTensorOp.getSelf(), - splitTensorOp.getDim(), start, end, step); + loc, op.getResult().getType(), input, dim, start, end, step); rewriter.replaceOp(op, sliceTensorOp); if (splitTensorOp.getResult().use_empty()) rewriter.eraseOp(splitTensorOp); @@ -318,8 +371,24 @@ class RecomposeSplitTensorListUnpack "Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int"); Location loc = op.getLoc(); - Value step = + Value input = splitTensorOp.getSelf(); + Value dim = splitTensorOp.getDim(); + + // add runtime.assert to check rank constraint + Value totalSize = rewriter.create(loc, input, dim); + Value cstNumResults = rewriter.create( + loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + // assert: numResults == floordiv(totalSize + splitSize - 1, splitSize) + Value splitResultSize = + getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); + Value eqOrNot = + rewriter.create(loc, splitResultSize, cstNumResults); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("numResults should equal to floordiv(totalSize " + "+ splitSize - 1, splitSize)")); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { @@ -329,8 +398,7 @@ class RecomposeSplitTensorListUnpack auto end = rewriter.create( loc, rewriter.getI64IntegerAttr((i + 1) * splitSize)); Value sliceTensorOp = rewriter.create( - loc, resultTy, splitTensorOp.getSelf(), splitTensorOp.getDim(), start, - end, step); + loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); @@ -341,31 +409,125 @@ class RecomposeSplitTensorListUnpack } }; +class RecomposeSplitWithSizesListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps + auto splitOp = + dyn_cast(op.getOperand().getDefiningOp()); + if (!splitOp) { + return rewriter.notifyMatchFailure(op, + "Input is not AtenSplitWithSizesOp"); + } + if (isListPotentiallyMutated(splitOp.getResult())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp result is potentially mutated"); + } + if (isListPotentiallyMutated(splitOp.getSplitSizes())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp's split_sizes is potentially mutated"); + } + auto splitSizesConstruct = + splitOp.getSplitSizes().getDefiningOp(); + if (!splitSizesConstruct) { + return rewriter.notifyMatchFailure( + op, "split_sizes is not from PrimListConstructOp"); + } + + int64_t sumSplitSize = 0; + SmallVector splitSizes; + for (auto operand : splitSizesConstruct.getOperands()) { + int64_t value = -1; + // TODO: support when split_sizes are not constant int + if (!matchPattern(operand, m_TorchConstantInt(&value))) { + return rewriter.notifyMatchFailure( + op, "one of split_sizes is not constant int"); + } + if (value < 0) { + return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0"); + } + sumSplitSize += value; + splitSizes.push_back(value); + } + if (splitSizes.size() != op.getNumResults()) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be same as splitOp result size"); + } + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // add runtime.assert to check rank constraint + Value totalSize = rewriter.create(loc, input, dim); + Value cstSumSplitSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value eqOrNot = + rewriter.create(loc, totalSize, cstSumSplitSize); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("split dim must be sum of split_sizes")); + + // calculate slice op's lower bound and up bound + SmallVector boundaryOfSliceOp(splitSizes.size() + 1, 0); + for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) { + boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; + } + SmallVector slices; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + for (size_t i = 0; i < op.getNumResults(); i++) { + auto resultTy = op.getResult(i).getType(); + auto start = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i])); + auto end = rewriter.create( + loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1]))); + Value sliceTensorOp = rewriter.create( + loc, resultTy, input, dim, start, end, /*step=*/cstOne); + slices.push_back(sliceTensorOp); + } + rewriter.replaceOp(op, slices); + // erase splitOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + class RecomposeChunkListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps - auto chunk = dyn_cast(op.getOperand().getDefiningOp()); - if (!chunk) + auto chunkOp = dyn_cast(op.getOperand().getDefiningOp()); + if (!chunkOp) return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp"); - if (isListPotentiallyMutated(chunk.getResult())) + if (isListPotentiallyMutated(chunkOp.getResult())) return rewriter.notifyMatchFailure( op, "AtenChunkOp result is potentially mutated"); - Value dim = chunk.getDim(); - Value input = chunk.getSelf(); - Value chunks = chunk.getChunks(); - Location loc = chunk.getLoc(); + Value dim = chunkOp.getDim(); + Value input = chunkOp.getSelf(); + Value chunks = chunkOp.getChunks(); + Location loc = chunkOp.getLoc(); Value totalSize = rewriter.create(loc, input, dim); - // chunkSize = floordiv(totalSize + chunks - 1, chunks) + Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks); + + // add runtime.assert to check chunks == NumResults + Value cstNumResults = rewriter.create( + loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value eqOrNot = rewriter.create(loc, chunks, cstNumResults); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr( + "chunks should equal to prim.list_unpack's num results")); + Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value dividend = rewriter.create(loc, totalSize, chunks); - dividend = rewriter.create(loc, dividend, cstOne); - Value chunkSize = rewriter.create(loc, dividend, chunks); - SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to slice op with @@ -383,13 +545,13 @@ class RecomposeChunkListUnpack : public OpRewritePattern { end = rewriter.create(loc, nextIdx, chunkSize); } Value sliceTensorOp = rewriter.create( - loc, resultTy, input, dim, start, end, cstOne); + loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); // erase chunkOp if no user left - if (chunk.getResult().use_empty()) - rewriter.eraseOp(chunk); + if (chunkOp.getResult().use_empty()) + rewriter.eraseOp(chunkOp); return success(); } }; @@ -453,6 +615,7 @@ class RecomposeComplexOpsPass patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 5109a8c5735e..cfa4e40ee908 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -62,7 +62,8 @@ class RefinePublicReturnPass OpBuilder builder(returnOp); for (auto operand : returnOp.getOperands()) { Value newOperand = operand; - // Look through TensorStaticInfoCastOp's and CopyToNonValueTensorOp's. + // Look through TensorStaticInfoCastOp's, CopyToNonValueTensorOp's, and + // DerefineOp's. for (;;) { if (auto cast = newOperand.getDefiningOp()) { newOperand = cast.getOperand(); @@ -76,6 +77,8 @@ class RefinePublicReturnPass if (users.size() != 1) break; newOperand = copy.getOperand(); + } else if (auto derefine = newOperand.getDefiningOp()) { + newOperand = derefine.getOperand(); } else { break; } diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 8e6b5888bb02..290beb1da7c9 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -176,10 +176,17 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union or !torch.union is the type used - // for (optional) `Scalar` inputs. At compile time, such inputs will usually - // be resolved to an `int` or a `float` so we need to derefine to match the - // library function signature. + // The type `!torch.number` can be an `int`, `float`, or `complex`. + // TODO: Add a new type `Torch::ComplexType` to handle the complex case. + if (desiredType.isa() && + operandType.isa()) { + return b.create(loc, desiredType, operand).getResult(); + } + + // !torch.union is the type used for optional + // `Scalar` inputs. At compile time, such inputs will usually be + // resolved to an `int`, `float`, or `None` so we need to derefine + // to match the library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { return containedType diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 43f2b22a3d66..6860fbb6eee8 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -171,6 +171,11 @@ class RefineNumToTensorScalarOpType return rewriter.notifyMatchFailure( op, "`PrimNumToTensorScalarOp` already has a dtype"); + if (op.getA().getType().isa()) { + return rewriter.notifyMatchFailure(op, + "`PrimNumToTensorScalarOp`'s input " + "should have concrete Scalar Type."); + } Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); auto impliedTypeFromInputType = originalResultType.cast() diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index f4aafe773923..751d9d790caa 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -100,11 +100,11 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Char; if (type.isa()) { mlir::Type complexElemType = type.cast().getElementType(); - if (complexElemType.isF32()) + if (complexElemType.isF16()) return torch_upstream::ScalarType::ComplexHalf; - if (complexElemType.isF64()) + if (complexElemType.isF32()) return torch_upstream::ScalarType::ComplexFloat; - if (complexElemType.isF128()) + if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexDouble; } llvm::report_fatal_error("unhandled type for getScalarTypeForType"); @@ -144,11 +144,11 @@ Torch::getTypeForScalarType(MLIRContext *context, case torch_upstream::ScalarType::Char: return mlir::IntegerType::get(context, 8, signedness); case torch_upstream::ScalarType::ComplexHalf: - return mlir::ComplexType::get(Float32Type::get(context)); + return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: - return mlir::ComplexType::get(Float64Type::get(context)); + return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: - return mlir::ComplexType::get(Float128Type::get(context)); + return mlir::ComplexType::get(Float64Type::get(context)); case torch_upstream::ScalarType::Undefined: return failure(); default: @@ -241,8 +241,9 @@ bool Torch::isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, - PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp>(op); + AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, + AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, + AtenViewAsComplexOp, AtenViewAsRealOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 1f7f4e8f8294..a286d5bbd7a9 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -18,13 +18,17 @@ set(LinkedLibs ) if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND LinkedLibs ChloPasses) + list(APPEND LinkedLibs + StablehloOps + ) endif() add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp BackendTypeConversionPasses.cpp Passes.cpp + ConvertCustomQuantOp.cpp + UnpackQuantTensor.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp VerifyStablehloBackendContract.cpp diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp new file mode 100644 index 000000000000..175a3cd14804 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -0,0 +1,226 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getName().str() != "quant.matmul_rhs_group_quant") { + return failure(); + } + Location loc = op->getLoc(); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { + return failure(); + } + + // get inputs: lhs, rhsQuant, scales, zps + Value lhs = adaptor.getOperands()[0]; + auto lhsType = lhs.getType().cast(); + if (!lhsType) { + return failure(); + } + auto lhsShape = lhsType.getShape(); + int lhsReductDimSize = lhsShape.back(); + + Value rhsQuant = adaptor.getOperands()[1]; + auto rhsType = rhsQuant.getType().cast(); + if (!rhsType) { + return failure(); + } + auto rhsShape = rhsType.getShape(); + int rhsReductDimSize = rhsShape.back(); + Type rhsElementType = rhsType.getElementType(); + + Value scales = adaptor.getOperands()[2]; + Value zps = adaptor.getOperands()[3]; + Value unpackedTypeWidth = adaptor.getOperands()[4]; + Value groupSize = adaptor.getOperands()[5]; + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto castOp = dyn_cast(operand.getDefiningOp()); + if (!castOp) { + return failure(); + } + auto constOp = + dyn_cast(castOp.getOperand(0).getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + + int gs; + if (failed(getConstantIntegerFromDefiningOp(groupSize, gs))) { + return failure(); + } + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) { + return failure(); + } + if (unpackedBitWidth != + static_cast(rhsElementType.getIntOrFloatBitWidth())) { + return failure(); + } + + // get outputs + Type newResultType = getTypeConverter()->convertType(op.getType(0)); + auto resultType = newResultType.cast(); + if (!resultType) { + return failure(); + } + auto resultShape = resultType.getShape(); + Type elementType = resultType.getElementType(); + + // expand lhs + std::vector lhsExpandedShape = {lhsShape[0], lhsShape[1], + lhsReductDimSize / gs, gs}; + RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); + SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; + Value lhsExpanded = rewriter.create( + loc, lhsExpandedType, lhs, lhsReassociation); + + // expand rhs + std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs}; + RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); + SmallVector rhsReassociation = {{0}, {1, 2}}; + Value rhsExpanded = rewriter.create( + loc, rhsExpandedType, rhsQuant, rhsReassociation); + Value cst0 = rewriter.create( + loc, FloatAttr::get(elementType, 0.0)); + + Value emptyDequant = rewriter.create( + loc, rhsExpandedShape, elementType); + SmallVector dynDims; + for (int i = 0; i < lhsType.getRank(); i++) { + if (lhsType.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, lhs, i)); + } + } + Value empty = rewriter.create( + loc, resultShape, elementType, dynDims); + Value output = rewriter.create( + loc, cst0, empty).getResult(0); + + AffineExpr d0, d1, d2, d3, d4; + bindDims(getContext(), d0, d1, d2, d3, d4); + auto c0 = rewriter.getAffineConstantExpr(0); + auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()); + auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext()); + auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext()); + auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext()); + auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext()); + SmallVector dqIndexingMaps = {map, map1, map1, map}; + SmallVector matIndexingMaps = {map2, map3, map4}; + + SmallVector dequantIteratorTypes(3, utils::IteratorType::parallel); + SmallVector matmulIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction + }; + + Value rhsDequant = + rewriter + .create( + loc, emptyDequant.getType(), + ValueRange{rhsExpanded, scales, zps}, emptyDequant, + /*indexingMaps=*/dqIndexingMaps, + /*iteratorTypes=*/dequantIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value w = args[0], scale = args[1], zeroPoint = args[2]; + Value extw = b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create(loc, rewriter.getF16Type(), extw); + Value shifted = b.create(loc, fp_extw, zeroPoint); + Value dqw = b.create(loc, shifted, scale); + b.create(loc, dqw); + }) + .getResult(0); + + Value matmulDequant = + rewriter + .create( + loc, output.getType(), + ValueRange{lhsExpanded, rhsDequant}, output, + /*indexingMaps=*/matIndexingMaps, + /*iteratorTypes=*/matmulIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], out = args[2]; + Value pd = b.create(loc, l, r); + Value ac = b.create(loc, pd, out); + b.create(loc, ac); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, matmulDequant); + return success(); + } +}; +} // namespace + +namespace { +class ConvertCustomQuantOpPass + : public TorchConversion::ConvertCustomQuantOpBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp new file mode 100644 index 000000000000..25f325399f12 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -0,0 +1,143 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class UnpackQuantizedMatmulWeights + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp, + PatternRewriter &rewriter) const override { + if (!constOp->hasOneUse()) + return failure(); + + OpOperand *use = constOp.getResult().use_begin().getOperand(); + auto op = dyn_cast(use->getOwner()); + if (!op) { + return failure(); + } + if (op.getName().str() != "quant.matmul_rhs_group_quant") { + return failure(); + } + + if (use->getOperandNumber() != 1) { + return failure(); + } + + Value rhs = op.getOperand(1); + Value bitWidth = op.getOperand(4); + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto constOp = dyn_cast(operand.getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) + return failure(); + + auto rhsType = rhs.getType().dyn_cast(); + if (!rhsType) + return failure(); + + if (!rhsType.hasDtype()) + return failure(); + + Type dType = rhsType.getDtype(); + int dTypeWidth = dType.getIntOrFloatBitWidth(); + if (dTypeWidth == unpackedBitWidth) + return failure(); + + if (!rhsType.hasSizes()) + return failure(); + + SmallVector tensorShape(rhsType.getSizes()); + if (tensorShape.back() == kUnknownSize) + return failure(); + int packRatio = dTypeWidth / unpackedBitWidth; + + tensorShape[tensorShape.size() - 1] *= packRatio; + Type unpackedElementType; + if (dType.isSignedInteger()) + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true); + else + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false); + ValueTensorType newRhsType = ValueTensorType::get( + rewriter.getContext(), tensorShape, unpackedElementType); + + auto elements = constOp.getValueAttr().dyn_cast(); + if (!elements) + return failure(); + + auto attrType = RankedTensorType::get(tensorShape, unpackedElementType); + + // TODO: Materialize IR that does the conversion from quantized type to + // pure integer type which relys on constant evaluation in backends + auto data = elements.getRawData(); + std::vector newData(data.size() * packRatio, + APInt(unpackedBitWidth, 0)); + for (int i = 0, e = data.size(); i < e; ++i) { + auto el = data[i]; + char mask = (1 << unpackedBitWidth) - 1; + for (int b = 0; b < packRatio; b++) { + newData[i * packRatio + b] = + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); + mask = mask << unpackedBitWidth; + } + } + rewriter.replaceOpWithNewOp( + constOp, newRhsType, + DenseElementsAttr::get(attrType, ArrayRef(newData))); + return success(); + } +}; +} // namespace + +namespace { +class UnpackQuantTensorPass + : public TorchConversion::UnpackQuantTensorBase { + using UnpackQuantTensorBase::UnpackQuantTensorBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createUnpackQuantTensorPass() { + return std::make_unique(); +} diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 43b45d32eaff..1d67bdfe236c 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/InitAll.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" @@ -20,15 +21,12 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "mhlo/transforms/passes.h" -#endif - void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); registry.insert(); + mlir::func::registerInlinerExtension(registry); } void mlir::torch::registerAllPasses() { @@ -38,12 +36,4 @@ void mlir::torch::registerAllPasses() { mlir::torch::registerConversionPasses(); mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::TMTensor::registerPasses(); - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::mhlo::registerSymbolicShapeOptimizationPass(); - mlir::mhlo::registerStablehloLegalizeToHloPass(); - mlir::mhlo::registerChloLegalizeToHloPass(); - mlir::mhlo::registerHloLegalizeToLinalgPass(); - mlir::mhlo::registerTestUnfuseBatchNormPass(); -#endif // TORCH_MLIR_ENABLE_STABLEHLO } diff --git a/python/torch_mlir/_dynamo_fx_importer.py b/python/torch_mlir/_dynamo_fx_importer.py index 84219cf84599..15efda2d9b52 100644 --- a/python/torch_mlir/_dynamo_fx_importer.py +++ b/python/torch_mlir/_dynamo_fx_importer.py @@ -147,6 +147,8 @@ def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str: if dtype == torch.quint8: return "!torch.quint8" if dtype == torch.complex64: + return "complex" + if dtype == torch.complex128: return "complex" @@ -205,9 +207,9 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType: torch.float64: 7, # torch.complex_half 8 - torch.complex32: - 9, torch.complex64: + 9, + torch.complex128: 10, torch.bool: 11, diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index f1314d25c06f..310ad6b73731 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -27,13 +27,7 @@ def get_module_name_for_debug_dump(module): class TorchMlirCompilerError(Exception): - def __init__(self, value: str): - super().__init__() - self.value = value - - def __str__(self) -> str: - return self.value - + pass def run_pipeline_with_repro_report(module, pipeline: str, diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index 3293c6e2f663..81a8383949c7 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -69,8 +69,13 @@ add_library(torch_mlir_ltc_backend SHARED backend_impl.cpp dynamic_ir.cpp mlir_node.cpp + tensor.cpp ops/device_data.cpp ops/generic.cpp + ops/index.cpp + ops/ivalue.cpp + ops/split.cpp + ops/unbind_int.cpp utils/jit_utils.cpp utils/tensor_utils.cpp ) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 0182952f898a..4823b4929ab1 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "torch-mlir-c/Registration.h" #include "torch-mlir-c/Transforms.h" #include "mlir-c/IR.h" @@ -205,13 +206,46 @@ void TorchMlirLoweringContext::AssignOutputOp( const Output& output, torch::jit::Value* op) { PRINT_FUNCTION(); - // TODO (antoniojkim): Do we need this? - // auto torch_mlir_node = - // NodeCast(output.node, output.node->op()); - // if (!torch_mlir_node->getPythonStacktrace().empty()) { - // op->node()->s_( - // c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace()); - // } + auto torch_mlir_node = + NodeCast(output.node, output.node->op()); + + std::vector source_files, functions; + std::vector line_numbers; + const auto& metadata = torch_mlir_node->metadata(); + const auto& frames = metadata.frame_info; + if (!frames.empty()) { + static std::vector g_roots = + string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); + + std::for_each(frames.rbegin(), frames.rend(), + [&](const torch::lazy::SourceLocation& location) { + functions.push_back(location.function); + line_numbers.push_back(location.line); + + std::string file_name = location.file; + for (const std::string& root : g_roots) { + if (startswith(file_name, root)) { + // location.file starts with root, strip it off + file_name = file_name.substr(root.size()); + break; + } + } + source_files.push_back(file_name); + }); + + if (!source_files.empty()) { + op->node()->ss_( + c10::Symbol::attr("source_files"), source_files); + op->node()->ss_( + c10::Symbol::attr("functions"), functions); + op->node()->is_( + c10::Symbol::attr("line_numbers"), line_numbers); + } + } + auto scope = ::c10::Symbol::scope(metadata.scope); + op->node()->setScope( + c10::make_intrusive()->push(scope)); + emitted_outputs_[output] = std::move(op); } @@ -424,7 +458,11 @@ const std::string TorchMlirComputation::to_string() const { *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; - mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss); + + // Setup flags for MLIR serialization. + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false); + mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss); return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index 28152bbb517c..d06ad5963919 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -10,6 +10,8 @@ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// +#include +#include #include #include #include @@ -28,12 +30,62 @@ #include #include +#include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" #include "ops/to_copy.h" +#include "ops/unbind_int.h" +#include "ops/split.h" +#include "ops/index.h" +#include "ops/ivalue.h" #include "utils/exception.h" #include "utils/sys_utils.h" +namespace { +at::Tensor to_meta(const at::Tensor& tensor) { + // undefined tensors can't be converted to the meta device, since they don't + // have sizes/strides + if (!tensor.defined()) + return tensor; + auto out = at::native::empty_strided_meta_symint( + tensor.sym_sizes(), tensor.sym_strides(), + /*dtype=*/c10::make_optional(tensor.scalar_type()), + /*layout=*/c10::make_optional(tensor.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + // needs to handle wrapped numbers, so dtype promotion works properly. + if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + out.unsafeGetTensorImpl()->set_wrapped_number(true); + } + return out; +} + +c10::optional to_meta(const c10::optional& tensor) { + if (tensor.has_value()) { + return to_meta(*tensor); + } + return c10::nullopt; +} + +std::vector to_meta(at::ITensorListRef t_list) { + std::vector outs; + outs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outs.push_back(to_meta(tensor)); + } + return outs; +} + +c10::List> to_meta(const c10::List>& t_list) { + c10::List> outs; + outs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outs.push_back(to_meta(tensor)); + } + return outs; +} +} // namespace + namespace torch { namespace lazy { @@ -92,32 +144,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { } // namespace -// at::Tensor LazyNativeFunctions::bernoulli( -// const at::Tensor& self, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // return torch::lazy::CreateAtenFromLtcTensor( -// // torch::lazy::bernoulli(self_tensor)); -// } - -// at::Tensor& LazyNativeFunctions::bernoulli_( -// at::Tensor& self, double p, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // torch::lazy::bernoulli_(self_tensor, p); -// // return self; -// } - // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. at::Tensor LazyNativeFunctions::clone( @@ -301,62 +327,217 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty_symint( - at::SymIntArrayRef sym_size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { - // TODO: support this directly - auto size = C10_AS_INTARRAYREF_SLOW(sym_size); - const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); - at::TensorOptions options = at::TensorOptions() - .device(c10::Device(device_type)) - .layout(layout) - .pinned_memory(pin_memory) - .dtype(dtype); - auto x_result = at::empty(size, options, memory_format); - auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); - // See Note [Lazy Tensor Functionalization] - if (c10::impl::tls_local_dispatch_key_set().excluded_.has( - c10::DispatchKey::Functionalize)) { - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. - return tensor; - } else { - auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); - return wrapped; +at::Tensor LazyNativeFunctions::_unsafe_view( + const at::Tensor& self, at::IntArrayRef size) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); +} + +at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return at::functionalization::functionalize_aten_op::call(self); +} + +std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + + LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); + if (!node) { + auto self_meta = to_meta(self); + auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); + + std::vector shapes; + for (const auto & shape : out_meta) { + shapes.push_back( + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) + ); + } + + if(torch::lazy::symbolicShapeEnabled()){ + std::vector inputs = { self, dim }; + const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, std::move(shapes)); + CacheNode(node); + } + + std::vector result; + for (size_t i = 0; i < node->num_outputs(); ++i) { + result.push_back( + torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) + ) + ); + } + + return result; +} + +std::vector LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + + LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); + if (!node) { + auto self_meta = to_meta(self); + auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim); + + std::vector shapes; + for (const auto & shape : out_meta) { + shapes.push_back( + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) + ); + } + + if(torch::lazy::symbolicShapeEnabled()){ + std::vector inputs = { self, split_sizes, dim }; + const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes)); + CacheNode(node); } + + std::vector result; + for (size_t i = 0; i < node->num_outputs(); ++i) { + result.push_back( + torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) + ) + ); + } + + return result; } -at::Tensor LazyNativeFunctions::empty_strided( - at::IntArrayRef size, at::IntArrayRef stride, - c10::optional dtype, c10::optional layout, - c10::optional device, c10::optional pin_memory) { +std::vector LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty_symint( - c10::fromIntArrayRefSlow(size), - dtype, layout, device, pin_memory, c10::nullopt); - return t.as_strided(size, stride, /*storage_offset=*/0); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); + if (!node) { + auto self_meta = to_meta(self); + auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim); + + std::vector shapes; + for (const auto & shape : out_meta) { + shapes.push_back( + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) + ); + } + const size_t num_outputs = shapes.size(); + + if(torch::lazy::symbolicShapeEnabled()){ + std::vector inputs = { self, split_size, dim }; + const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs); + CacheNode(node); + } + + std::vector result; + for (size_t i = 0; i < node->num_outputs(); ++i) { + result.push_back( + torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) + ) + ); + } + return result; } -at::Tensor& -LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { +at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List> & indices) { TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + + std::vector values; + for (const auto & it : indices) { + c10::optional tensor = it; + LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + } + + auto list = MakeNode(values); + + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); - torch::lazy::Value constant = - torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - value, self_tensor->shape(), self_tensor->GetDevice()); - self_tensor->SetInPlaceIrValue(std::move(constant)); - return self; + if (!node) { + auto self_meta = to_meta(self); + auto indices_meta = to_meta(indices); + auto out_meta = at::meta::index(self_meta, indices_meta); + + std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + TORCH_INTERNAL_ASSERT(shapes.size() == 1); + if(torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = { self, indices }; + const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, std::move(shapes)); + CacheNode(node); + } + + auto result = torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + + return result; } -at::Tensor LazyNativeFunctions::_unsafe_view( - const at::Tensor& self, at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); + + std::vector indices_vector; + for (const auto & it : indices) { + c10::optional tensor = it; + LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + } + + auto indices_list = MakeNode(indices_vector); + + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate); + + if (!node) { + auto self_meta = to_meta(self); + auto indices_meta = to_meta(indices); + auto values_meta = to_meta(values); + + auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate); + + std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + TORCH_INTERNAL_ASSERT(shapes.size() == 1); + if(torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = { self, indices, values }; + const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes)); + CacheNode(node); + } + + auto result = torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + + return result; } // This is needed by the torch.tensor constructor. @@ -390,9 +571,18 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( c10::optional layout, c10::optional device, c10::optional pin_memory) { - return at::functionalization:: - functionalize_aten_op_symint::call( - self, size, stride, dtype, layout, device, pin_memory); + if (!device || device->type() == c10::DeviceType::Lazy) { + return at::functionalization::functionalize_aten_op_symint< + ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, + device, pin_memory); + } + // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") + // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + at::Tensor t = at::empty_symint( + size, (dtype ? dtype : c10::optional(self.scalar_type())), + (layout ? layout : c10::optional(self.layout())), device, + pin_memory, c10::nullopt); + return t.as_strided_symint(size, stride, /*storage_offset=*/0); } at::Tensor LazyNativeFunctions::narrow_copy_symint( @@ -476,4 +666,4 @@ at::Tensor& LazyNativeFunctions::logsumexp_out( void InitializeAtenBindings() {} } // namespace lazy -} // namespace torch +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index 8009e677a6c6..e4b75e5d53d1 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -116,7 +116,40 @@ torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( } auto graph = function->graph(); auto listnode = - graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list)); + graph->insertNode(graph->createList(c10::TensorType::get(), tensor_list)); + return {listnode->output()}; +} + +/////////////////////////////////////////////////////////////////////////////// +// TorchMlirOptionalTensorList +/////////////////////////////////////////////////////////////////////////////// + +OpKind TorchMlirOptionalTensorList::ClassOpKind() { + // Note: this OpKind is separate from ltc_ops.h since it would be a circular + // import otherwise + static const OpKind tensor_list_opkind = + OpKind::Get("lazy_tensors::optional_tensor_list"); + return tensor_list_opkind; +} + +TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values) + : TorchMlirNode( + /*op=*/TorchMlirOptionalTensorList::ClassOpKind(), + /*operands=*/values, + /*shapes=*/std::vector(), + /*num_outputs=*/1, + /*hash_seed=*/kHashSeed) {} + +torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector tensor_list; + CHECK(!operands().empty()); + for (const torch::lazy::Output& operand : operands()) { + tensor_list.emplace_back(loctx->GetOutputOp(operand)); + } + auto graph = function->graph(); + auto listnode = + graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list)); return {listnode->output()}; } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index fcabf0e5a0b0..dbf3117dbb13 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -91,5 +91,18 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode { TorchMlirLoweringContext* loctx) const override; }; +// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent +// optional tensors, so the output type for this op is !torch.list>. +struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode { + static OpKind ClassOpKind(); + + TorchMlirOptionalTensorList() = delete; + TorchMlirOptionalTensorList(OpList values); + + torch::lazy::TorchMlirOpVector Lower( + TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; +}; + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index 6bed4513dbce..c15efb7a7a57 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -43,7 +43,12 @@ TorchMlirOpVector LowerTorchMlirBuiltin( for (auto arg : arguments) { torch::jit::Value* value = arg.value(dummy_graph); if (value->type()->kind() == c10::TypeKind::ListType) { - value->setType(c10::ListType::create(c10::TensorType::get())); + auto list_element_type = value->type()->cast()->getElementType(); + if (list_element_type->cast()) { + value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get()))); + } else { + value->setType(c10::ListType::create(c10::TensorType::get())); + } } } @@ -55,8 +60,17 @@ TorchMlirOpVector LowerTorchMlirBuiltin( CHECK(sv); TorchMlirOpVector results; - if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { - // Op returns multiple values. + if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) { + // Unpack dynamic multi-output operations like aten::split with Tensor[] output type. + // This is required to have consistent input types for multi-output node consumers. + torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size()); + function->graph()->insertNode(node); + for (const auto & output : node->outputs()) { + results.push_back(output); + } + } else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { + // Op returns multiple values and the number of outputs is static and defined + // by the operation schema. const auto tuple_call_result = sv->asTuple({}, *function); for (const auto& tuple_component : tuple_call_result) { auto tuple_component_sv = diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp new file mode 100644 index 000000000000..34af3e590162 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp @@ -0,0 +1,99 @@ +//===- index.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "index.h" + +namespace torch { +namespace lazy { + +IndexTensor::IndexTensor(const torch::lazy::Value& self, + const torch::lazy::Value& indices, + std::vector&& shapes) + : torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(), + OpList{self, indices}, std::move(shapes), + /* num_outputs */ 1, torch::lazy::MHash()) {} + +std::string IndexTensor::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + return ss.str(); +} + +bool IndexTensor::CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& indices) const { + return false; +} + +TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve(2); + kwarguments.reserve(0); + + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + + torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin( + function, op().op, shapes(), arguments, kwarguments); + TORCH_CHECK_EQ(index_out.size(), 1); + + return index_out; +} + +IndexPut::IndexPut(const torch::lazy::Value& self, + const torch::lazy::Value& indices, + const torch::lazy::Value& values, bool accumulate, + std::vector&& shapes) + : torch::lazy::TorchMlirNode( + IndexPut::ClassOpKind(), OpList{self, indices, values}, + std::move(shapes), + /* num_outputs */ 1, torch::lazy::MHash(accumulate)), + accumulate(accumulate) {} + +std::string IndexPut::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + ss << ", accumulate=" << accumulate; + return ss.str(); +} + +bool IndexPut::CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& indices, + const torch::lazy::Value& values, + bool accumulate) const { + return false; +} + +TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve(4); + kwarguments.reserve(0); + + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back("accumulate", accumulate); + + torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin( + function, op().op, shapes(), arguments, kwarguments); + + TORCH_CHECK_EQ(index_out.size(), 1); + + return index_out; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/index.h b/python/torch_mlir/csrc/base_lazy_backend/ops/index.h new file mode 100644 index 000000000000..e97760fc37ad --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/index.h @@ -0,0 +1,58 @@ +//===- index.h ------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + +class IndexTensor : public torch::lazy::TorchMlirNode { + public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::aten::index); + } + + IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices, + std::vector&& shapes); + + std::string ToString() const override; + + bool CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& indices) const; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; +}; + +class IndexPut : public torch::lazy::TorchMlirNode { + public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::aten::index_put); + } + + IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices, + const torch::lazy::Value& values, bool accumulate, + std::vector&& shapes); + + std::string ToString() const override; + + bool CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& indices, + const torch::lazy::Value& values, bool accumulate) const; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; + + bool accumulate; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp new file mode 100644 index 000000000000..0653e4467313 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp @@ -0,0 +1,36 @@ +//===- ivalue.cpp +//----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "ivalue.h" + +#include + +namespace torch { +namespace lazy { + +IValueConstant::IValueConstant(const c10::IValue& value) + : torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{}, + std::vector{}, + /* num_outputs */ 1, torch::lazy::MHash()), + value(value) {} + +std::string IValueConstant::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + return ss.str(); +} + +TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + return {loctx->graph()->insertConstant(value)}; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h b/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h new file mode 100644 index 000000000000..8a8453d3a347 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h @@ -0,0 +1,37 @@ +//===- index.h ------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + +// IValueConstant IR Node represents a `prim::Constant` constructed with IValue +// parameter which is helpfull in different usecases when we need custom +// native ops lowering to torch-mlir IR nodes. +class IValueConstant : public torch::lazy::TorchMlirNode { + public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::prim::Constant); + } + + IValueConstant(const c10::IValue& value); + + std::string ToString() const override; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; + + c10::IValue value; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp new file mode 100644 index 000000000000..d20d298dfdd0 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp @@ -0,0 +1,101 @@ +//===- split.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "split.h" + +namespace torch { +namespace lazy { + +SplitWithSizesCopy::SplitWithSizesCopy( + const torch::lazy::Value& self, const ::std::vector& split_sizes, + const int64_t& dim, std::vector&& shapes) + : torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(), + OpList{ self }, std::move(shapes), + split_sizes.size() /* num_outputs */, + torch::lazy::MHash(split_sizes, dim)), + split_sizes(split_sizes), dim(dim) {} + +std::string SplitWithSizesCopy::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + ss << ", split_sizes=" << split_sizes; + ss << ", dim=" << dim; + return ss.str(); +} + +bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self, + const ::std::vector& split_sizes, + const int64_t& dim) const { + return false; +} + +TorchMlirOpVector +SplitWithSizesCopy::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve(3); + kwarguments.reserve(0); + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back("split_sizes", split_sizes); + arguments.emplace_back("dim", dim); + + torch::lazy::TorchMlirOpVector split_with_sizes_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, + kwarguments); + + return split_with_sizes_copy_out; +} + +SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self, + const torch::lazy::Value& split_size, + const int64_t& dim, + std::vector&& shapes, + const size_t num_outputs) + : torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(), + OpList{ self, split_size }, std::move(shapes), + num_outputs, torch::lazy::MHash(dim)), + dim(dim) {} + +std::string SplitCopyTensor::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + ss << ", dim=" << dim; + return ss.str(); +} + +bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& split_size, + const int64_t& dim) const { + return false; +} + +TorchMlirOpVector +SplitCopyTensor::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve(3); + kwarguments.reserve(0); + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back("dim", dim); + + torch::lazy::TorchMlirOpVector split_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, + kwarguments); + return split_copy_out; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/split.h b/python/torch_mlir/csrc/base_lazy_backend/ops/split.h new file mode 100644 index 000000000000..8593d5628c2e --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/split.h @@ -0,0 +1,65 @@ +//===- split.h ------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + +class SplitWithSizesCopy : public torch::lazy::TorchMlirNode { +public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::aten::split_with_sizes_copy); + } + + SplitWithSizesCopy(const torch::lazy::Value& self, + const ::std::vector& split_sizes, + const int64_t& dim, + std::vector&& shapes); + + std::string ToString() const override; + + bool CanBeReused(const torch::lazy::Value& self, + const ::std::vector& split_sizes, + const int64_t& dim) const; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; + + std::vector split_sizes; + int64_t dim; +}; + +class SplitCopyTensor : public torch::lazy::TorchMlirNode { +public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::aten::split_copy); + } + + SplitCopyTensor(const torch::lazy::Value& self, + const torch::lazy::Value& split_size, const int64_t& dim, + std::vector&& shapes, + const size_t num_outputs = 1); + + std::string ToString() const override; + + bool CanBeReused(const torch::lazy::Value& self, + const torch::lazy::Value& split_size, + const int64_t& dim) const; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; + + int64_t dim; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp new file mode 100644 index 000000000000..a5526366cd2b --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp @@ -0,0 +1,54 @@ +//===- unbind_int.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "unbind_int.h" + +namespace torch { +namespace lazy { + +UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, + std::vector&& shapes) + : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self }, + std::move(shapes), + self.shape().size(dim), /* num_outputs */ + torch::lazy::MHash(dim)), + dim(dim) {} + +std::string UnbindCopyInt::ToString() const { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + ss << ", dim=" << dim; + return ss.str(); +} + +bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self, + const int64_t& dim) const { + return false; +} + +TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const { + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve(2); + kwarguments.reserve(0); + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + arguments.emplace_back("dim", dim); + + torch::lazy::TorchMlirOpVector unbind_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, + kwarguments); + + return unbind_copy_out; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h b/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h new file mode 100644 index 000000000000..766752c16517 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h @@ -0,0 +1,37 @@ +//===- unbind_int.h ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + +class UnbindCopyInt : public torch::lazy::TorchMlirNode { +public: + static torch::lazy::OpKind ClassOpKind() { + return torch::lazy::OpKind(at::aten::unbind_copy); + } + + UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, + std::vector&& shapes); + + std::string ToString() const override; + + bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const; + + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; + + int64_t dim; +}; + +} // namespace lazy +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 97d35cdcd3b4..043094c67e0a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include @@ -17,47 +18,127 @@ namespace torch { namespace lazy { -// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future. +// TODO(henrytu): Upstream these shape inference functions to PyTorch in the +// future. -std::vector -compute_shape_div(const at::Tensor& self, const at::Scalar& other) { +std::vector compute_shape_add(const at::Tensor& self, + const at::Scalar& other, + const at::Scalar& alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector -compute_shape_mse_loss_backward( - const at::Tensor& grad_output, - const at::Tensor& self, - const at::Tensor& target, - int64_t reduction) { + +std::vector compute_shape_sub(const at::Tensor& self, + const at::Scalar& other, + const at::Scalar& alpha) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_div(const at::Tensor& self, + const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_max_pool3d_with_indices( + const at::Tensor& self, at::IntArrayRef kernel_size, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool ceil_mode) { + auto in_sizes = self.sizes().vec(); + std::vector dhw(3, 0); + std::vector paddings = padding.vec(); + std::vector ksizes = kernel_size.vec(); + std::vector dilations = dilation.vec(); + std::vector strides = stride.vec(); + TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ", + in_sizes); + TORCH_CHECK(kernel_size.size() == 3 && + stride.size() == 3 && + padding.size() == 3 && + dilation.size() == 3, "max_pool3d requires 3D operands, but got ", + kernel_size, stride, padding, dilation); + int64_t batch = in_sizes[0]; + int64_t channel = in_sizes[1]; // NCDHW + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html + for (auto i = 0UL; i<3; ++i) { + double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] * + (ksizes[i] - 1) - 1) / (double)strides[i] + 1; + if (ceil_mode) + dhw[i] = (int64_t)std::ceil(out_size); + else + dhw[i] = (int64_t)std::floor(out_size); + } + auto out_sizes = {batch, channel, dhw[0], dhw[1], dhw[2]}; + // `with_indices` returns output and index Tensor + return {Shape(self.scalar_type(), out_sizes), Shape(at::kLong, out_sizes)}; +} + +std::vector compute_shape_max_pool3d_with_indices_backward( + const at::Tensor & grad_output, const at::Tensor & self, + at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor & indices) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector -compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { +std::vector compute_shape_mse_loss_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Tensor& target, int64_t reduction) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_mul(const at::Tensor& self, + const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_var( const at::Tensor& self, at::OptionalIntArrayRef dim, - c10::optional correction, bool keepdim) { + const c10::optional & correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } std::vector compute_shape_hardtanh( - const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val -) { + const at::Tensor& self, const at::Scalar& min_val, + const at::Scalar& max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_where( - const at::Tensor & condition, - const at::Tensor & self, - const at::Tensor & other) { +std::vector compute_shape_hardtanh_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Scalar& min_val, const at::Scalar& max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_where(const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + // There are cases like - + // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, + // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. + // So the result tensor would the biggest of all the three operands. + auto condition_meta = at::native::empty_strided_meta_symint( + condition.sym_sizes(), condition.sym_strides(), + /*dtype=*/c10::make_optional(condition.scalar_type()), + /*layout=*/c10::make_optional(condition.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto self_meta = at::native::empty_strided_meta_symint( + self.sym_sizes(), self.sym_strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto other_meta = at::native::empty_strided_meta_symint( + other.sym_sizes(), other.sym_strides(), + /*dtype=*/c10::make_optional(other.scalar_type()), + /*layout=*/c10::make_optional(other.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::where(condition_meta, self_meta, other_meta); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + std::vector compute_shape_bucketize( const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, bool right) { @@ -65,50 +146,64 @@ std::vector compute_shape_bucketize( return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_copy( - const at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { +std::vector compute_shape_copy(const at::Tensor& self, + const at::Tensor& src, + bool non_blocking) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_floor_divide( + const at::Tensor& self, const at::Tensor& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_fmod(const at::Tensor& self, + const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( - const at::Tensor& input, - const c10::optional& weight, - const c10::optional& bias, - int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps) { - - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); + const at::Tensor& input, const c10::optional& weight, + const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, + int64_t group, double eps) { + + TORCH_CHECK(input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); // A separate mean and var needs to be kept for each group per N. - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{N, group}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{N, group}); - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{N, group}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{N, group}); return shapes; } +std::vector compute_shape_im2col( + const at::Tensor& self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + + auto self_meta = at::native::empty_strided_meta_symint( + self.sym_sizes(), self.sym_strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + + auto out_meta = at::im2col(self_meta, kernel_size, dilation, padding, stride); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + std::vector compute_shape_native_group_norm_backward( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& mean, - const at::Tensor& rstd, - const c10::optional& weight, - int64_t N, int64_t C, int64_t HxW, - int64_t group, ::std::array output_mask) { - - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); + const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, + const at::Tensor& rstd, const c10::optional& weight, int64_t N, + int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + + TORCH_CHECK(input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); @@ -116,15 +211,180 @@ std::vector compute_shape_native_group_norm_backward( int64_t num_features = input.size(1); // `weight` and `bias` are vectors of length C (number of channels)` - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{num_features}); return shapes; } +std::vector compute_shape_remainder( + const at::Tensor& self, const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_uniform( + const at::Tensor& self, double from, double to, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_normal_functional( + const at::Tensor& self, double mean, double std, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_multinomial( + const at::Tensor& self, int64_t num_samples, bool replacement, + c10::optional generator) { + // Input tensor can be either 1D or 2D. The last dim of output + // should be 'num_samples'. So the output shape can be either + // [num_samples] or [m, num_samples]. + // Output type can only be long tensor. + auto ishape = self.sizes().vec(); + ishape.back() = num_samples; + return {Shape(at::kLong, ishape)}; +} + +std::vector compute_shape_eye( + int64_t n, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_eye( + int64_t n, int64_t m, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& end, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), + pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, step, dtype, layout, + c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_full( + at::IntArrayRef size, const at::Scalar& fill_value, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_ones( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_zeros( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty_strided( + at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_fill(const at::Tensor& self, + const at::Scalar& value) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_fill(const at::Tensor& self, + const at::Tensor& value) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_randn( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_randint( + int64_t high, at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_randint( + int64_t low, int64_t high, at::IntArrayRef size, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_resize( + const at::Tensor & self, at::IntArrayRef size, + c10::optional memory_format) { + return {Shape(self.scalar_type(), size.vec())}; +} + +std::vector compute_shape_bernoulli( + const at::Tensor& self, const at::Tensor &p, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_scalar_tensor( + const at::Scalar & s, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; +} -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp new file mode 100644 index 000000000000..82ae6cc27f4a --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp @@ -0,0 +1,29 @@ +//===- tensor.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include + +#include "tensor.h" + +namespace torch { +namespace lazy { + +at::Tensor CreateFunctionalizedAtenFromLtcTensor( + const LazyTensorPtr& ltc_tensor) { + at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); + if (!c10::impl::tls_is_dispatch_key_excluded( + c10::DispatchKey::Functionalize) && + !at::functionalization::impl::isFunctionalTensor(tensor)) { + return at::functionalization::impl::to_functional_tensor(tensor); + } + return tensor; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.h b/python/torch_mlir/csrc/base_lazy_backend/tensor.h new file mode 100644 index 000000000000..4e39dd095aa5 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.h @@ -0,0 +1,24 @@ +//===- tensor.h -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace torch { +namespace lazy { + +// Ops like torch.ones/zeros etc. which produce new tensor as an output +// should have explicit tensor functinoalization. Otherwise we can get +// unfanctionalized primitives or in the worst case if we apply inplace +// operations to unfunctionalized tensor it won't be captured in LTC graph. +TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h index c4c2ea79d6ab..281331992e49 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h @@ -22,6 +22,24 @@ std::string string_join(const std::vector& v, const std::string& delimiter) { return joined.str(); } +inline std::vector string_split( + const std::string& str, + const std::string& sep +) { + std::vector tokens; + std::size_t pos1 = str.find_first_not_of(sep); + while (pos1 != std::string::npos) { + std::size_t pos2 = str.find_first_of(sep, pos1); + if (pos2 == std::string::npos) { + tokens.push_back(str.substr(pos1)); + pos1 = pos2; + } else { + tokens.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = str.find_first_not_of(sep, pos2 + 1); + } + } + return tokens; +} /* * Returns true if str starts with prefix diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h index 6cb47895af92..5ae14904909a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h @@ -14,6 +14,14 @@ static T GetEnv(const std::string& name, const T& default_value = T(0)) { return T(std::atoi(env)); } +static std::string GetEnvString(const std::string& name, const std::string& default_value) { + const char* env = std::getenv(name.c_str()); + if (!env) { + return default_value; + } + return std::string(env); +} + static bool GetEnvBool(const char* name, bool defval) { const char* env = std::getenv(name); if (env == nullptr) { diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 3bc8465eafc1..1064a3d1e1ac 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -28,6 +28,11 @@ using namespace torch::lazy; namespace torch { namespace lazy { +/// Returns true if a string begins with another. +inline bool beginswith(const std::string& s, const std::string& t) { + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; +} + struct ReferenceLazyBackendDeviceType : public BackendDeviceType { ReferenceLazyBackendDeviceType(c10::DeviceType device_type) : device_type_(device_type) {} @@ -104,7 +109,25 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // // JIT Execution adopted from: // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp - torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), ""); + std::shared_ptr graph = mlir_computation->graph(); + for (auto* node : graph->nodes()) { + // Convert any lazy devices to cpu devices to ensure + // that the values are actually computed + if (node->outputs().size() == 1 && + node->output()->type()->kind() == + c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK(node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } + } + } + + torch::jit::GraphExecutor graph_executor(graph, ""); std::vector stack; for (const auto& argument : arguments) { const auto mlir_data = diff --git a/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index b2ff81c67a22..c575d9dd299b 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch/csrc/jit/python/pybind.h" +#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" #include @@ -25,6 +26,7 @@ namespace py = pybind11; namespace { bool verbose = sys_util::GetEnv("VERBOSE", false); +bool ir_debug = sys_util::GetEnv("LTC_IR_DEBUG", false); struct NoGilSection { NoGilSection() : state(PyEval_SaveThread()) {} @@ -52,6 +54,11 @@ void Initialize() { if (verbose) { std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; } + + if (ir_debug) { + FLAGS_torch_lazy_ir_debug = true; + std::cout << "Enabled lazy tensor IR debugging." << std::endl; + } } /** diff --git a/python/torch_mlir/dialects/TorchBinding.td b/python/torch_mlir/dialects/TorchBinding.td index 2de5dcd5615f..e2dbe0f14162 100644 --- a/python/torch_mlir/dialects/TorchBinding.td +++ b/python/torch_mlir/dialects/TorchBinding.td @@ -10,7 +10,6 @@ #ifndef PYTHON_BINDINGS_TORCH_OPS #define PYTHON_BINDINGS_TORCH_OPS -include "mlir/Bindings/Python/Attributes.td" include "torch-mlir/Dialect/Torch/IR/TorchOps.td" #endif // PYTHON_BINDINGS_TORCH_OPS diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 5917dba72302..cbd62af70899 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -19,7 +19,10 @@ # ============================================================================== # TODO: upstream this -def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): +def _embedding_bag_helper(weight: List[int], indices: List[int], + offsets: List[int], include_last_offset: bool, + mode: int, per_sample_weights: Optional[List[int]], + padding_idx: Optional[int]): assert len(weight) == 2 assert len(indices) == 1 assert len(offsets) == 1 @@ -35,7 +38,10 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[i if mode == 1: offset2bag_shape.append(0) else: - offset2bag_shape = upstream_shape_functions._copy(indices) + if per_sample_weights is None and padding_idx is None: + offset2bag_shape = [0] + else: + offset2bag_shape = upstream_shape_functions._copy(indices) bag_size_shape = upstream_shape_functions._copy(offsets) @@ -209,6 +215,10 @@ def aten〇type_as〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int]: return upstream_shape_functions.unary(input) +def aten〇native_dropout〡shape(input: List[int], p: float, train: Optional[bool]) -> Tuple[List[int], List[int]]: + shape = upstream_shape_functions.unary(input) + return shape, shape + def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]: return upstream_shape_functions.unary(self) @@ -284,6 +294,9 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1 def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, input_scale: float = 1) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -302,9 +315,18 @@ def aten〇any〡shape(self: List[int]) -> List[int]: def aten〇all〡shape(self: List[int]) -> List[int]: return [] +def aten〇min〡shape(self: List[int]) -> List[int]: + return [] + +def aten〇min〇other〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇max〡shape(self: List[int]) -> List[int]: return [] +def aten〇max〇other〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇sum〡shape(self: List[int], dtype: Optional[int] = None) -> List[int]: return [] @@ -384,6 +406,9 @@ def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype) +def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype) + def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -449,11 +474,22 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: for i in range(tensor_dim): out.append(self[i] * repeats[i + leading_rank]) return out - + def aten〇repeat_interleave〇Tensor〡shape(repeats: List[int], output_size: Optional[int] = None) -> List[int]: assert output_size is not None return [output_size] +@check_shape_function([ + Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length + Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length +]) +def aten〇tile〡shape(self: List[int], dims: List[int]) -> List[int]: + dims_length = len(dims) + self_length = len(self) + if dims_length < self_length: + dims = [1] * (self_length - dims_length) + dims + return aten〇repeat〡shape(self, dims) + def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: return upstream_shape_functions.unary(self) @@ -481,10 +517,10 @@ def aten〇_unsafe_view〡shape(self: List[int], size: List[int]) -> List[int]: def aten〇resize_〡shape(self: List[int], size: List[int], memory_format: Optional[int] = None) -> List[int]: return size -def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]: +def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]: return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) -def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: +def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: maxpool2d = indices = upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) return maxpool2d, indices @@ -538,7 +574,57 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd else: return [nbatch, nInputPlane, outputHeight, outputWidth] -def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: +# TODO: This should be upstreamed. +# See https://github.com/pytorch/pytorch/pull/76889 for an example. +def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): + assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" + kL = kernel_size[0] + + assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" + dL = kL if len(stride) == 0 else stride[0] + + assert len(padding) == 1, "avg_pool1d: padding must be a single int" + padL = padding[0] + + dilationL = 1 + + assert len(input) == 2 or len(input) == 3 + + nbatch = input[-3] if len(input) == 3 else 1 + nInputPlane = input[-2] + inputLength = input[-1] + + outputLength = upstream_shape_functions.pooling_output_shape( + inputLength, kL, padL, dL, dilationL, ceil_mode) + + if len(input) == 2: + return [nInputPlane, outputLength] + else: + return [nbatch, nInputPlane, outputLength] + +# TODO: This should be upstreamed. +# See https://github.com/pytorch/pytorch/pull/76889 for an example. +def adaptive_avg_pool1d(self: List[int], out: List[int]): + assert len(out) == 1 + assert len(self) == 2 or len(self) == 3 + + for i in range(len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(len(self) - 1): + shape.append(self[i]) + shape.append(out[0]) + + return shape + +def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: + return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) + +def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: + return adaptive_avg_pool1d(self, output_size) + +def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: @@ -570,13 +656,17 @@ def aten〇ones〡shape(size: List[int], dtype: Optional[int] = None, layout: Op def aten〇empty〇memory_format〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return size - +def aten〇empty_strided〡shape(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size def aten〇full〡shape(size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size def aten〇full_like〡shape(self: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self +def aten〇new_full〡shape(self: List[int], size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size + def aten〇zeros_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -622,6 +712,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self +def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size + @not_present_in_registry def aten〇bernoulli〇float〡shape(self: List[int], p: float = 0.5, generator: Any = None) -> List[int]: return self @@ -710,6 +803,9 @@ def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇__and__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇__or__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇minimum〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -795,8 +891,8 @@ def aten〇tensor〇bool〡shape(t: bool, dtype: Optional[int] = None, device: O def aten〇scalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return [] -@check_dtype_function([Invocation(-1), Invocation(-1.0)]) -def aten〇scalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +@check_dtype_function([Invocation(-1), Invocation(-1.0)]) +def aten〇scalar_tensor〡dtype(s: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: return dtype else: @@ -887,10 +983,22 @@ def aten〇view_as_complex〡dtype(self_rank_dtype: Tuple[int, int]) -> int: else: assert False, "Unsupported dtype" -def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: +def aten〇view_as_real〡shape(self: List[int]) -> List[int]: + return self + [2] +def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex64: + return torch.float + elif self_dtype == torch.complex128: + return torch.double + else: + assert False, "Unsupported dtype" + + +def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) -def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> List[int]: +def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: @@ -924,9 +1032,17 @@ def aten〇sort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descend def aten〇narrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]: return upstream_shape_functions.slice(self, dim, start, start + length, 1) +# This shape function is a little hacky, because we don't know the start index which is determined by a tensor param. +def aten〇narrow〇Tensor〡shape(self: List[int], dim: int, start: List[int], length: int) -> List[int]: + self[dim] = length + return self + def aten〇slice_scatter〡shape(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return self +def aten〇masked_scatter〡shape(self: List[int], mask: List[int], source: List[int]) -> List[int]: + return self + def aten〇select〇int〡shape(self: List[int], dim: int, index: int) -> List[int]: return upstream_shape_functions.select(self, dim, index) @@ -955,10 +1071,12 @@ def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: - return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode) + return _embedding_bag_helper(weight, indices, offsets, include_last_offset, + mode, per_sample_weights, padding_idx) def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights: Optional[List[int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[List[int], List[int], List[int], List[int]]: - return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode) + return _embedding_bag_helper(weight, indices, offsets, include_last_offset, + mode, per_sample_weights, padding_idx) @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. @@ -1043,13 +1161,15 @@ def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> Li return broadcasted_shape first_index_tensor_location = -1 + last_used_index_location = -1 index_tensors_are_together = True for e, index_tensor_shape in enumerate(indices): if index_tensor_shape is not None: if first_index_tensor_location == -1: first_index_tensor_location = e - elif e - first_index_tensor_location != 1: + elif e - last_used_index_location != 1: index_tensors_are_together = False + last_used_index_location = e if not index_tensors_are_together: return broadcasted_shape + unused_dim_sizes @@ -1136,6 +1256,15 @@ def hacky_get_unknown_dimension_size(): def aten〇bincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]: return [hacky_get_unknown_dimension_size()] +def aten〇nonzero〡shape(self: List[int]) -> List[int]: + return [hacky_get_unknown_dimension_size(), len(self)] + +def aten〇masked_select〡shape(self: List[int], mask: List[int]) -> List[int]: + return [hacky_get_unknown_dimension_size()] + +def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int = -1) -> List[int]: + return [size, len(self)] + def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) @@ -1272,6 +1401,18 @@ def _get_dtype_of_floating_point_op(input_dtype: int) -> int: return input_dtype return torch.float32 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=[ + torch.float64, torch.float32, torch.float16, torch.bfloat16, + torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool +])) +def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_complex_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float + else: + return torch.double + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1348,7 +1489,7 @@ def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, threshold: Union[int, float, complex] = 20) -> int: self_rank, self_dtype = self_rank_dtype if is_integer_dtype(self_dtype): return self_dtype @@ -1391,13 +1532,23 @@ def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) +def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) +def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) -def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: +def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1438,21 +1589,21 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) -def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) -def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) -def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float, complex]] = None, max: Optional[Union[int, float, complex]] = None) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 @@ -1464,7 +1615,7 @@ def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Option return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) -def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float, complex] = 0) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1505,6 +1656,11 @@ def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: b input_rank, input_dtype = input_rank_dtype return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇native_dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: Optional[bool]) -> Tuple[int, int]: + input_rank, input_dtype = input_rank_dtype + return input_dtype, torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1516,7 +1672,7 @@ def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], imp return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) -def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1575,14 +1731,14 @@ def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) -def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex], max_val: Union[int, float, complex]) -> int: grad_output_rank, grad_output_dtype = grad_output_rank_dtype if is_integer_dtype(grad_output_dtype): return torch.float32 return grad_output_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) -def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex] = -1, max_val: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.uint8, torch.bool] return self_dtype @@ -1602,6 +1758,11 @@ def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], ind self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_index_put_invocations) +def aten〇_unsafe_index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_index_put_invocations) def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: self_rank, self_dtype = self_rank_dtype @@ -1635,7 +1796,7 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap return input_dtype @check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) -def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int: grad_output_rank, grad_output_dtype = grad_output_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [grad_output_rank, self_rank] @@ -1655,12 +1816,12 @@ def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, return input_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) -def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) -def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1678,12 +1839,12 @@ def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) -def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int: +def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) -def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]: +def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 @@ -1697,6 +1858,11 @@ def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 4, dtype=dtype, device=torch.device("cpu")), 0, ZeroDTensorWithDtype(dtype=torch.int64, device=torch.device("cpu")), 1) for dtype in _SORTED_TORCH_TYPES]) +def aten〇narrow〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start_rank_dtype: Tuple[int, int], length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1759,6 +1925,11 @@ def aten〇repeat_interleave〇Tensor〡dtype(repeats_rank_dtype: Tuple[int, int repeats_rank, repeats_dtype = repeats_rank_dtype return repeats_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1])) +def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1808,7 +1979,13 @@ def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, ind @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES]) -def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1862,7 +2039,7 @@ def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output return promoted_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) -def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1896,6 +2073,12 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇rand〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1932,7 +2115,7 @@ def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function([Invocation(-1), Invocation(-1.0)]) -def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: +def prim〇abs〇Scalar〡dtype(a: Union[int, float, complex]) -> int: return get_dtype_of_scalar(a) @check_dtype_function(_check_tensors_with_the_same_dtype( @@ -1973,7 +2156,7 @@ def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -1983,13 +2166,13 @@ def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -2003,7 +2186,7 @@ def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -2014,7 +2197,7 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool -@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: _, query_dtype = query_rank_dtype return query_dtype @@ -2030,7 +2213,7 @@ def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -2052,7 +2235,7 @@ def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function([ @@ -2061,7 +2244,7 @@ def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[in Invocation(0, 0.0), # int, float Invocation(0, 0), # int, int ]) -def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: +def aten〇add〡dtype(a: Union[int, float, complex], b: Union[int, float, complex]) -> int: ranks: List[Optional[int]] = [None, None] dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] return promote_dtypes(ranks, dtypes) @@ -2086,7 +2269,7 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) @@ -2099,7 +2282,15 @@ def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: +def aten〇__or__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] @@ -2280,7 +2471,7 @@ def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[in return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] @@ -2291,7 +2482,7 @@ def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty # https://github.com/pytorch/pytorch/issues/100921 # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function(_check_two_tensor_op(tensor_device="cpu", input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}, threshold=0)) -def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype grad_output_rank, grad_output_dtype = grad_output_rank_dtype assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" @@ -2377,7 +2568,7 @@ def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) ]) -def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> int: input_rank, input_dtype = input_rank_dtype return input_dtype @@ -2388,7 +2579,7 @@ def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) ]) -def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> int: input_rank, input_dtype = input_rank_dtype return input_dtype @@ -2462,6 +2653,14 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype return torch.int64 return torch.float64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu"))) +def aten〇nonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=5, tensor_device=torch.device("cpu"))) +def aten〇nonzero_static〡dtype(self_rank_dtype: Tuple[int, int], size: int, fill_value: int = -1) -> int: + return torch.int64 + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width @@ -2475,7 +2674,7 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, 3, dtype=torch.float32))]) -def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype mat1_rank, mat1_dtype = mat1_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype @@ -2519,7 +2718,7 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float32))]) -def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype @@ -2545,7 +2744,7 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float32))]) -def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype @@ -2559,7 +2758,7 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2568,7 +2767,7 @@ def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2576,7 +2775,7 @@ def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2584,7 +2783,7 @@ def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2596,7 +2795,7 @@ def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2605,26 +2804,22 @@ def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) -def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) -@check_dtype_function([ - Invocation(2.0, TensorOfShape(3, 4, dtype=torch.float64)), - Invocation(2.0, TensorOfShape(3, 4, dtype=torch.bfloat16)), - Invocation(2, TensorOfShape(4, dtype=torch.int32))]) -def aten〇pow〇Scalar〡dtype(self: Union[int, float], exponent_rank_dtype: Tuple[int, int]) -> int: - exp_rank, exp_dtype = exponent_rank_dtype - ranks: List[Optional[int]] = [exp_rank, None] - dtypes = [exp_dtype, get_dtype_of_scalar(self)] +def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int: + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [None, exponent_rank] + dtypes = [get_dtype_of_scalar(self), exponent_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) -def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(exponent)] @@ -2633,7 +2828,7 @@ def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponen @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0)) -def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex] = 0.01) -> int: self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool ranks: List[Optional[int]] = [self_rank, None] @@ -2643,10 +2838,21 @@ def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: dtypes = [self_dtype, negative_slope_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, alpha=1, scale=1, input_scale=2) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, alpha=1.0, scale=1.0, input_scale=2.0)) +def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1, scale: Union[int, float, complex] = 1, input_scale: Union[int, float, complex] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + param_dtypes = [get_dtype_of_scalar(p) for p in [alpha, scale, input_scale]] + if any([is_float_dtype(d) for d in param_dtypes]): + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2663,7 +2869,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu")), ErrorInvocation( TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.bfloat16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"))]) -def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: batch1_rank, batch1_dtype = batch1_rank_dtype batch2_rank, batch2_dtype = batch2_rank_dtype assert batch1_dtype not in [torch.bool, torch.float16] @@ -2689,7 +2895,7 @@ def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) -def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other: Union[int, float, complex]) -> int: if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): return torch.int64 return torch.float32 @@ -2698,7 +2904,7 @@ def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: U Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) -def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2708,7 +2914,7 @@ def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], se Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) -def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype ranks: List[Optional[int]] = [None, other_rank] dtypes = [get_dtype_of_scalar(self), other_dtype] @@ -2760,6 +2966,13 @@ def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normaliz result_dtype = torch.float64 return input_dtype, input_dtype, result_dtype +# note: one_hot doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, num_classes=2, tensor_device="cpu", error_types={torch.complex128, torch.complex64, torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool})) +def aten〇one_hot〡dtype(self_rank_dtype: Tuple[int, int], num_classes: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype == torch.int64 + return torch.int64 + @check_dtype_function( [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), @@ -2800,7 +3013,7 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified Invocation(end=0, dtype=torch.float16), # Dtype specified Invocation(end=0, dtype=torch.int16)]) # Dtype specified -def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〡dtype(end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2814,7 +3027,7 @@ def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, l ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified -def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〇start〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2830,7 +3043,7 @@ def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, floa ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified -def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〇start_step〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], step: Union[int, float, complex] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2859,6 +3072,18 @@ def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = def aten〇sum〇dim_IntList〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: return aten〇sum〡dtype(self_rank_dtype, dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇prod〇dim_int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, @@ -2884,11 +3109,24 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim return self_dtype return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇min〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return aten〇minimum〡dtype(self_rank_dtype, other_rank_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_two_tensor_op()) +def aten〇max〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return aten〇maximum〡dtype(self_rank_dtype, other_rank_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: return aten〇max〡dtype(self_rank_dtype) @@ -2921,7 +3159,7 @@ def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int: return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @@ -2933,7 +3171,7 @@ def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int: return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) @@ -2951,7 +3189,7 @@ def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) -def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float, complex] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if dtype is not None: @@ -3016,7 +3254,7 @@ def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = Invocation([1], 0.0, dtype=torch.int32), Invocation([1], 0.0, dtype=torch.float16), Invocation([1], 0.0, dtype=torch.complex64)]) -def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: return dtype fill_value_dtype = get_dtype_of_scalar(fill_value) @@ -3048,13 +3286,30 @@ def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[ self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇empty_strided〡dtype(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) -def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.complex64)) +def aten〇new_full〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype @@ -3129,7 +3384,7 @@ def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_ _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) -def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: +def prims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: return dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + @@ -3188,7 +3443,7 @@ def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Opt return dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) -def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]: +def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if self_dtype == torch.complex64: @@ -3229,7 +3484,10 @@ def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - return input_dtype + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype @check_dtype_function( [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), @@ -3265,7 +3523,7 @@ def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: assert False, "Unexpected dtype!" @check_dtype_function([Invocation(0), Invocation(0.0)]) -def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int: +def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float, complex]) -> int: return get_dtype_of_scalar(a) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + @@ -3392,14 +3650,14 @@ def main(args): using namespace mlir; StringRef mlir::torch::Torch::getAbstractInterpLibrary() {{ -#ifndef _MSC_VER +#if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Woverlength-strings" #endif // clang-format off return {asm}; // clang-format on -#ifndef _MSC_VER +#if defined(__clang__) #pragma clang diagnostic pop #endif }}""") diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 3cfc4a24aa74..74eb520e22d4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -6,6 +6,7 @@ import inspect import re from typing import List, Optional, Union, Any, Dict +import codecs import torch @@ -63,7 +64,7 @@ def get_priority_of_dtype(dtype: int) -> int: return 11 assert False, "Cannot determine priority of dtype" -def get_dtype_of_scalar(scalar: Union[int, float]) -> int: +def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor # with dtype that corresponds to Python's `float`. @@ -234,10 +235,22 @@ def generate_library(functions: Dict[str, Any]) -> str: # defined symbols. Since all of our shape functions conveniently have # a `〇` in them, we replace the torch namespace with our prefix. E.g.: # __torch__.aten〇add〇Scalar -> __torch_mlir_shape_fn.aten〇add〇Scalar - asm = re.sub(r"__torch__\.([^.(]+)\\E3\\80\\87([^.(]+)\\E3\\80\\A1([^.(\"]+)", - r"__torch_mlir_\3_fn.\1\\E3\\80\\87\2", + + # Encoding for: 〇 + circle = r"\\E3\\80\\87" + # Encoding for: 〡 + line = r"\\E3\\80\\A1" + name = r"[^.(]+" + # Sometimes PyTorch will insert namespaces to the function name in + # the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}` + # The extra namespaces are not part of the abstract interpretation + # function name, so here we simply drop the extra namespaces. + namespace = fr"(?:{name}\.)" + + asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"', + fr'@"__torch_mlir_\3_fn.\1{circle}\2"', asm) # Put the `〇` back to a regular `.`. - asm = asm.replace("\\E3\\80\\87", ".") + asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".") return asm diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 0396df1a0081..2291f27e32f2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -32,8 +32,14 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: # testing against the real ops, and tuples work fine in all # the places this kicks in (e.g. conv dilations -- we aren't # mutating those lists). - default_debug = arg["default_debug"].replace( - '[', '(').replace(']', ')') + default_list = arg["default_debug"] + # (,) is not a valid empty tuple contruction in Python, so + # we must handle the emtpy case separately. + if default_list == "[]": + default_debug = "()" + else: + default_debug = default_list.replace( + "[", "(").replace("]", ",)") elif arg["pytype"] == "str": default_debug = repr(arg["default_debug"]).replace("'", '"') else: @@ -43,7 +49,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: def _pytype_to_fn_pytype_common(pytype: str) -> str: if "number" in pytype: - return pytype.replace("number", "Union[int, float]") + return pytype.replace("number", "Union[int, float, complex]") # `torch.device` is lowercase. if pytype == "Device": return "device" @@ -191,9 +197,13 @@ def _get_function_signature(self, function_kind: str, def_name = "〇".join(mlir_op_name.split(".")) def_name += f"〡{function_kind}" parameter_decls = list(map(parameter_decl_builder, self.arguments)) + parameter_decls = list(filter(None, parameter_decls)) ret_decls = list(map(ret_decl_builder, self.returns)) + ret_decls = list(filter(None, ret_decls)) parameters = ", ".join(parameter_decls) result = ", ".join(ret_decls) + if len(ret_decls) == 0: + result = "None" if len(ret_decls) >= 2: result = f"Tuple[{result}]" @@ -279,7 +289,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: return "" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - return "None" + return "" return self._get_function_signature( "has_value_semantics", parameter_decl_builder, ret_decl_builder) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 007df85d11eb..95f8d68cd2d6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -242,15 +242,18 @@ def emit_with_mutating_variants(key, **kwargs): for key in [ "aten::tanh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sign : (Tensor) -> (Tensor)", + "aten::sgn : (Tensor) -> (Tensor)", "aten::hardsigmoid : (Tensor) -> (Tensor)", "aten::hardswish : (Tensor) -> (Tensor)", "aten::erf : (Tensor) -> (Tensor)", + "aten::erfinv : (Tensor) -> (Tensor)", "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", @@ -289,7 +292,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", @@ -311,17 +316,18 @@ def emit_with_mutating_variants(key, **kwargs): # variants. emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - + emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") + emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") @@ -334,19 +340,29 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") + emit("aten::view_as_real : (Tensor) -> (Tensor)") + + # Ops with dynamic number of outputs + emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") + emit("aten::split_copy.Tensor : (Tensor, int, int) -> (Tensor[])") + emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])") # Random number generation emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)") emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::random : (Tensor, Generator?) -> (Tensor)") + emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") @@ -355,6 +371,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") emit_with_mutating_variants( "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") + emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") @@ -389,6 +406,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) + emit( + "aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)", + ) emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) @@ -401,9 +421,30 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) + emit( + "aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" + ) + emit( + "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + ) + emit( + "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" + ) + emit( + "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" + ) emit( "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) + emit( + "aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) emit( "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" ) @@ -415,7 +456,14 @@ def emit_with_mutating_variants(key, **kwargs): ) emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") + emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") @@ -426,6 +474,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") @@ -443,11 +492,21 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") + emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)") + emit("aten::nonzero : (Tensor) -> (Tensor)") + emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])") + emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)") + emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)") + emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)") + emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") + emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") + emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") @@ -463,6 +522,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") @@ -471,6 +532,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") + emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") @@ -478,6 +540,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") + emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") @@ -486,6 +549,8 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True) + emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)") @@ -497,7 +562,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") - emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True) + emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") @@ -508,22 +573,30 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::numel : (Tensor) -> (int)") emit("aten::repeat : (Tensor, int[]) -> (Tensor)") emit("aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)") + emit("aten::tile : (Tensor, int[]) -> (Tensor)") emit("aten::reshape : (Tensor, int[]) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") + emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") + emit("aten::prod.dim_int : (Tensor, int, bool, int?) -> (Tensor)") emit("aten::max : (Tensor) -> (Tensor)") + emit("aten::max.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amax : (Tensor, int[], bool) -> (Tensor)") + emit("aten::min : (Tensor) -> (Tensor)") + emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True) - emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") + emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") - emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True) + emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") @@ -547,11 +620,15 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::numpy_T : (Tensor) -> (Tensor)") emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") + emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") @@ -568,6 +645,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)") + emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") @@ -594,10 +673,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True) emit("aten::insert.t : (t[], int, t) -> ()") emit("aten::ne.int_list : (int[], int[]) -> (bool)") - emit("aten::any.bool : (bool[]) -> (bool)") + emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") + emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") @@ -628,17 +708,18 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") - emit("aten::add.float_int : (float, int) -> (float)") + emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) - emit("aten::mul.float : (float, float) -> (float)") + emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) - emit("aten::neg.float : (float) -> (float)") + emit("aten::neg.float : (float) -> (float)", has_folder=True) emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) @@ -659,7 +740,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)", has_folder=True) - emit("aten::add : (Scalar, Scalar) -> (Scalar)") + emit("aten::add : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True) emit("aten::sqrt.int : (int) -> (float)", has_folder=True) @@ -669,6 +750,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)") + emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) emit("aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)") @@ -685,6 +767,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)") + emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") # ========================================================================== diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index e0420022d58a..afac7b164b36 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -11,6 +11,8 @@ #include "function_importer.h" #include "ivalue_importer.h" +#include + #include #include @@ -55,11 +57,12 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, case ScalarType::QUInt8: return torchMlirTorchQUInt8TypeGet(context); case ScalarType::ComplexHalf: - return mlirComplexTypeGet(mlirF32TypeGet(context)); + return mlirComplexTypeGet(mlirF16TypeGet(context)); case ScalarType::ComplexFloat: + return mlirComplexTypeGet(mlirF32TypeGet(context)); + case ScalarType::ComplexDouble: return mlirComplexTypeGet(mlirF64TypeGet(context)); - // Cannot support ScalarType::ComplexDouble because there is no MLIR C API - // to generate F128 types. + default: { return {nullptr}; } @@ -407,15 +410,53 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc, MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, torch::jit::Node *node) { - auto flc = node->sourceRange().file_line_col(); - if (flc) { + MlirLocation loc = mlirLocationUnknownGet(context); + + if (node->hasAttribute(c10::Symbol::attr("source_files"))) { + const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files")); + const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers")); + const auto &functions = node->ss(c10::Symbol::attr("functions")); + + // Chain a sequence of calls to construct single MlirLocation. + for (const auto i : c10::irange(sourceFiles.size())) { + MlirLocation newLoc = mlirLocationNameGet( + context, toMlirStringRef(functions[i]), + mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]), + lineNumbers[i], + 0 /* column is not available */ + )); + loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); + } + if (sourceFiles.size() == 1) { + // Somehow a callstack depth of 1... + // Disambiguate function name from scope name below. + loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); + } + } else if (auto flc = node->sourceRange().file_line_col()) { const std::string &file = std::get<0>(*flc); int line = std::get<1>(*flc); int col = std::get<2>(*flc); - return mlirLocationFileLineColGet(context, toMlirStringRef(file), line, - col); + loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); } - return mlirLocationUnknownGet(context); + + std::string locationName; + auto scopeName = node->scopeName(); + if (!scopeName.empty()) { + locationName = scopeName; + } + + if (const c10::FunctionSchema *schema = node->maybeSchema()) { + if (!locationName.empty()) { + locationName += "/"; + } + locationName += schema->operator_name().name; + } + + if (!locationName.empty()) { + loc = mlirLocationNameGet(context, toMlirStringRef(locationName), loc); + } + + return loc; } std::vector diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 23c727405a60..1b9dbb0d2c51 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -163,7 +163,6 @@ def invoke(*args): "func.func(convert-math-to-llvm)", # Handle some complex mlir::math ops (e.g. atan2) "convert-math-to-libm", - "convert-linalg-to-llvm", "expand-strided-metadata", "finalize-memref-to-llvm", "lower-affine", diff --git a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py deleted file mode 100644 index 6a36dd196386..000000000000 --- a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ /dev/null @@ -1,50 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -from torch_mlir.ir import * -from torch_mlir.passmanager import * -from torch_mlir.compiler_utils import run_pipeline_with_repro_report - -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( - RefBackendLinalgOnTensorsBackend, -) - -from .abc import StablehloBackend - -__all__ = [ - "LinalgOnTensorsStablehloBackend", -] - - -class LinalgOnTensorsStablehloBackend(StablehloBackend): - """Main entry-point for the linalg-on-tensors based StableHLO backend. - - This currently uses the linalg-on-tensors RefBackend for actual execution. - """ - - def __init__(self): - super().__init__() - self.refbackend = RefBackendLinalgOnTensorsBackend() - - def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the StableHLO backend contract. - - Args: - imported_module: The MLIR module consisting of funcs in the StableHLO - dialect. - Returns: - An opaque, backend specific compiled artifact object that can be - passed to `load`. - """ - run_pipeline_with_repro_report( - imported_module, - "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,mhlo-test-unfuse-batch-norm,canonicalize,hlo-legalize-to-linalg,canonicalize))", - "Lowering StableHLO to Linalg-on-Tensors", - ) - return self.refbackend.compile(imported_module) - - def load(self, module): - """Loads a compiled artifact into the runtime.""" - return self.refbackend.load(module) diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 43992573e8fc..6ae664d4165b 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -13,10 +13,10 @@ # ============================================================================== class ScalarConstantTupleModule(torch.nn.Module): - + def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -60,7 +60,7 @@ def MmModule_chained(module, tu: TestUtils): # ============================================================================== -class BmmModule(torch.nn.Module): +class BmmFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -75,11 +75,31 @@ def forward(self, lhs, rhs): return torch.bmm(lhs, rhs) -@register_test_case(module_factory=lambda: BmmModule()) -def BmmModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: BmmFloatModule()) +def BmmFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) +class BmmIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.bmm(lhs, rhs) + + +@register_test_case(module_factory=lambda: BmmIntModule()) +def BmmIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=100), tu.randint(3, 5, 4, high=100)) + + # ============================================================================== @@ -353,6 +373,28 @@ def FlattenDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class AliasModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inp_tensor): + return torch.ops.aten.alias(inp_tensor) + + +@register_test_case(module_factory=lambda: AliasModule()) +def AliasModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, low=-1)) + + +# ============================================================================== + + class ConstantPad2dStaticModule(torch.nn.Module): def __init__(self): @@ -1122,6 +1164,25 @@ def SoftmaxIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) +class SoftmaxIntNonNoneDtypeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, tensor): + return torch.ops.aten.softmax(tensor, dim=2, dtype=torch.float64) + + +@register_test_case(module_factory=lambda: SoftmaxIntNonNoneDtypeModule()) +def SoftmaxIntNonNoneDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) + + # ============================================================================== @@ -1440,6 +1501,30 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastDynamicDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, -1, 1, -1], torch.float32, True), + ([1, -1, 1, -1], torch.float32, True), + ]) + def forward(self, x, y): + dim_at_index_1 = torch.ops.aten.size(x, 1) + dim_at_index_3 = torch.ops.aten.size(x, 3) + res = torch.ops.aten.broadcast_to(y, [1, dim_at_index_1, 1, dim_at_index_3]) + return res + + +@register_test_case(module_factory=lambda: BroadcastDynamicDimModule()) +def BroadcastDynamicDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 4), tu.rand(1, 1, 1, 1)) + +# ============================================================================== + class BroadcastDifferentRankWithMinusOneModule(torch.nn.Module): def __init__(self): @@ -1586,6 +1671,47 @@ def RepeatInterleaveStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class TileSmallDimsSizeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.tile([3, 4]) + + +@register_test_case(module_factory=lambda: TileSmallDimsSizeModule()) +def TileSmallDimsSizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) + +# ============================================================================== + +class TileBigDimsSizeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.tile([3, 4, 5, 6]) + + +@register_test_case(module_factory=lambda: TileBigDimsSizeModule()) +def TileBigDimsSizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) + +# ============================================================================== + + class ExpandModule(torch.nn.Module): def __init__(self): @@ -1936,6 +2062,94 @@ def forward(self, x): def DropoutTrainModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1536)) +# ============================================================================== + + +class DropoutTrainStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 1536], torch.float32, True), + ]) + def forward(self, x): + res = torch.dropout(x, 0.3, train=True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: DropoutTrainStaticShapeModule()) +def DropoutTrainStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + +# ============================================================================== + + +class NativeDropoutEvalFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.native_dropout(x, 0.1, train=False) + + +@register_test_case(module_factory=lambda: NativeDropoutEvalFloatModule()) +def NativeDropoutEvalFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class NativeDropoutTrainModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + res = torch.native_dropout(x, 0.3, train=True) + return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + + +@register_test_case(module_factory=lambda: NativeDropoutTrainModule()) +def NativeDropoutTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class NativeDropoutTrainStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 1536], torch.float32, True), + ]) + def forward(self, x): + res = torch.native_dropout(x, 0.3, train=True) + return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + + +@register_test_case(module_factory=lambda: NativeDropoutTrainStaticShapeModule()) +def NativeDropoutTrainStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) # ============================================================================== @@ -2251,6 +2465,8 @@ def IndexTensorStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4)) # ============================================================================== + + class IndexTensorMultiIndexStaticModule(torch.nn.Module): def __init__(self): @@ -2318,6 +2534,102 @@ def IndexTensorModule3dInputStatic_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorStaticContiguousWithNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4, 5, 32], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ]) + def forward(self, x, index, index1): + return torch.ops.aten.index(x, (None, index, index1, None)) + + +@register_test_case(module_factory=lambda: IndexTensorStaticContiguousWithNoneModule()) +def IndexTensorStaticContiguousWithNoneModule_basic(module, tu: TestUtils): + + module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]])) + +# ============================================================================== + + +class IndexTensorDyanmicInputContiguousWithNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ]) + def forward(self, x, index, index1): + return torch.ops.aten.index(x, (None, index, index1, None)) + + +@register_test_case(module_factory=lambda: IndexTensorDyanmicInputContiguousWithNoneModule()) +def IndexTensorDyanmicInputContiguousWithNoneModule_basic(module, tu: TestUtils): + + module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]])) + +# ============================================================================== + + +class IndexTensorStaticNonContiguousWithNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4, 5, 32], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ]) + def forward(self, x, index, index1, index2): + return torch.ops.aten.index(x, (None, index, index1, None, index2)) + + +@register_test_case(module_factory=lambda: IndexTensorStaticNonContiguousWithNoneModule()) +def IndexTensorStaticNonContiguousWithNoneModule_basic(module, tu: TestUtils): + + module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]]), torch.tensor([[0],[1]])) + +# ============================================================================== + +class IndexTensorDyanmicInputNonContiguousWithNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ]) + def forward(self, x, index, index1, index2): + return torch.ops.aten.index(x, (None, index, index1, None, index2)) + + +@register_test_case(module_factory=lambda: IndexTensorDyanmicInputNonContiguousWithNoneModule()) +def IndexTensorDyanmicInputNonContiguousWithNoneModule_basic(module, tu: TestUtils): + + module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]]), torch.tensor([[0],[1]])) + +# ============================================================================== + class IndexTensorSelectDimModule(torch.nn.Module): @@ -2584,6 +2896,29 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorNegativeIndexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 3, 2], torch.float32, True), + ([1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, (None, None, index)) + + +@register_test_case(module_factory=lambda: IndexTensorNegativeIndexModule()) +def IndexTensorNegativeIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 3, 2), tu.randint(1, low=-2, high=0)) + + +# ============================================================================== + + class IndexTensorHackedTwinModule(torch.nn.Module): def __init__(self): @@ -3085,6 +3420,48 @@ def forward(self, x): def FlipModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) +# ============================================================================== + + +class FlipModuleStaticShape(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.flip(x, [1, 2]) + + +@register_test_case(module_factory=lambda: FlipModuleStaticShape()) +def FlipModuleStaticShape_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) + +# ============================================================================== + + +class FlipNegativeIndexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.flip(x, [-1]) + + +@register_test_case(module_factory=lambda: FlipNegativeIndexModule()) +def FlipNegativeIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) + # ============================================================================== @@ -3477,6 +3854,42 @@ def forward(self, lhs): def NumpyTRank0Module_basic(module, tu: TestUtils): module.forward(torch.tensor(7, dtype=torch.float32)) + +# ============================================================================== + + +class AtenEmbeddingBagStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 2], torch.float32, True), + ([3], torch.int64, True), + ([1], torch.int64, True), + ]) + def forward(self, weight, indices, offsets): + return torch.ops.aten.embedding_bag(weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=None) + + +@register_test_case(module_factory=lambda: AtenEmbeddingBagStaticModule()) +def AtenEmbeddingBagStaticModule_basic(module, tu: TestUtils): + weight = tu.rand(4, 2) + indices = torch.LongTensor([3, 0, 1]) + offsets = torch.LongTensor([0]) + module.forward(weight, indices, offsets) + + class AtenEmbeddingBagSumExample(torch.nn.Module): def __init__(self): @@ -3490,15 +3903,26 @@ def __init__(self): ([-1], torch.int64, True), ]) def forward(self, weight, indices, offsets): - return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None) + return torch.ops.aten.embedding_bag(weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=None) + @register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample()) def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils): - weight = tu.rand(100, 10) - indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + weight = tu.rand(100, 10) + indices = torch.LongTensor( + [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) + class Aten_EmbeddingBagExample(torch.nn.Module): def __init__(self): @@ -3514,13 +3938,16 @@ def __init__(self): def forward(self, weight, indices, offsets): return torch.ops.aten._embedding_bag(weight, indices, offsets) + @register_test_case(module_factory=lambda: Aten_EmbeddingBagExample()) def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): - weight = tu.rand(100, 10) - indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + weight = tu.rand(100, 10) + indices = torch.LongTensor( + [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) + # ============================================================================== class CumsumModule(torch.nn.Module): @@ -3574,6 +4001,23 @@ def forward(self, val): def CumsumStaticNegativeDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 4)) +class CumsumInputDtypeInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 7, 4], torch.int32, True), + ]) + def forward(self, val): + return torch.ops.aten.cumsum(val, 1) + +@register_test_case(module_factory=lambda: CumsumInputDtypeInt32Module()) +def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 7, 4).to(torch.int32)) + # ============================================================================== class AtenToDeviceModule(torch.nn.Module): @@ -4033,7 +4477,7 @@ class OneHotModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([None, ([-1], torch.long, True)]) def forward(self, x): @@ -4268,15 +4712,49 @@ def forward(self, x): def AtenComplexViewModule_basic(module, tu: TestUtils): module.forward(tu.rand(5,2)) +# ============================================================================== +class AtenRealView128Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.complex128, True), + ]) + def forward(self, x): + return torch.view_as_real(x) + + +@register_test_case(module_factory=lambda: AtenRealView128Module()) +def AtenRealView128Module_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 6, 1).to(torch.complex128)) # ============================================================================== +class AtenRealView64Module(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.complex64, True), + ]) + def forward(self, x): + return torch.view_as_real(x) + + +@register_test_case(module_factory=lambda: AtenRealView64Module()) +def AtenRealView64Module_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 6, 1).to(torch.complex64)) + +# ============================================================================== class Add_Module(torch.nn.Module): def __init__(self): super().__init__() - self.tensor = torch.ones(2, 3) + self.register_buffer('tensor', torch.ones(2, 3)) @export @annotate_args([ @@ -4305,7 +4783,7 @@ def __init__(self): ([-1, -1, -1, -1], torch.float32, True), ]) def forward(self, x): - return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]); + return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]); @register_test_case(module_factory=lambda: Im2Col_Module()) def Im2ColModule_basic(module, tu: TestUtils): diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 1b92c8f17135..552e2aa0862e 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1093,6 +1093,126 @@ def forward(self, a): def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(tu.randint(10, 4, high=100)) +# ============================================================================== + + +class NewFullModuleDefaultDtype(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, (3,4), 5) + + +@register_test_case(module_factory=lambda: NewFullModuleDefaultDtype()) +def NewFullModuleDefaultDtype_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3)) + + +class NewFullModuleInt2D(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, (3,4), 10.5) + + +@register_test_case(module_factory=lambda: NewFullModuleInt2D()) +def NewFullModuleInt2D_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 5, high=10)) + + +class NewFullModuleInt3D(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, (3,4), 5.0, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: NewFullModuleInt3D()) +def NewFullModuleInt3D_basic(module, tu: TestUtils): + module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32)) + + +class NewFullModuleFloat3D(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, (3,4), 15, dtype=torch.float32) + + +@register_test_case(module_factory=lambda: NewFullModuleFloat3D()) +def NewFullModuleFloat3D_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + + +class NewFullModuleFloat3DStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float64, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, (3,4), 15.3, dtype=torch.float32) + + +@register_test_case(module_factory=lambda: NewFullModuleFloat3DStatic()) +def NewFullModuleFloat3DStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + + +class NewFullModuleFalsePinMemory(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.new_full(a, + (3,4), + 5, + dtype=torch.int64, + pin_memory=False) + + +@register_test_case(module_factory=lambda: NewFullModuleFalsePinMemory()) +def NewFullModuleFalsePinMemory_basic(module, tu: TestUtils): + module.forward(tu.randint(10, 4, high=100)) + # ============================================================================== @@ -1528,6 +1648,7 @@ def forward(self, a): def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== @@ -1543,3 +1664,26 @@ def forward(self): @register_test_case(module_factory=lambda: EyeStaticModule()) def EyeStaticModule_basic(module, tu: TestUtils): module.forward() + +# ============================================================================== + + +class EmptyStridedModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + def forward(self, a): + x = torch.ops.aten.empty_strided(a.size(), stride=[12, 4, 1]) + y = x.copy_(a) + return y + + +@register_test_case(module_factory=lambda: EmptyStridedModule()) +def EmptyStridedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 64116d059cc2..b9ba1c0947bc 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -177,32 +177,56 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils): class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module): - def __init__(self): + def __init__(self, out_channels, groups): super().__init__() torch.manual_seed(0) - self.conv = torch.nn.Conv2d(in_channels=2, - out_channels=10, + self.conv = torch.nn.Conv2d(in_channels=4, + out_channels=out_channels, kernel_size=3, padding=3, stride=2, dilation=3, - bias=False) + bias=False, + groups=groups) self.train(False) @export @annotate_args([ None, - ([5, 2, 10, 20], torch.float32, True), + ([5, 4, 10, 20], torch.float32, True), ]) def forward(self, x): return self.conv(x) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule()) + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1)) def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): - t = tu.rand(5, 2, 10, 20) - module.forward(t) + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4)) +def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4)) +def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2)) +def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2)) +def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 723a87d1eec6..71f2a32ac00d 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -473,6 +473,47 @@ def forward(self, x): def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, low=-1)) +# ============================================================================== + + +class ElementwiseEluNonDefaultModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.elu(x, scale=1.5, alpha=2.0, input_scale=3.0) + +@register_test_case(module_factory=lambda: ElementwiseEluNonDefaultModule()) +def ElementwiseEluNonDefaultModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseEluModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.elu(x) + +@register_test_case(module_factory=lambda: ElementwiseEluModule()) +def ElementwiseEluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3, low=-1, high=1)) + # ============================================================================== @@ -612,6 +653,52 @@ def ElementwiseMinimumIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseMinOtherModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return x.min(y) + + +@register_test_case(module_factory=lambda: ElementwiseMinOtherModule()) +def ElementwiseMinOtherModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5), tu.rand(3, 5)) + + +# ============================================================================== + + +class ElementwiseMinOtherIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return x.min(y) + + +@register_test_case(module_factory=lambda: ElementwiseMinOtherIntModule()) +def ElementwiseMinOtherIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) + + +# ============================================================================== + + class ElementwiseMaximumModule(torch.nn.Module): def __init__(self): @@ -658,6 +745,52 @@ def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseMaxOtherModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return x.max(y) + + +@register_test_case(module_factory=lambda: ElementwiseMaxOtherModule()) +def ElementwiseMaxOtherModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5), tu.rand(3, 5)) + + +# ============================================================================== + + +class ElementwiseMaxOtherIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return x.max(y) + + +@register_test_case(module_factory=lambda: ElementwiseMaxOtherIntModule()) +def ElementwiseMaxOtherIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) + + +# ============================================================================== + + class ElementwiseClampModule(torch.nn.Module): def __init__(self): @@ -1003,6 +1136,28 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseMulTensorComplexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule()) +def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64)) + + +# ============================================================================== class ElementwiseMishModule(torch.nn.Module): @@ -1451,23 +1606,6 @@ def ElementwiseSignModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwisePowScalarModule(torch.nn.Module): - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) - def forward(self, x): - return torch.ops.aten.pow(0.5, x) - -@register_test_case(module_factory=lambda: ElementwisePowScalarModule()) -def ElementwisePowScalarModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4)) - - -# ============================================================================== - - class ElementwisePowModule(torch.nn.Module): def __init__(self): @@ -1582,6 +1720,28 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowScalarModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ]) + def forward(self, exp): + return torch.pow(2.0, exp) + + +@register_test_case(module_factory=lambda: ElementwisePowScalarModule()) +def ElementwisePowScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): def __init__(self): @@ -2070,6 +2230,56 @@ def ElementwiseBitwiseOrStaticShapeModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseOrTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.__or__(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseOrTensorModule()) +def ElementwiseOrTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(3, 4, low=-10, high=10)) + + +# ============================================================================== + + +class ElementwiseOrTensorStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.__or__(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseOrTensorStaticShapeModule()) +def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(4, low=-10, high=10)) + + +# ============================================================================== + + class ElementwiseBitwiseXorModule(torch.nn.Module): def __init__(self): @@ -2721,7 +2931,7 @@ def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils): class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -2740,7 +2950,7 @@ def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -3334,3 +3544,32 @@ def forward(self, tensor, value): @register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64()) def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.randint()) + + +# ============================================================================== + + +class TupleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a, b): + cond = True + if cond: + tuple = a, b + else: + tuple = a + b, None + _, y = tuple + return y + + +@register_test_case(module_factory=lambda: TupleModule()) +def TupleModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2), tu.rand(2, 2)) diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 69073c6ab6c2..dd18545b0bc4 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -700,3 +700,159 @@ def forward(self, x): @register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule()) def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) + + +# ============================================================================== + + +class AvgPool1dFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d(kernel_size=6, + stride=2, + padding=3, + ceil_mode=False, + count_include_pad=True) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.ap1d(x) + +@register_test_case(module_factory=lambda: AvgPool1dFloatModule()) +def AvgPool1dFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, low=-1)) + + +class AvgPool1dIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d(kernel_size=6, + stride=2, + padding=3, + ceil_mode=False, + count_include_pad=True) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, x): + return self.ap1d(x) + +@register_test_case(module_factory=lambda: AvgPool1dIntModule()) +def AvgPool1dIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 4, 20, high=100)) + + +class AvgPool1dStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d(kernel_size=6, + stride=2, + padding=3, + ceil_mode=False, + count_include_pad=True) + + @export + @annotate_args([ + None, + ([2, 4, 20], torch.int64, True), + ]) + def forward(self, x): + return self.ap1d(x) + +@register_test_case(module_factory=lambda: AvgPool1dStaticModule()) +def AvgPool1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 4, 20, high=100)) + + +# ============================================================================== + + +class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + + @export + @annotate_args([ + None, + ([1, 512, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeStaticModule()) +def AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + + @export + @annotate_args([ + None, + ([1, 512, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeStaticModule()) +def AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) \ No newline at end of file diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 1f459affd5ec..06159324b304 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -313,6 +313,26 @@ def forward(self, a): def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) + +# ============================================================================== + +class ReduceProdDimIntFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.prod(a, 1, dtype=torch.float32) + + +@register_test_case(module_factory=lambda: ReduceProdDimIntFloatModule()) +def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float32)) + # ============================================================================== class ReduceMaxAlongDim(torch.nn.Module): @@ -591,6 +611,58 @@ def ReduceAmaxKeepDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceMinFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a) +@register_test_case(module_factory=lambda: ReduceMinFloatModule()) +def ReduceMinFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceMinSignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a) + +@register_test_case(module_factory=lambda: ReduceMinSignedIntModule()) +def ReduceMinSignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + +# ============================================================================== + +class ReduceMinUnsignedIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a) + +@register_test_case(module_factory=lambda: ReduceMinUnsignedIntModule()) +def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=100)) + +# ============================================================================== class ReduceL1NormModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 22076e0310f9..1baa462462f1 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -6,6 +6,28 @@ # ============================================================================== +class RandModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 512], torch.float, True) + ]) + def forward(self, x): + size = x.size() + a = torch.rand(size) + return torch.std(a), torch.mean(a) + + +@register_test_case(module_factory=lambda: RandModule()) +def RandModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 512)) + +# ============================================================================== + class UniformModule(torch.nn.Module): def __init__(self): @@ -44,6 +66,44 @@ def UniformModule_basic(module, tu: TestUtils): # ============================================================================== +class UniformStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([256, 512, 12], torch.float64, True), + ([512, 1024, 12], torch.float64, True), + ([512, 256, 12], torch.float64, True), + ]) + def forward(self, x, y, z): + a = torch.ops.aten.uniform_(x, 1.0, 10.0) + b = torch.ops.aten.uniform_(y, -20.0, -5.0) + c = torch.ops.aten.uniform_(z, -15.0, 3.0) + std = torch.cat([ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)) + ]) + mean = torch.cat([ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)) + ]) + return std, mean + + +@register_test_case(module_factory=lambda: UniformStaticShapeModule()) +def UniformStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(256, 512, 12).double(), + tu.rand(512, 1024, 12).double(), + tu.rand(512, 256, 12).double()) + +# ============================================================================== + class UniformNoCorrelationModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/scatter.py index 5e3ea6e8c44f..176ad8506b53 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -872,6 +872,35 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), tu.randint(5, 8, 6, high=1000)) + +# ============================================================================== +# UnsafeIndexPutHackedTwin tests are using the aten._unsafe_index_put.hacked_twin operator. + + +class UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._unsafe_index_put(input, [index], + value, + accumulate=False) + + +@register_test_case( + module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule()) +def UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) + + # ============================================================================== class ScatterSrcStaticModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 25f3bca7a306..b13f23a1c014 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -585,6 +585,42 @@ def NarrowVerticalTest2_basic(module, tu: TestUtils): # ============================================================================== +class NarrowTensorHorizontalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=torch.tensor(0), length=2) + +@register_test_case(module_factory=lambda: NarrowTensorHorizontalModule()) +def NarrowTensorHorizontalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) + +# ============================================================================== + +class NarrowTensorVerticalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=torch.tensor(1), length=2) + +@register_test_case(module_factory=lambda: NarrowTensorVerticalModule()) +def NarrowTensorVerticalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) + +# ============================================================================== + class SliceCopy_Module(torch.nn.Module): def __init__(self): super().__init__() @@ -872,6 +908,72 @@ def SplitTensorListUnpackModule_basic(module, tu: TestUtils): # ============================================================================== + +class SplitTensorLastSmallerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([8, 10, 12], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 3, dim=0) + return s2 + + +@register_test_case(module_factory=lambda: SplitTensorLastSmallerModule()) +def SplitTensorLastSmallerModule_basic(module, tu: TestUtils): + # Splitting the first dimension with 8 elements into chunks of 3 + # will leave the last result to have 2 elements in that dimension. + module.forward(tu.rand(8, 10, 12)) + +# ============================================================================== + + +class SplitTensorNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 12, 6], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 2, -1) + return s1 + + +@register_test_case(module_factory=lambda: SplitTensorNegativeDimModule()) +def SplitTensorNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 12, 6)) + +# ============================================================================== + +class SplitWithSizesListUnpackModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 12], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split_with_sizes(x, [3, 4, 5], -1) + return (s0, s1, s2) + +@register_test_case(module_factory=lambda: SplitWithSizesListUnpackModule()) +def SplitWithSizesListUnpackModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 12)) + +# ============================================================================== + class ChunkListUnpack_Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 6e15da5a4804..6e04c5fa8700 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -169,6 +169,28 @@ def forward(self, x): def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class ToDtypeLayoutCPUModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.to(x, + dtype=torch.float64, + layout=None, + device="cpu", + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None) + + +@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule()) +def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + class ToDtypeLayoutStridedModule(torch.nn.Module): @@ -235,6 +257,27 @@ def forward(self, x, y): def TypeAsSameModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) +class TypeAsDifferentModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.type_as(x, y) + + +@register_test_case(module_factory=lambda: TypeAsDifferentModule()) +def TypeAsDifferentModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 5, low=0, high=10, dtype=torch.int), + tu.randint(3, 5, low=0, high=10, dtype=torch.int64) + ) # ============================================================================== diff --git a/pytorch-hash.txt b/pytorch-hash.txt index ae0f2b8dffe0..754078490fe0 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -69565763c841e4e8d07fd338c9bf6515005b3880 +90c406a3a198b8f45682a9979b4c091ec5dc647e diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index b6b107d405e2..7e93f7c8ce66 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torch==2.1.0.dev20230710 +torch==2.2.0.dev20230922 diff --git a/setup.py b/setup.py index 047d6dd8bfeb..046e5d5ff6e9 100644 --- a/setup.py +++ b/setup.py @@ -167,7 +167,7 @@ def build_extension(self, ext): ext_modules=[ CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], - install_requires=["numpy", ] + ( + install_requires=["numpy", "packaging"] + ( [f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), zip_safe=False, ) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 52936c53b9b1..933031e16e9e 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -259,27 +259,6 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float { return %0 : !torch.float } -// CHECK-LABEL: func.func @torch.aten.any.bool() -> !torch.bool { -// CHECK: %[[CST_FALSE:.*]] = arith.constant false -// CHECK: %[[FALSE:.*]] = torch_c.from_i1 %[[CST_FALSE]] -// CHECK: %[[CST_TRUE:.*]] = arith.constant true -// CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]] -// CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list -// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]] -// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]] -// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]] -// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1 -// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1 -// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]] -// CHECK: return %[[RESULT]] : !torch.bool -func.func @torch.aten.any.bool() -> !torch.bool { - %false = torch.constant.bool false - %true = torch.constant.bool true - %input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %0 = torch.aten.any.bool %input : !torch.list -> !torch.bool - return %0 : !torch.bool -} - // CHECK-LABEL: func.func @torch.aten.Bool.float( // CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool { // CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 71090ea6ed7b..d95b7e1d87cf 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -264,39 +264,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.index.Tensor -// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> -// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor -// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor -// CHECK: %[[CST0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor -// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor -// CHECK: %[[CST1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor -// CHECK: %[[OUT_T:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]]) : tensor -// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor, tensor) outs(%[[OUT_T]] : tensor) { -// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32): -// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index -// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index -// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index -// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor -// CHECK: linalg.yield %[[RESULT]] : f32 -// CHECK: } -> tensor -// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor to tensor -// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { - %none = torch.constant.none - %1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> - %2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list> -> !torch.vtensor<[?,?,?],f32> - return %2 : !torch.vtensor<[?,?,?],f32> -} +} \ No newline at end of file diff --git a/test/Conversion/TorchToStablehlo/scatter.mlir b/test/Conversion/TorchToStablehlo/scatter.mlir new file mode 100644 index 000000000000..a3fb1af6df03 --- /dev/null +++ b/test/Conversion/TorchToStablehlo/scatter.mlir @@ -0,0 +1,35 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_1:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_2:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAR_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAR_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAR_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor +// CHECK: %[[VAR_3:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor +// CHECK: %[[VAR_4:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xi64> +// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xi64> +// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]] : tensor<2xi64> +// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor +// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]], %[[CONSTANT_1]] : tensor<3xi64> +// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor +// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor, tensor) -> tensor +// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({ +// CHECK: ^bb0(%arg3: tensor, %[[ARG_4:.*]]: tensor): +// CHECK: stablehlo.return %[[ARG_4]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64> +func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int0 = torch.constant.int 0 + %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} \ No newline at end of file diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2705f453bdf5..63fdd9368d27 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.tanh %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -16,7 +16,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sigmoid"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.sigmoid %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -29,7 +29,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.relu$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.clamp %[[ARG_BUILTIN]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -46,9 +46,9 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater_equal"(%[[VAL_0]], %[[VAL_3]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_2]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_0]], %[[VAL_5]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_0]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -64,7 +64,7 @@ func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.log$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.log"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.log %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -77,7 +77,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.exp$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.exp"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.exp %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -90,7 +90,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.neg$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.negate"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.negate %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -103,7 +103,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.floor$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.floor"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.floor %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -116,7 +116,7 @@ func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.bitwise_not$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.bitwise_not"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.bitwise_not %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -129,7 +129,7 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK-LABEL: func.func @torch.aten.ceil$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.ceil %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -143,7 +143,7 @@ func.func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.reciprocal$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.reciprocal %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -161,8 +161,8 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -181,8 +181,8 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -199,7 +199,7 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]] {shift = 0 : i32} : (tensor, tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { @@ -214,8 +214,8 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RCP:.*]] = "tosa.reciprocal"(%[[ARG1_BUILTIN]]) : (tensor) -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[RCP:.*]] = tosa.reciprocal %[[ARG1_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[RCP]] {shift = 0 : i32} : (tensor, tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { @@ -244,8 +244,8 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false // CHECK: %[[ARG3:.*]] = torch.constant.int 0 // CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list -// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) <{new_shape = array}> : (tensor<1x?x?x?xf32>) -> tensor +// CHECK: %[[SUM:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[SUM]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -263,11 +263,11 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_sum %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_sum %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_sum %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { @@ -281,11 +281,11 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK-LABEL: func.func @test_reduce_all$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_all %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_all %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_all %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_all %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -300,8 +300,8 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor // CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG2:.*]] = torch.constant.bool false -// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) <{new_shape = array}> : (tensor<1x?x?x?xi1>) -> tensor +// CHECK: %[[REDUCE:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE]] {new_shape = array} : (tensor<1x?x?x?xi1>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { @@ -316,11 +316,11 @@ func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !to // CHECK-LABEL: func.func @test_reduce_any$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_any %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_any %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_any %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -333,7 +333,7 @@ func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // CHECK-LABEL: func.func @torch.aten.rsqrt$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.rsqrt %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -349,7 +349,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.maximum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -365,7 +365,7 @@ func.func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.minimum"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.minimum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -381,7 +381,7 @@ func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -400,8 +400,8 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -421,8 +421,8 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -440,7 +440,7 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -456,7 +456,7 @@ func.func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_3]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -472,7 +472,7 @@ func.func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -488,7 +488,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> // CHECK: } @@ -510,17 +510,17 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool true // CHECK: %[[VAL_7:.*]] = torch.constant.bool false -// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.rsqrt"(%[[VAL_14]]) : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_13]], %[[VAL_15]]) <{shift = 0 : i32}> : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.mul"(%[[VAL_16]], %[[VAL_10]]) <{shift = 0 : i32}> : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.add"(%[[VAL_17]], %[[VAL_11]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> // CHECK: } @@ -542,7 +542,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> // CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> @@ -568,28 +568,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<1.200000e+01> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 3 : i64}> : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) <{axis = 2 : i64}> : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) <{axis = 1 : i64}> : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) <{new_shape = array}> : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) <{shift = 0 : i32}> : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) <{axis = 3 : i64}> : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) <{axis = 2 : i64}> : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) <{axis = 1 : i64}> : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) <{new_shape = array}> : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) <{shift = 0 : i32}> : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) <{new_shape = array}> : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_29:.*]] = "tosa.rsqrt"(%[[VAL_28]]) : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = "tosa.mul"(%[[VAL_27]], %[[VAL_29]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = "tosa.mul"(%[[VAL_30]], %[[VAL_24]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_32:.*]] = "tosa.add"(%[[VAL_31]], %[[VAL_25]]) : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> // CHECK: } @@ -609,8 +609,8 @@ func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor< // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.logical_not"(%[[VAL_4]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.logical_not %[[VAL_4]] : (tensor) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -629,7 +629,7 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: } @@ -649,7 +649,7 @@ func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4 // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> // CHECK: } @@ -664,9 +664,9 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.log"(%[[VAL_1]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor<1x1xf32>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -683,7 +683,7 @@ func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vt // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -702,7 +702,7 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -719,7 +719,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -752,7 +752,7 @@ func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !to // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -772,7 +772,7 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.bool false -// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -798,10 +798,10 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) <{acc_type = f32, kernel = array, pad = array, stride = array}> : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> // CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> // CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32> @@ -828,9 +828,9 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true // CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) <{axis = 2 : i64}> : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) <{axis = 2 : i64}> : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> @@ -861,7 +861,7 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> { // CHECK: %[[CST5:.*]] = torch.constant.int 5 // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> -// CHECK: %[[VAL_1:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<5xi64>) -> tensor<5xi64> +// CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xi64>) -> tensor<5xi64> // CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64> // CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64> func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { @@ -897,11 +897,11 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.equal %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.logical_not %[[VAL_2]] : (tensor) -> tensor // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x5x5xi8>}> : () -> tensor<1x1x5x5xi8> -// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_5:.*]] = tosa.equal %[[INP]], %[[VAL_4]] : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = tosa.logical_not %[[VAL_5]] : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { @@ -927,8 +927,8 @@ func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtens // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x5xi64>}> : () -> tensor<3x5xi64> -// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> -// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_1:.*]] = tosa.equal %[[INP]], %[[VAL_0]] : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<3x5xi1>) -> tensor<3x5xi1> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> // CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { @@ -946,7 +946,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.none // CHECK: %[[VAL_4:.*]] = torch.constant.bool false -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x128xi1>) -> tensor<1x128xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x128xi1>) -> tensor<1x128xi64> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> // CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> // CHECK: } @@ -966,19 +966,19 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) <{new_shape = array}> : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) <{axis = 3 : i64}> : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) <{new_shape = array}> : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) <{shift = 0 : i32}> : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) <{axis = 1 : i64}> : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) <{new_shape = array}> : (tensor<8x1xi32>) -> tensor<1x8xi32> -// CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) <{new_shape = array}> : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> +// CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> // CHECK: } @@ -997,9 +997,9 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.cast"(%[[VAL_7]]) : (tensor<2x2xi32>) -> tensor<2x2xi64> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> // CHECK: return %[[VAL_9]] : !torch.vtensor<[2,2],si64> // CHECK: } @@ -1016,12 +1016,12 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor -// CHECK: %[[VAL_4_CAST:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor) -> tensor +// CHECK: %[[VAL_4_CAST:.*]] = tosa.cast %[[VAL_4]] : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4_CAST]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_7]], %[[VAL_6]]) : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.cast"(%[[VAL_8]]) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4_CAST]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> // CHECK: return %[[VAL_10]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } @@ -1038,7 +1038,7 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 511 -// CHECK: %[[VAL_4:.*]] = "tosa.clamp"(%[[VAL_1]]) <{max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> // CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } @@ -1057,8 +1057,8 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_5]]) : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_6]], %[[VAL_2]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_3]], %[[VAL_6]], %[[VAL_2]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } @@ -1076,7 +1076,7 @@ func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } @@ -1089,7 +1089,7 @@ func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK-LABEL: func.func @torch.aten.abs( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64> -// CHECK: %[[VAL_2:.*]] = "tosa.abs"(%[[VAL_1]]) : (tensor<15x15xi64>) -> tensor<15x15xi64> +// CHECK: %[[VAL_2:.*]] = tosa.abs %[[VAL_1]] : (tensor<15x15xi64>) -> tensor<15x15xi64> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64> // CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64> // CHECK: } @@ -1106,7 +1106,7 @@ func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32> // CHECK: } @@ -1117,15 +1117,15 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.remainder.Scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) <{shift = 0 : i32}> : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) <{shift = 0 : i32}> : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5:.*]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3:.*]], %[[VAL_6:.*]] {shift = 0 : i32} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]] {shift = 0 : i32} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> // CHECK: } diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 57312ee298f9..312554b246ae 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: torch.aten.mul.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i32} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { %float2.000000e00 = torch.constant.float 2.000000e+00 %0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16> @@ -15,8 +15,8 @@ func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> // CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp // CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16> // CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16> -// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<6xf32>) -> tensor<6xbf16> +// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_0]], %[[VAL_3]] : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16> func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> { %float1 = torch.constant.float 1.000000e+00 %0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16> @@ -28,8 +28,8 @@ func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, // CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<5xbf16>) -> tensor<5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_0]], %[[VAL_2]] : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32> @@ -41,8 +41,8 @@ func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, // CHECK-LABEL: torch.aten.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32> -// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_1]] : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32> func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> { %int1 = torch.constant.int 1 %int256 = torch.constant.int 256 @@ -55,7 +55,7 @@ func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) // CHECK-LABEL: torch.aten.sub.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.sub %[[VAL_0]], %[[VAL_2]] : (tensor, tensor) -> tensor func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> { %int1 = torch.constant.int 1 %0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16> @@ -67,8 +67,8 @@ func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg // CHECK-LABEL: torch.aten.maximum$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.maximum %[[VAL_2]], %[[VAL_1]] : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32> func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> { %0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32> return %0 : !torch.vtensor<[1,3,1],f32> @@ -79,8 +79,8 @@ func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %a // CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.bitwise_and %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> return %0 : !torch.vtensor<[?,?],si32> @@ -91,9 +91,9 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor) -> tensor func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -104,8 +104,8 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.div %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> @@ -116,8 +116,8 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // CHECK-LABEL: torch.aten.div.Scalar$int_input_fp_output // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<7.812500e-03> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i32} : (tensor, tensor<1x1xf32>) -> tensor func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> { %int128 = torch.constant.int 128 %0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32> @@ -129,8 +129,8 @@ func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ? // CHECK-LABEL: torch.aten.pow.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %arg0 : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.pow %[[VAL_2]], %[[VAL_1]] : (tensor, tensor<1x1xf32>) -> tensor func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 3.000000e+00 %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32> diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index 6f07530d2a09..5ee5bbf6f446 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -97,20 +97,3 @@ func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vte %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple return %0 : !torch.tuple } - -// ----- - -// Single tensor tuple return -// expected-error @+1 {{Functions must return}} -func.func @single_tensor_tuple_return(%arg0: !torch.tensor) -> !torch.tuple { - %0 = torch.prim.TupleConstruct %arg0 : !torch.tensor -> !torch.tuple - return %0 : !torch.tuple -} - -// ----- - -// Multiple, non-tuple return -// expected-error @+1 {{should only ever return one item}} -func.func @multiple_non_tuple_return(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) { - return %arg0, %arg0 : !torch.tensor, !torch.tensor -} \ No newline at end of file diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index b1e9886d369e..21e0500f4eb5 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -975,6 +975,15 @@ func.func @torch.prim.TupleUnpack(%arg0: !torch.tensor, %arg1: !torch.tensor) -> return %124#0 : !torch.tensor } +// CHECK-LABEL: func.func @torch.prim.TupleUnpack.Derefined( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.optional { +// CHECK: %[[DEREFINED:.+]] = torch.derefine %[[ARG]] : !torch.tensor to !torch.optional +// CHECK: return %[[DEREFINED]] : !torch.optional +func.func @torch.prim.TupleUnpack.Derefined(%arg: !torch.tensor) -> !torch.optional { + %tuple = torch.prim.TupleConstruct %arg : !torch.tensor -> !torch.tuple + %optional_tensor = torch.prim.TupleUnpack %tuple : !torch.tuple -> !torch.optional + return %optional_tensor : !torch.optional +} // CHECK-LABEL: func.func @torch.aten.__contains__.str( // CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor, @@ -1036,6 +1045,16 @@ func.func @torch.aten.add.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.add.float_int() -> !torch.float { +// CHECK: %[[CST9:.*]] = torch.constant.float 9.000000e+00 +// CHECK: return %[[CST9]] : !torch.float +func.func @torch.aten.add.float_int() -> !torch.float { + %cst4 = torch.constant.float 4.0 + %cst5 = torch.constant.int 5 + %ret = torch.aten.add.float_int %cst4, %cst5: !torch.float, !torch.int -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int { // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: return %[[CST1]] : !torch.int @@ -1056,6 +1075,25 @@ func.func @torch.aten.mul.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { +// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 +// CHECK: return %[[CST30]] : !torch.float +func.func @torch.aten.mul.float() -> !torch.float { + %cst6 = torch.constant.float 6.0 + %cst5 = torch.constant.float 5.0 + %ret = torch.aten.mul.float %cst6, %cst5: !torch.float, !torch.float -> !torch.float + return %ret : !torch.float +} + +// CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { +// CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 +// CHECK: return %[[CST_6]] : !torch.float +func.func @torch.aten.neg.float() -> !torch.float { + %cst6 = torch.constant.float 6.0 + %ret = torch.aten.neg.float %cst6: !torch.float -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int { // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: return %[[CST0]] : !torch.int @@ -1383,14 +1421,6 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } -// CHECK-LABEL: func.func @torch.aten.type_as$same( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32> -func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { - %0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32> - return %0 : !torch.tensor<[?,?],f32> -} - // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> @@ -1414,6 +1444,21 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch return %0 : !torch.tensor } +// CHECK-LABEL: func.func @torch.aten.to.other$basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" +// CHECK: %[[VAR_0:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int +// CHECK: %[[VAR_1:.*]] = torch.aten.to.device %[[ARG_0]], %[[CPU]], %[[VAR_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.Device, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor +// CHECK: return %[[VAR_1]] : !torch.tensor +func.func @torch.aten.to.other$basic(%arg0 : !torch.tensor, %arg1 : !torch.tensor) -> !torch.tensor { + %none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten.to.other %arg0, %arg1, %false, %false, %none : !torch.tensor, !torch.tensor, !torch.bool, !torch.bool, !torch.none -> !torch.tensor + return %0 : !torch.tensor +} + // CHECK-LABEL: func.func @torch.aten.view$1D( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?],f32> @@ -1926,6 +1971,18 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te return %1: !torch.tensor } +// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> +func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[3,4,2],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> @@ -2014,3 +2071,42 @@ func.func @torch.prims.view_of$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torc %0 = torch.prims.view_of %arg0 : !torch.vtensor<[3,4,2],f32> -> !torch.vtensor<[3,4,2],f32> return %0 : !torch.vtensor<[3,4,2],f32> } + +// CHECK-LABEL: func.func @torch.aten.cuda$canonicalize +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor +// CHECK-NEXT: return %[[ARG]] : !torch.tensor +func.func @torch.aten.cuda$canonicalize(%arg0: !torch.tensor) -> !torch.tensor { + %0 = torch.aten.cuda %arg0 : !torch.tensor -> !torch.tensor + return %0 : !torch.tensor +} + +// CHECK-LABEL: func.func @torch.aten.device.with_index$canonicalize +// CHECK-NEXT: %[[VAL:.*]] = torch.constant.device "cuda:0" +// CHECK-NEXT: return %[[VAL]] : !torch.Device +func.func @torch.aten.device.with_index$canonicalize() -> !torch.Device { + %str = torch.constant.str "cuda" + %int0 = torch.constant.int 0 + %0 = torch.aten.device.with_index %str, %int0 : !torch.str, !torch.int -> !torch.Device + return %0 : !torch.Device +} + +// CHECK-LABEL: func.func @torch.aten.add$fold() -> !torch.float { +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 3.000000e+00 +// CHECK: return %[[FLOAT_1]] : !torch.float +func.func @torch.aten.add$fold() -> !torch.float { + %float1 = torch.constant.float 1.0 + %float2 = torch.constant.float 2.0 + %0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + +// CHECK-LABEL: func.func @torch.aten.any.bool$fold() -> !torch.bool { +// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[CST_TRUE]] : !torch.bool +func.func @torch.aten.any.bool$fold() -> !torch.bool { + %false = torch.constant.bool false + %true = torch.constant.bool true + %input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %0 = torch.aten.any.bool %input : !torch.list -> !torch.bool + return %0 : !torch.bool +} \ No newline at end of file diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 5fa1a5df5d08..e5d5ca19d8a2 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -118,3 +118,27 @@ func.func @torch.aten.acos$float_type(%arg0: !torch.vtensor<[2, 2],f32>, %arg1: %0 = torch.aten.acos %arg0 : !torch.vtensor<[2, 2],f32> -> !torch.vtensor<[2, 2],f32> return %0 : !torch.vtensor<[2, 2],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.type_as$basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor +// CHECK: return %[[VAR]] : !torch.tensor +func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { + %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor + return %0 : !torch.tensor +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.type_as$fold( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> { +// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16> +func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> { + %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> + return %0 : !torch.tensor<[?], f16> +} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 254a348cdec4..f22d5b785746 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -25,7 +25,7 @@ torch.class_type @c { } %c0 = torch.constant.int 0 %0 = torch.nn_module { - // expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() {name = "g", type = !torch.int} : () -> ()}} + // expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() <{name = "g", type = !torch.int}> : () -> ()}} torch.slot "f", %c0 : !torch.int } : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index ad810ec97ccb..b3a225962785 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -9,6 +9,14 @@ func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { return %2 : !torch.tensor } +// CHECK-LABEL: func.func @refine_optional( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: return %[[ARG]] : !torch.vtensor<[2],f32> +func.func @refine_optional(%arg: !torch.vtensor<[2],f32>) -> !torch.optional> { + %res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional> + return %res : !torch.optional> +} + // CHECK-LABEL: func.func @multiple_use_non_value_tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { @@ -34,6 +42,17 @@ func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.t return %2 : !torch.tensor } +// No conversion on private function. +// CHECK-LABEL: func.func private @dont_refine_private( +// CHECK-SAME: %[[ARG:.+]]: !torch.vtensor<[2],f32>) -> !torch.optional> { +// CHECK: %[[RES:.+]] = torch.derefine %[[ARG]] : !torch.vtensor<[2],f32> to !torch.optional> +// CHECK: return %[[RES]] : !torch.optional> +// CHECK: } +func.func private @dont_refine_private(%arg: !torch.vtensor<[2],f32>) -> !torch.optional> { + %res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional> + return %res : !torch.optional> +} + // ----- // Call to public function. diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 265497ddf324..9aec26662b69 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -72,3 +72,18 @@ func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: ! %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor } + +// ----- + +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.arange( + +// CHECK-LABEL: func.func @derefine_int_to_number() -> !torch.vtensor { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[NUMBER:.*]] = torch.derefine %[[INT1]] : !torch.int to !torch.number +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.arange(%[[NUMBER]], {{.*}}) : (!torch.number, {{.*}}) -> !torch.int +func.func @derefine_int_to_number() -> !torch.vtensor { + %int1 = torch.constant.int 1 + %none = torch.constant.none + %0 = torch.aten.arange %int1, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %0 : !torch.vtensor +} diff --git a/test/Dialect/Torch/simplify-dtype-calculations.mlir b/test/Dialect/Torch/simplify-dtype-calculations.mlir index 238699943c76..e7e860a3fb72 100644 --- a/test/Dialect/Torch/simplify-dtype-calculations.mlir +++ b/test/Dialect/Torch/simplify-dtype-calculations.mlir @@ -285,18 +285,18 @@ func.func @refine_dtype$derefine_result_type(%arg0: !torch.int, %arg1: !torch.in } // CHECK-LABEL: func.func @refine_dtype$complex_type( -// CHECK: {{.*}} = torch.aten.fft_fft{{.*}}-> !torch.vtensor<*,complex> +// CHECK: {{.*}} = torch.aten.fft_fft{{.*}}-> !torch.vtensor<*,complex> func.func @refine_dtype$complex_type(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { // dtype for ComplexFloat, a.k.a Complex64 %int9 = torch.constant.int 9 %none = torch.constant.none %int-1 = torch.constant.int -1 %0 = torch.dtype.calculate { - %2 = torch.aten.fft_fft %arg0, %none, %int-1, %none : !torch.vtensor<*,f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<*,unk> - torch.dtype.calculate.yield %2 : !torch.vtensor<*,unk> + %2 = torch.aten.fft_fft %arg0, %none, %int-1, %none : !torch.vtensor<*,f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<*,complex> + torch.dtype.calculate.yield %2 : !torch.vtensor<*,complex> } dtypes { torch.dtype.calculate.yield.dtypes %int9 : !torch.int - } : !torch.vtensor<*,unk> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<*,unk> to !torch.vtensor + } : !torch.vtensor<*,complex> + %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<*,complex> to !torch.vtensor return %1 : !torch.vtensor } diff --git a/test/Dialect/Torch/verify-backend-contract-error.mlir b/test/Dialect/Torch/verify-backend-contract-error.mlir index eb9c6c581a99..22fdd2ec7149 100644 --- a/test/Dialect/Torch/verify-backend-contract-error.mlir +++ b/test/Dialect/Torch/verify-backend-contract-error.mlir @@ -1,7 +1,36 @@ // RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s + func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { // expected-error @below {{unsupported by backend contract: tensor with unknown rank}} // expected-note @below {{this is likely due to a missing transfer function}} %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor return %t : !torch.vtensor } + +// ----- + +// expected-error @below {{invalid dtype 'i9'}} +func.func @bad_element_type(%arg: !torch.vtensor<[?],i9>) -> !torch.vtensor<[?],i9> { + return %arg : !torch.vtensor<[?],i9> +} + +// ----- + +// expected-error @below {{unsupported by backend contract: non-value tensor type}} +// expected-note @below {{this is likely due to a missing case in the MaximizeValueSemantics pass}} +func.func @non_value_tensor(%arg0: !torch.tensor) -> !torch.tensor { + return %arg0 : !torch.tensor +} + +// ----- + +func.func @valid_tuple(%arg0: !torch.vtensor<[?],f32>) -> !torch.tuple> { + %0 = torch.prim.TupleConstruct %arg0 : !torch.vtensor<[?],f32> -> !torch.tuple> + return %0 : !torch.tuple> +} + +// ----- + +func.func @valid_multiple_ret_values(%arg0: !torch.vtensor<[?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) { + return %arg0, %arg0 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir new file mode 100644 index 000000000000..4f72f24e8868 --- /dev/null +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -0,0 +1,45 @@ +// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-LABEL: func @forward +func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2],f16> { + %q_rhs = torch.vtensor.literal(dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8> + %scales = torch.vtensor.literal(dense<1.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + %zps = torch.vtensor.literal(dense<0.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + %bit_width = torch.constant.int 8 + %group_size = torch.constant.int 2 + %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,2],f16>, !torch.vtensor<[2,2],ui8>, !torch.vtensor<[2,1,1],f16>, !torch.vtensor<[2,1,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f16> + // CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2],f16> -> tensor<1x1x2xf16> + // CHECK: %[[TENSOR1:.*]] = torch.vtensor.literal(dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8> + // CHECK: %[[QUANT_RHS:.*]] = torch_c.to_builtin_tensor %[[TENSOR1]] : !torch.vtensor<[2,2],ui8> -> tensor<2x2xi8> + // CHECK: %[[TENSOR2:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + // CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> + // CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + // CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> + // CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16> + // CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8> + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16> + // CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16> + // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[CST]] : f16) outs(%[[EMPTY2]] : tensor<1x1x2xf16>) -> tensor<1x1x2xf16> + // CHECK: %[[DEQUANT_RHS:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_RHS]], %[[SCALES]], %[[ZPS]] : tensor<2x1x2xi8>, tensor<2x1x1xf16>, tensor<2x1x1xf16>) outs(%[[EMPTY1]] : tensor<2x1x2xf16>) { + // CHECK-NEXT: ^bb0(%[[WEIGHTS:.*]]: i8, %[[SCALES:.*]]: f16, %[[ZPS:.*]]: f16, %{{.*}}: f16): + // CHECK-NEXT: %[[EXTUI:.*]] = arith.extui %[[WEIGHTS]] : i8 to i32 + // CHECK-NEXT: %[[UITOFP:.*]] = arith.uitofp %[[EXTUI]] : i32 to f16 + // CHECK-NEXT: %[[SUBF:.*]] = arith.subf %[[UITOFP]], %[[ZPS]] : f16 + // CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[SUBF]], %[[SCALES]] : f16 + // CHECK-NEXT: linalg.yield %[[MULF]] : f16 + // CHECK-NEXT: } -> tensor<2x1x2xf16> + // CHECK: %[[MATMUL:.*]] = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[EXPANDED_LHS]], %[[DEQUANT_RHS]] : tensor<1x1x1x2xf16>, tensor<2x1x2xf16>) outs(%[[OUT]] : tensor<1x1x2xf16>) { + // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: f16, %[[OUT:.*]]: f16): + // CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f16 + // CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[MULF]], %[[OUT]] : f16 + // CHECK-NEXT: linalg.yield %[[ADDF]] : f16 + // CHECK-NEXT: } -> tensor<1x1x2xf16> + // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<1x1x2xf16> to tensor<1x1x2xf16> + return %output : !torch.vtensor<[1,1,2],f16> +} diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir new file mode 100644 index 000000000000..0ca64ae09397 --- /dev/null +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -0,0 +1,13 @@ +// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @forward +func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> { + %q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8> + // CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4> + %scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> + %zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> + %bit_width = torch.constant.int 4 + %group_size = torch.constant.int 2 + %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16> + return %output : !torch.vtensor<[1,1,8],f16> +} diff --git a/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir index 2a55a3231548..c489375268b9 100644 --- a/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir +++ b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir @@ -2,7 +2,7 @@ // CHECK: func.func @tanh func.func @tanh(%arg0: tensor) -> tensor { - %0 = "tosa.tanh"(%arg0) : (tensor) -> tensor + %0 = tosa.tanh %arg0 : (tensor) -> tensor return %0 : tensor } diff --git a/test/python/custom_op_shape_dtype_fn.py b/test/python/custom_op_shape_dtype_fn.py index d955ec7a2a9a..a46f1c594031 100644 --- a/test/python/custom_op_shape_dtype_fn.py +++ b/test/python/custom_op_shape_dtype_fn.py @@ -3,6 +3,7 @@ from typing import List, Tuple import torch +import torch.multiprocessing as mp import torch.utils.cpp_extension import torch_mlir from torch_mlir_e2e_test.annotations import export, annotate_args @@ -51,15 +52,40 @@ def forward(self, a): mod = CustomOpExampleModule() mod.eval() -module = torch_mlir.compile( - mod, - torch.ones(3, 4), - output_type="torch", - backend_legal_ops=["goofy.identity"], - extra_library=extra_library, -) +def run(): + mod = CustomOpExampleModule() + mod.eval() -print(module) + module = torch_mlir.compile( + mod, + torch.ones(3, 4), + output_type="torch", + backend_legal_ops=["goofy.identity"], + extra_library=extra_library, + ) + + print(module) + +run() + +# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} { +# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +# CHECK: %{{.*}} = torch.constant.int 2 +# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> +# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: return %1 : !torch.vtensor<[3,4],f32> +# CHECK: } +# CHECK: } + +# Using `torch.multiprocessing` adds extra namespaces to the abstract +# interpretation functions when they are imported into MLIR: +# `func @"__torch__.__mp_main__.{name}...` +# This tests that the extra namespaces are removed correctly. +if __name__ == "__main__": + mp.set_start_method("spawn") + p = mp.Process(target=run, args=()) + p.start() + p.join() # CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} { # CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { diff --git a/test/python/importer/jit_ir/node_import/debug-info.py b/test/python/importer/jit_ir/node_import/debug-info.py index b6543ed61733..f7b441a12da0 100644 --- a/test/python/importer/jit_ir/node_import/debug-info.py +++ b/test/python/importer/jit_ir/node_import/debug-info.py @@ -17,14 +17,11 @@ @mb.import_function @torch.jit.script def add3(t0, t1, t2): - # TODO: Checks for debug info are quite hard with the new trailing debug - # attribute print. See if this can be improved. - # CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]] + # CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]] intermediate = t0 + t1 - # CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]] - final = intermediate + t2 - return final + # CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]] + return intermediate * t2 # Verify again with debug info present. Just checking that it makes it in there. -mb.module.operation.print(enable_debug_info=True) +mb.module.operation.print(enable_debug_info=True, use_local_scope=True) print() diff --git a/tools/torch-mlir-lsp-server/CMakeLists.txt b/tools/torch-mlir-lsp-server/CMakeLists.txt index 3ee29438e906..d53519c8a047 100644 --- a/tools/torch-mlir-lsp-server/CMakeLists.txt +++ b/tools/torch-mlir-lsp-server/CMakeLists.txt @@ -9,6 +9,7 @@ COMPONENT torch-mlir-lsp-server) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(torch-mlir-lsp-server PRIVATE MLIRLspServerLib @@ -17,6 +18,7 @@ target_link_libraries(torch-mlir-lsp-server PRIVATE # TODO: Remove these in favor of interface deps. ${dialect_libs} ${conversion_libs} + ${extension_libs} ) mlir_check_all_link_libraries(torch-mlir-lsp-server) diff --git a/tools/torch-mlir-lsp-server/torch-mlir-lsp-server.cpp b/tools/torch-mlir-lsp-server/torch-mlir-lsp-server.cpp index ca76900250c1..a6d88a355483 100644 --- a/tools/torch-mlir-lsp-server/torch-mlir-lsp-server.cpp +++ b/tools/torch-mlir-lsp-server/torch-mlir-lsp-server.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" #include "torch-mlir/InitAll.h" @@ -18,6 +19,7 @@ using namespace mlir; int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); + registerAllExtensions(registry); mlir::torch::registerAllDialects(registry); return failed(MlirLspServerMain(argc, argv, registry)); } diff --git a/tools/torch-mlir-opt/CMakeLists.txt b/tools/torch-mlir-opt/CMakeLists.txt index 3fb003633431..94c547d0eb2d 100644 --- a/tools/torch-mlir-opt/CMakeLists.txt +++ b/tools/torch-mlir-opt/CMakeLists.txt @@ -7,6 +7,12 @@ COMPONENT torch-mlir-opt) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +set(dependency_libraries) +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND dependency_libraries StablehloRegister) +endif() target_link_libraries(torch-mlir-opt PRIVATE MLIROptLib @@ -15,4 +21,6 @@ target_link_libraries(torch-mlir-opt PRIVATE TorchMLIRTorchPasses ${dialect_libs} ${conversion_libs} + ${extension_libs} + ${dependency_libraries} ) diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index af76cc56d7fa..fa6a41a7097e 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -8,13 +8,12 @@ //===----------------------------------------------------------------------===// #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "torch-mlir/InitAll.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" #include "stablehlo/dialect/Register.h" #endif @@ -26,16 +25,11 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); + registerAllExtensions(registry); mlir::torch::registerAllDialects(registry); #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - mlir::mhlo::registerSymbolicShapeOptimizationPass(); - mlir::mhlo::registerStablehloLegalizeToHloPass(); - mlir::mhlo::registerChloLegalizeToHloPass(); - mlir::mhlo::registerHloLegalizeToLinalgPass(); - mlir::mhlo::registerTestUnfuseBatchNormPass(); #endif return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "MLIR modular optimizer driver\n", registry)); diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a8a81d7ccfaa..a7fde6168b8c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torchvision==0.16.0.dev20230710 +torchvision==0.17.0.dev20230922 diff --git a/utils/bazel/WORKSPACE.bazel b/utils/bazel/WORKSPACE.bazel index 374de7d39769..f7a81a4faf29 100644 --- a/utils/bazel/WORKSPACE.bazel +++ b/utils/bazel/WORKSPACE.bazel @@ -24,7 +24,7 @@ new_local_repository( path = "../../externals/llvm-project", ) -load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") +load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") llvm_configure( name = "llvm-project", @@ -36,11 +36,9 @@ llvm_configure( ], ) -llvm_disable_optional_support_deps() - local_repository( - name = "mlir-hlo", - path = "../../externals/mlir-hlo/", + name = "stablehlo", + path = "../../externals/stablehlo/", ) new_local_repository( @@ -125,3 +123,14 @@ maybe( "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz", ], ) + +maybe( + http_archive, + name = "llvm_zlib", + build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", + sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", + strip_prefix = "zlib-ng-2.0.7", + urls = [ + "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", + ], +) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index abfd3ea613a3..fa8fccd01500 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -104,6 +104,7 @@ cc_library( deps = [ ":MLIRTorchOpsIncGen", ":MLIRTorchTypesIncGen", + "@llvm-project//mlir:CastInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", @@ -448,9 +449,8 @@ cc_library( ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", "@llvm-project//mlir:Dialect", - "@mlir-hlo//:mlir_hlo", - "@mlir-hlo//:transforms_passes", - "@mlir-hlo//stablehlo:register", + "@stablehlo//:register", + "@stablehlo//:stablehlo_passes", ], ) @@ -810,6 +810,7 @@ cc_library( ":TorchMLIRTorchConversionPasses", ":TorchMLIRTorchDialect", ":TorchMLIRTorchPasses", + "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", @@ -826,6 +827,7 @@ cc_binary( ":TorchMLIRInitAll", ":TorchMLIRTorchDialect", ":TorchMLIRTorchPasses", + "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", ],