Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of e9ed4af9 (69) #438

Merged
merged 35 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e9ed4af
[TOSA] Add legalization for aten.index_select (#3760)
justin-ngo-arm Oct 4, 2024
53f7532
Revert "[TorchToLinalg] perform rank0 elementwise computations outsid…
rsuderman Oct 4, 2024
f4840ed
[ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering …
AmosLewis Oct 6, 2024
b08d086
[TOSA] Add legalization for fill, flip, and round (#3768)
justin-ngo-arm Oct 7, 2024
f6721e5
[MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (…
vivekkhandelwal1 Oct 8, 2024
614fcdd
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
vivekkhandelwal1 Oct 8, 2024
58489fa
torch.aten.squeeze.dim lowering with dynamic dims (#3749)
jinchen62 Oct 8, 2024
757fee4
[AutoBump] Merge with fixes of e9ed4af9 (Oct 04)
mgehre-amd Jan 2, 2025
604aaec
Update xfail
mgehre-amd Jan 2, 2025
9cfdc65
[AutoBump] Merge with fixes of 53f7532e (Oct 04)
mgehre-amd Jan 2, 2025
3bfe046
Merge branch 'bump_to_2374b9e0' into bump_to_e9ed4af9
mgehre-amd Jan 2, 2025
a64a7c9
Merge remote-tracking branch 'origin/bump_to_e9ed4af9' into bump_to_5…
mgehre-amd Jan 2, 2025
76a95f2
Fix xfail
mgehre-amd Jan 3, 2025
6a63518
Merge branch 'bump_to_e9ed4af9' into bump_to_53f7532e
mgehre-amd Jan 3, 2025
5409334
[AutoBump] Merge with f4840ed8 (Oct 06)
mgehre-amd Jan 3, 2025
60131e6
[AutoBump] Merge with fixes of b08d0868 (Oct 07)
mgehre-amd Jan 3, 2025
c0eb38e
Update xfail
mgehre-amd Jan 3, 2025
2ee058d
[AutoBump] Merge with f6721e59 (Oct 08)
mgehre-amd Jan 3, 2025
f2f3960
[AutoBump] Merge with fixes of 614fcdd1 (Oct 08)
mgehre-amd Jan 6, 2025
40a686a
bump
mgehre-amd Jan 6, 2025
5b21918
Merge branch 'bump_to_b08d0868' into bump_to_f6721e59
mgehre-amd Jan 6, 2025
7d14c99
Merge branch 'bump_to_f6721e59' into bump_to_614fcdd1
mgehre-amd Jan 6, 2025
32c2a54
[AutoBump] Merge with fixes of 58489faf (Oct 08)
mgehre-amd Jan 6, 2025
ef59423
bump
mgehre-amd Jan 6, 2025
9d1eb7d
Merge branch 'bump_to_b08d0868' into bump_to_f6721e59
mgehre-amd Jan 6, 2025
d5922c5
Merge branch 'bump_to_f6721e59' into bump_to_614fcdd1
mgehre-amd Jan 6, 2025
e15c436
Merge branch 'bump_to_614fcdd1' into bump_to_58489faf
mgehre-amd Jan 6, 2025
a217546
Merge pull request #439 from Xilinx/bump_to_53f7532e
mgehre-amd Jan 8, 2025
13132ac
Merge pull request #441 from Xilinx/bump_to_f4840ed8
mgehre-amd Jan 8, 2025
16f7253
Merge pull request #442 from Xilinx/bump_to_b08d0868
mgehre-amd Jan 8, 2025
f9df768
Merge pull request #443 from Xilinx/bump_to_f6721e59
mgehre-amd Jan 8, 2025
9be8dfd
Merge pull request #445 from Xilinx/bump_to_58489faf
mgehre-amd Jan 8, 2025
90d34f2
Merge pull request #444 from Xilinx/bump_to_614fcdd1
mgehre-amd Jan 8, 2025
bf3b4b7
Merge remote-tracking branch 'origin/feature/backport_ea1_ops' into b…
mgehre-amd Jan 9, 2025
cc708c2
Merge remote-tracking branch 'origin/bump_to_e9ed4af9' into bump_to_e…
mgehre-amd Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

// Flips an input tensor based on the values of axis list.
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
SmallVector<int64_t> axis);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
11 changes: 7 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

// TODO: Implement max and min cases
if (reduction == "mul") {
reduction = "multiply";
reduction = "prod";
} else if (reduction == "max" || reduction == "min") {
return rewriter.notifyMatchFailure(
binder.op, "max/min reduction unsupported for scatter elements");
} else if (reduction == "add") {
reduction = "sum";
}

Value cstStrReduction =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);

rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
binder.op, resultType, data, constAxis, indices, updates,
cstStrReduction);
cstStrReduction, cstTrue);
return success();
});
patterns.onOp(
Expand Down
64 changes: 51 additions & 13 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
int64_t &dim,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Expand All @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");

Expand Down Expand Up @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// TODO: Handle the case where the dim(th) dimension is dynamic.
// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
Expand All @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) {
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
Expand Down Expand Up @@ -1857,14 +1864,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

// If stride is negative, then flip the input tensor corresponding to that
// dim, update the stride for flipped tensor by multiplying it by -1, and
// update the offset as follows:
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
//
// For example:
// Input = [0, 1, 2, 3, 4, 5]
// stride = [-2], result_shape = [2], offset = [3]
// Result = [3, 1]
// After flipping:
// Input = [5, 4, 3, 2, 1, 0]
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
// Result = [3, 1]

Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
SmallVector<int64_t>{dim});
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, strides[dim], zero);
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
Value resShapeMulStride =
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value flippedOffset =
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
offsets[dim] = rewriter.create<arith::SelectOp>(
loc, isNegativeStride, flippedOffset, offsets[dim]);

input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
flippedInput, input);

SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Expand Down Expand Up @@ -2095,12 +2134,11 @@ class ConvertAtenSliceScatterOp
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

Expand Down
89 changes: 38 additions & 51 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType =
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

SmallVector<int64_t> axis;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
Expand All @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
}
}

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);

SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, self.getType(), self, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(
loc, dims[flipDim], indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);

return success();
}
};
Expand Down Expand Up @@ -1221,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
Expand All @@ -1236,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) {
// Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
weightShape[2], weightShape[3]};
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);
if (!inputZp) {
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::DepthwiseConv1DNcwCwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");

// currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw
Expand Down Expand Up @@ -1297,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType());
Expand Down
41 changes: 41 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
.getResult(0);
return success();
}

// Flips an input tensor based on the values of axis list.
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> axis) {
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);

SmallVector<utils::IteratorType> iteratorTypes(selfRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(), input, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
return flipped;
}
Loading