forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #269 from Xilinx/bump_to_911e7235
[AutoBump] Merge with 911e723 (May 13) (36)
- Loading branch information
Showing
13 changed files
with
560 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "./PopulatePatterns.h" | ||
|
||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "stablehlo/dialect/StablehloOps.h" | ||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
using namespace mlir::torch::torch_to_stablehlo; | ||
|
||
template <> | ||
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite( | ||
AtenUniformOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
Value self = adaptor.getSelf(); | ||
Value generator = adaptor.getGenerator(); | ||
Location loc = op.getLoc(); | ||
|
||
if (!isa<Torch::NoneType>(generator.getType())) | ||
return rewriter.notifyMatchFailure( | ||
op, "The generator has to be None because only global default " | ||
"generator is supported"); | ||
|
||
auto elements = cast<RankedTensorType>(self.getType()).getShape(); | ||
if (llvm::any_of(elements, | ||
[](int64_t dim) { return dim == ShapedType::kDynamic; })) | ||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); | ||
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>( | ||
loc, rewriter.getI64TensorAttr(elements)); | ||
auto outTy = getTypeConverter()->convertType(op.getType()); | ||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType(); | ||
Value from = | ||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); | ||
Value to = | ||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); | ||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>( | ||
op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); | ||
return success(); | ||
} | ||
|
||
template <> | ||
LogicalResult ConvertAtenOp<AtenRandnGeneratorOp>::matchAndRewrite( | ||
AtenRandnGeneratorOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
Value generator = adaptor.getGenerator(); | ||
Location loc = op.getLoc(); | ||
|
||
if (!isa<Torch::NoneType>(generator.getType())) { | ||
return rewriter.notifyMatchFailure( | ||
op, "The generator has to be None because only global default " | ||
"generator is supported"); | ||
} | ||
llvm::SmallVector<int64_t> shape; | ||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { | ||
return rewriter.notifyMatchFailure(op, "size must be constant"); | ||
} | ||
|
||
auto outTy = getTypeConverter()->convertType(op.getType()); | ||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType(); | ||
if (!isa<mlir::FloatType>(outElemTy)) { | ||
return rewriter.notifyMatchFailure(op, | ||
"only support output with float type"); | ||
} | ||
auto scalarTy = RankedTensorType::get({}, outElemTy); | ||
|
||
Value shapeTensor = rewriter.create<stablehlo::ConstantOp>( | ||
loc, rewriter.getI64TensorAttr(shape)); | ||
Value mean = rewriter.create<stablehlo::ConstantOp>( | ||
loc, | ||
DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0))); | ||
Value var = rewriter.create<stablehlo::ConstantOp>( | ||
loc, | ||
DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0))); | ||
|
||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>( | ||
op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL); | ||
return success(); | ||
} | ||
|
||
template <> | ||
LogicalResult ConvertAtenOp<AtenNormalFunctionalOp>::matchAndRewrite( | ||
AtenNormalFunctionalOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
Value self = adaptor.getSelf(); | ||
Value generator = adaptor.getGenerator(); | ||
Location loc = op.getLoc(); | ||
|
||
if (!isa<Torch::NoneType>(generator.getType())) | ||
return rewriter.notifyMatchFailure( | ||
op, "The generator has to be None because only global default " | ||
"generator is supported"); | ||
|
||
auto elements = cast<RankedTensorType>(self.getType()).getShape(); | ||
if (llvm::any_of(elements, | ||
[](int64_t dim) { return dim == ShapedType::kDynamic; })) | ||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); | ||
auto shapeTensor = rewriter.create<stablehlo::ConstantOp>( | ||
loc, rewriter.getI64TensorAttr(elements)); | ||
auto outTy = getTypeConverter()->convertType(op.getType()); | ||
auto outElemTy = cast<RankedTensorType>(outTy).getElementType(); | ||
Value mean = | ||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy); | ||
Value std = | ||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy); | ||
rewriter.replaceOpWithNewOp<stablehlo::RngOp>( | ||
op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL); | ||
return success(); | ||
} | ||
|
||
void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality( | ||
TypeConverter &typeConverter, RewritePatternSet &patterns, | ||
ConversionTarget &target, const TorchToStablehloOptions &options) { | ||
MLIRContext *context = patterns.getContext(); | ||
|
||
#define INSERT_ATENOP_PATTERN(AtenOp) \ | ||
target.addIllegalOp<AtenOp>(); \ | ||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options) | ||
|
||
INSERT_ATENOP_PATTERN(AtenUniformOp); | ||
INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp); | ||
INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp); | ||
#undef INSERT_ATENOP_PATTERN | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.