diff --git a/README.md b/README.md index 70268ba729f0..b9d7a47595fa 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,23 @@ pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/ex ## Demos +### FxImporter ResNet18 +```shell +# Get the latest example if you haven't checked out the code +wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py + +# Run ResNet18 as a standalone script. +python projects/pt1/examples/fximporter_resnet18.py + +# Output +load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg +... +PyTorch prediction +[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)] +torch-mlir prediction +[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)] +``` + ### TorchScript ResNet18 Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend: diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f5844e442d29..377795d843d9 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenUniformOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - Value generator = adaptor.getGenerator(); - Location loc = op.getLoc(); - - if (!isa(generator.getType())) - return rewriter.notifyMatchFailure( - op, "The generator has to be None because only global default " - "generator is supported"); - - auto elements = cast(self.getType()).getShape(); - if (llvm::any_of(elements, - [](int64_t dim) { return dim == ShapedType::kDynamic; })) - return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); - auto shape_tensor = rewriter.create( - loc, rewriter.getI64TensorAttr(elements)); - auto outTy = getTypeConverter()->convertType(op.getType()); - auto outElemTy = cast(outTy).getElementType(); - Value from = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); - Value to = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); - rewriter.replaceOpWithNewOp( - op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); - return success(); -} - // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 566f1d15b6ad..b200063e1785 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo Linear.cpp ViewLike.cpp Reduction.cpp + Rng.cpp Pooling.cpp Utils.cpp diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h index fc28acfde29f..112d5d0ed374 100644 --- a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h @@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); +void populateRngOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToStablehloOptions &options); + } // namespace torch_to_stablehlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp new file mode 100644 index 000000000000..3cd440c957e9 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -0,0 +1,140 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_stablehlo; + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUniformOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + auto elements = cast(self.getType()).getShape(); + if (llvm::any_of(elements, + [](int64_t dim) { return dim == ShapedType::kDynamic; })) + return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); + auto shape_tensor = rewriter.create( + loc, rewriter.getI64TensorAttr(elements)); + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + Value from = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); + Value to = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); + rewriter.replaceOpWithNewOp( + op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRandnGeneratorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) { + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + } + llvm::SmallVector shape; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { + return rewriter.notifyMatchFailure(op, "size must be constant"); + } + + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + if (!isa(outElemTy)) { + return rewriter.notifyMatchFailure(op, + "only support output with float type"); + } + auto scalarTy = RankedTensorType::get({}, outElemTy); + + Value shapeTensor = rewriter.create( + loc, rewriter.getI64TensorAttr(shape)); + Value mean = rewriter.create( + loc, + DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0))); + Value var = rewriter.create( + loc, + DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0))); + + rewriter.replaceOpWithNewOp( + op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNormalFunctionalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + auto elements = cast(self.getType()).getShape(); + if (llvm::any_of(elements, + [](int64_t dim) { return dim == ShapedType::kDynamic; })) + return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); + auto shapeTensor = rewriter.create( + loc, rewriter.getI64TensorAttr(elements)); + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + Value mean = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy); + Value std = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy); + rewriter.replaceOpWithNewOp( + op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL); + return success(); +} + +void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp); + INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp); +#undef INSERT_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 4bcc02344e7d..9a3360bf9069 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -75,6 +75,8 @@ class ConvertTorchToStablehlo typeConverter, patterns, target, options); torch_to_stablehlo::populatePoolingOpPatternsAndLegality( typeConverter, patterns, target, options); + torch_to_stablehlo::populateRngOpPatternsAndLegality( + typeConverter, patterns, target, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 0c352d31ca80..38bc4d275bf1 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -27,98 +28,113 @@ template struct QuantInfo { template <> struct QuantInfo { static constexpr unsigned operandsToQuantize[1] = {0}; }; -template -class QuantizeOperands : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - llvm::SmallVector operands(op->getOperands()); - - bool dequanted = false; - auto f = [&dequanted](Value operand) { - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - return operand; - }; - - for (unsigned i : QuantInfo::operandsToQuantize) { - operands[i] = f(operands[i]); - } - if (!dequanted) { - return rewriter.notifyMatchFailure(op, "no dequantizations found"); - } - - rewriter.replaceOpWithNewOp(op, op.getType(), operands); - return success(); - } -}; +// A QCommutingOp is an Op satisfying: +// 1. Has at most one tensor operand at index 0 +// 2. Has a single output, which is a tensor +// 3. Satisfies the commutation relation: +// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant] +// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" +// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" +bool isQCommutingOp(mlir::Operation *op) { + // if adding a new commuting op here, be sure to add a + // RemoveUnused pattern for that op to clean up afterwards + return llvm::isa(op); +} -template -class QuantizeTransposedOperands : public OpRewritePattern { +// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant +// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> +// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops +// {Op1,Op2,...,Opk} with k <= depth. +// With depth = 0, this conversion will simply fuse any immediately quantizable +// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int +// operands)] +template +class QuantizeOperandsPastCommutingOps : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); - unsigned numOperands = operands.size(); bool dequanted = false; - for (unsigned i = 0; i < numOperands; i++) { - if (auto trans = operands[i].getDefiningOp()) { - auto transOperands = trans.getOperands(); - Value dequantOperand; - if (auto dequant = - transOperands[0].getDefiningOp()) { - dequantOperand = dequant.getOperand(); - if (auto quant = - dequantOperand - .getDefiningOp()) { - auto quantOperands = quant.getOperands(); - auto qType = quantOperands[0] - .getType() - .cast() - .getOptionalDtype(); - auto torchQType = - cast(quant.getType()).getOptionalDtype(); - auto transQTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - qType); - auto newQuantTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - torchQType); - Value newTrans = rewriter.create( - op.getLoc(), transQTy, quantOperands[0], transOperands[1], - transOperands[2]); - Value newQuant = - rewriter.create( - op.getLoc(), newQuantTy, newTrans, quantOperands[1], - quantOperands[2]); - operands[i] = newQuant; - dequanted = true; - } + + for (unsigned i : QuantInfo::operandsToQuantize) { + Value operand = operands[i]; + std::stack commutingOpStack; + Value dequantOpd, MPTQTOpd; + for (unsigned k = 0; k < depth + 1; k++) { + auto currOp = operand.getDefiningOp(); + // Case 0 : currOp is a nullptr (e.g., operand is a block argument) + if (!currOp) + break; + // Case 1 : currOp is a q commuting op (continue loop) + if (isQCommutingOp(currOp)) { + commutingOpStack.push(currOp); + // set operand to currOp for next k-iteration + operand = currOp->getOperand(0); + continue; + } + // Case 2 : currOp is a dequant op (end loop) + if (llvm::isa(currOp)) { + dequantOpd = currOp->getOperand(0); + auto MPTQTOp = + dequantOpd.getDefiningOp(); + MPTQTOpd = MPTQTOp.getOperand(0); } + // either a dequant was found or chain broken, so break loop + break; + } + + // move to next operand if this trace was unsuccessful + if (!MPTQTOpd) + continue; + + // a successful trace occured, so set dequant to true + dequanted = true; + + // rewrite stack + Value oldOpd = MPTQTOpd; + Type intDType = + cast(MPTQTOpd.getType()).getOptionalDtype(); + while (!commutingOpStack.empty()) { + // get front of the commuting op stack and replace its first operand + // with oldOpd + auto currOp = commutingOpStack.top(); + commutingOpStack.pop(); + llvm::SmallVector currOperands(currOp->getOperands()); + currOperands[0] = oldOpd; + // get new result type + auto oldType = cast(currOp->getResultTypes()[0]); + auto intType = + rewriter.getType(oldType.getSizes(), intDType); + // rewrite currOp to have new operands and result type + // store this as oldOpd for next loop + oldOpd = rewriter + .create(loc, (currOp->getName()).getIdentifier(), + currOperands, intType, currOp->getAttrs()) + ->getResult(0); } + + // stack is empty, so oldOpd is now the corrected verion of the + // SrcOp's original operand + // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp + auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); + auto qTorchType = + cast(dequantOpd.getType()).getOptionalDtype(); + auto newMPTQTType = rewriter.getType( + cast(operands[i].getType()).getSizes(), qTorchType); + operands[i] = rewriter.create( + loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } + if (!dequanted) { - return rewriter.notifyMatchFailure( - op, "no dequantized transpose inputs found."); + return rewriter.notifyMatchFailure(op, "No dequantizations found."); } + rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } @@ -356,11 +372,14 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, QuantizeOperands, - QuantizeOperands, QuantizeOperands, - QuantizeTransposedOperands, - QuantizeAccumulator, QuantizeOperands, - QuantizeTransposedOperands, QuantizeAccumulator, + RemoveUnused, RemoveUnused, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/projects/pt1/examples/_example_utils.py b/projects/pt1/examples/_example_utils.py new file mode 100644 index 000000000000..8f63b4fd4a63 --- /dev/null +++ b/projects/pt1/examples/_example_utils.py @@ -0,0 +1,52 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from PIL import Image +import requests + +import torch +from torchvision import transforms + + +DEFAULT_IMAGE_URL = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) +DEFAULT_LABEL_URL = ( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt" +) + + +def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL): + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" + } + img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") + # preprocessing pipeline + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + img_preprocessed = preprocess(img) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(url: str = DEFAULT_LABEL_URL): + classes_text = requests.get( + url=url, + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res, labels): + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 + top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] + return top3 diff --git a/projects/pt1/examples/fximporter_resnet18.py b/projects/pt1/examples/fximporter_resnet18.py new file mode 100644 index 000000000000..8776c42fa7e4 --- /dev/null +++ b/projects/pt1/examples/fximporter_resnet18.py @@ -0,0 +1,59 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from pathlib import Path + +import torch +import torch.utils._pytree as pytree +import torchvision.models as models +from torch_mlir import fx +from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend +from torch_mlir_e2e_test.configs.utils import ( + recursively_convert_to_numpy, +) + +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) + + +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) +labels = load_labels() + +resnet18 = models.resnet18(pretrained=True).eval() +module = fx.export_and_import( + resnet18, + torch.ones(1, 3, 224, 224), + output_type="linalg-on-tensors", + func_name=resnet18.__class__.__name__, +) +backend = refbackend.RefBackendLinalgOnTensorsBackend() +compiled = backend.compile(module) +fx_module = backend.load(compiled) + +params = { + **dict(resnet18.named_buffers(remove_duplicate=False)), +} +params_flat, params_spec = pytree.tree_flatten(params) +params_flat = list(params_flat) +with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(params_flat + [img]) + +golden_prediction = top3_possibilities(resnet18.forward(img), labels) +print("PyTorch prediction") +print(golden_prediction) + +prediction = top3_possibilities( + torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)), + labels, +) +print("torch-mlir prediction") +print(prediction) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index 0cc5b5dda96a..ea56653ca6f6 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -4,71 +4,36 @@ # Also available under a BSD-style license. See LICENSE. import sys - -from PIL import Image -import requests +from pathlib import Path import torch import torchvision.models as models -from torchvision import transforms - from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend - -def load_and_preprocess_image(url: str): - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" - } - img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) def predictions(torch_func, jit_func, img, labels): - golden_prediction = top3_possibilities(torch_func(img)) + golden_prediction = top3_possibilities(torch_func(img), labels) print("PyTorch prediction") print(golden_prediction) - prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) + prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())), labels) print("torch-mlir prediction") print(prediction) -image_url = ( - "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -) - -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) labels = load_labels() -resnet18 = models.resnet18(pretrained=True) -resnet18.train(False) +resnet18 = models.resnet18(pretrained=True).eval() module = torchscript.compile( resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" ) diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index d8ec0fa6495f..30f8716ebdf0 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -291,33 +291,6 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // ----- -// CHECK-LABEL: func.func @torch.aten.uniform( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] -// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] -// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> -// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> -// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> -// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor -// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> -// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> -// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor -// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> -func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> { - %none = torch.constant.none - %float0 = torch.constant.float 0.0 - %float1 = torch.constant.float 1.0 - %0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64> - return %0 : !torch.vtensor<[32, 64],f64> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { diff --git a/test/Conversion/TorchToStablehlo/rng.mlir b/test/Conversion/TorchToStablehlo/rng.mlir new file mode 100644 index 000000000000..31241caacb28 --- /dev/null +++ b/test/Conversion/TorchToStablehlo/rng.mlir @@ -0,0 +1,100 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.uniform( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> +// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> +// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> +// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> { + %none = torch.constant.none + %float0 = torch.constant.float 0.0 + %float1 = torch.constant.float 1.0 + %0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64> + return %0 : !torch.vtensor<[32, 64],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.randn.generator +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT32:.*]] = torch.constant.int 32 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> { + %none = torch.constant.none + %int32 = torch.constant.int 32 + %int64 = torch.constant.int 64 + %size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64> + return %0 : !torch.vtensor<[32, 64],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.randn.generator$f32 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT32:.*]] = torch.constant.int 32 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf32> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf32> -> !torch.vtensor<[32,64],f32> +// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f32> +func.func @torch.aten.randn.generator$f32() -> !torch.vtensor<[32, 64],f32> { + %none = torch.constant.none + %int32 = torch.constant.int 32 + %int64 = torch.constant.int 64 + %size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f32> + return %0 : !torch.vtensor<[32, 64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.normal_functional( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> +// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> +// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> +// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> { + %none = torch.constant.none + %mean = torch.constant.float 2.0 + %std = torch.constant.float 1.0 + %0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64> + return %0 : !torch.vtensor<[32, 64],f64> +} diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index f98cb842f5d3..594295d4e86d 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,6 +28,60 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- +// CHECK-LABEL: @matmul_commuting +func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.vtensor<[1,1024,1024],f32> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-128 = torch.constant.int -128 + %int2 = torch.constant.int 2 + %int128 = torch.constant.int 128 + %int1024 = torch.constant.int 1024 + %int12 = torch.constant.int 12 + %0 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float5.000000e-01, %int-128 : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + %2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %3 = torch.aten.slice.Tensor %1, %int0, %int1, %int2, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %4 = torch.prim.ListConstruct %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.reshape %2, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %6 = torch.aten.reshape %3, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %7 = torch.aten.transpose.int %5, %int1, %int2 : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + %8 = torch.aten.quantize_per_tensor %7, %float5.000000e-01, %int0, %int12 : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %9 = torch.aten.int_repr %8 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float5.000000e-01, %int0 : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %11 = torch.aten.dequantize.self %10 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],f32> + %12 = torch.aten.matmul %11, %6 : !torch.vtensor<[1,1024,128],f32>, !torch.vtensor<[1,128,1024],f32> -> !torch.vtensor<[1,1024,1024],f32> + + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[IN128:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[I128:.+]] = torch.constant.int 128 + // CHECK-DAG: %[[I1024:.+]] = torch.constant.int 1024 + // CHECK-DAG: %[[I12:.+]] = torch.constant.int 12 + // CHECK-DAG: %[[MPTQT0:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[IN128]] : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + // CHECK-DAG: %[[DQ0:.+]] = torch.aten.dequantize.self %[[MPTQT0]] : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + // CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %[[DQ0]], %[[I0]], %[[I0]], %[[I1]], %[[I1]] : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I1]], %[[I128]], %[[I1024]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[RESHAPE0:.+]] = torch.aten.reshape %[[SLICE0]], %[[LIST]] : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + // CHECK-DAG: %[[TR0:.+]] = torch.aten.transpose.int %[[RESHAPE0]], %[[I1]], %[[I2]] : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + // CHECK-DAG: %[[Q0:.+]] = torch.aten.quantize_per_tensor %[[TR0]], %[[HALF]], %[[I0]], %[[I12]] : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[IR0:.+]] = torch.aten.int_repr %[[Q0]] : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + // CHECK-DAG: %[[MPTQT1:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR0]], %[[HALF]], %[[I0]] : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[I0]], %[[I1]], %[[I2]], %[[I1]] : !torch.vtensor<[2,128,32,32],si8>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],si8> + // CHECK-DAG: %[[RESHAPE1:.+]] = torch.aten.reshape %[[SLICE1]], %[[LIST]] : !torch.vtensor<[1,128,32,32],si8>, !torch.list -> !torch.vtensor<[1,128,1024],si8> + // CHECK-DAG: %[[MPTQT2:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[RESHAPE1]], %[[HALF]], %[[IN128]] : !torch.vtensor<[1,128,1024],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1024],!torch.qint8> + // CHECK-DAG: %[[MATMUL:.+]] = torch.aten.matmul %[[MPTQT1]], %[[MPTQT2]] : !torch.vtensor<[1,1024,128],!torch.qint8>, !torch.vtensor<[1,128,1024],!torch.qint8> -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[IR1:.+]] = torch.aten.int_repr %[[MATMUL]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],si32> + // CHECK-DAG: %[[MPTQT3:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR1]], %[[QUARTER]], %[[I0]] : !torch.vtensor<[1,1024,1024],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[DQ1:.+]] = torch.aten.dequantize.tensor %[[MPTQT3]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],f32> + return %12 : !torch.vtensor<[1,1024,1024],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 @@ -43,21 +97,21 @@ func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch. %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> - // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 - // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 - // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> - // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> - // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> - // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> - // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> - // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> - // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK-DAG: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> }