Skip to content

Commit

Permalink
Merge pull request #269 from Xilinx/bump_to_911e7235
Browse files Browse the repository at this point in the history
[AutoBump] Merge with 911e723 (May 13) (36)
  • Loading branch information
mgehre-amd authored Sep 4, 2024
2 parents 01c9f23 + 375d6bf commit c17c667
Show file tree
Hide file tree
Showing 13 changed files with 560 additions and 203 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/ex

## Demos

### FxImporter ResNet18
```shell
# Get the latest example if you haven't checked out the code
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py

# Run ResNet18 as a standalone script.
python projects/pt1/examples/fximporter_resnet18.py

# Output
load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
...
PyTorch prediction
[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)]
torch-mlir prediction
[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)]
```

### TorchScript ResNet18

Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:
Expand Down
32 changes: 1 addition & 31 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");

auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
return success();
}

// Converts `aten.empty.memory_format` to `tensor.empty` op.
template <>
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
Expand Down Expand Up @@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);

INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToStablehlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
Linear.cpp
ViewLike.cpp
Reduction.cpp
Rng.cpp
Pooling.cpp
Utils.cpp

Expand Down
5 changes: 5 additions & 0 deletions lib/Conversion/TorchToStablehlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);

void populateRngOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);

} // namespace torch_to_stablehlo
} // namespace torch
} // namespace mlir
Expand Down
140 changes: 140 additions & 0 deletions lib/Conversion/TorchToStablehlo/Rng.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"

#include "mlir/IR/BuiltinTypes.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;

template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");

auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM);
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenRandnGeneratorOp>::matchAndRewrite(
AtenRandnGeneratorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(generator.getType())) {
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
}
llvm::SmallVector<int64_t> shape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) {
return rewriter.notifyMatchFailure(op, "size must be constant");
}

auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
if (!isa<mlir::FloatType>(outElemTy)) {
return rewriter.notifyMatchFailure(op,
"only support output with float type");
}
auto scalarTy = RankedTensorType::get({}, outElemTy);

Value shapeTensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(shape));
Value mean = rewriter.create<stablehlo::ConstantOp>(
loc,
DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0)));
Value var = rewriter.create<stablehlo::ConstantOp>(
loc,
DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0)));

rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL);
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenNormalFunctionalOp>::matchAndRewrite(
AtenNormalFunctionalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value generator = adaptor.getGenerator();
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");

auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
auto shapeTensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value mean =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy);
Value std =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy);
rewriter.replaceOpWithNewOp<stablehlo::RngOp>(
op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL);
return success();
}

void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)

INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp);
INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp);
#undef INSERT_ATENOP_PATTERN
}
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class ConvertTorchToStablehlo
typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateRngOpPatternsAndLegality(
typeConverter, patterns, target, options);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down
Loading

0 comments on commit c17c667

Please sign in to comment.