Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of 1d48596 (May 10)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 27, 2024
2 parents b42b9a4 + 1d48596 commit e4364c6
Show file tree
Hide file tree
Showing 16 changed files with 225 additions and 123 deletions.
12 changes: 9 additions & 3 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ While this is running, you can already setup the Python venv and dependencies in

## Setup your Python VirtualEnvironment and Dependencies

Also, ensure that you have the appropriate `python-dev` package installed
to access the Python development libraries / headers.

```shell
python -m venv mlir_venv
source mlir_venv/bin/activate
Expand All @@ -26,6 +23,15 @@ python -m pip install -r requirements.txt
python -m pip install -r torchvision-requirements.txt
```

Also, ensure that you have the appropriate `python-dev` package installed
to access the Python development libraries / headers. For example, you can install
it with the following `apt` command on Ubuntu/Debian.

```shell
sudo apt install python3-dev
```


## (Optional) Set up pre-commit

This project uses [pre-commit](https://pre-commit.com/) in its CI. You can
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
result = rewriter.create<Torch::AtenMaximumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
rewriter.replaceOp(binder.op, result);
return success();
});
patterns.onOp(
Expand All @@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
result = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
rewriter.replaceOp(binder.op, result);
return success();
});
patterns.onOp("Neg", 1,
Expand Down
27 changes: 26 additions & 1 deletion lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ namespace {

template <typename T> struct DimensionTraits {};

template <> struct DimensionTraits<AtenMaxPool1dOp> {
static constexpr int64_t Dim = 1;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
Expand Down Expand Up @@ -359,7 +365,24 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {

Type elementType = cast<RankedTensorType>(self.getType()).getElementType();

if constexpr (Dim == 2) {
if constexpr (Dim == 1) {
SmallVector<Value, 4> outTensorShape;
Value maxPool1d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddedInput, maxPool1d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool1d);
return success();
} else if constexpr (Dim == 2) {
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
Expand Down Expand Up @@ -1123,8 +1146,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool1dOp>();
target.addIllegalOp<AtenMaxPool2dOp>();
target.addIllegalOp<AtenMaxPool3dOp>();
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);

Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10696,6 +10696,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
16 changes: 10 additions & 6 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,8 @@
"Matmul_matvec",
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool1dStaticModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
Expand Down Expand Up @@ -2082,6 +2084,9 @@
"CumsumStaticNegativeDimModule_basic",
"CumsumInputDtypeInt32Module_basic",
"EyeStaticModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
Expand Down Expand Up @@ -2552,6 +2557,11 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
Expand Down Expand Up @@ -2867,12 +2877,6 @@
"SliceCopyMax_Module_basic",
}

if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: Found dtype (torch.float64) but expected (torch.float32)
"ReduceL1NormWithDTypeModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/examples/example-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
datasets
transformers
requests
pillow
106 changes: 0 additions & 106 deletions projects/pt1/examples/torchdynamo_resnet18.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False
)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
outf.write(module.operation.get_asm())

print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def forward(self, data):
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True
)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
outf.write(module.operation.get_asm())

print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,11 @@ def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], kernel_size=[2]))
def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Loading

0 comments on commit e4364c6

Please sign in to comment.