Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/bump_to_e9ed4af9' into bump_to_e…
Browse files Browse the repository at this point in the history
…9ed4af9
  • Loading branch information
mgehre-amd committed Jan 9, 2025
2 parents bf3b4b7 + 90d34f2 commit cc708c2
Show file tree
Hide file tree
Showing 11 changed files with 510 additions and 147 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

// Flips an input tensor based on the values of axis list.
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
SmallVector<int64_t> axis);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
11 changes: 7 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

// TODO: Implement max and min cases
if (reduction == "mul") {
reduction = "multiply";
reduction = "prod";
} else if (reduction == "max" || reduction == "min") {
return rewriter.notifyMatchFailure(
binder.op, "max/min reduction unsupported for scatter elements");
} else if (reduction == "add") {
reduction = "sum";
}

Value cstStrReduction =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);

rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
binder.op, resultType, data, constAxis, indices, updates,
cstStrReduction);
cstStrReduction, cstTrue);
return success();
});
patterns.onOp(
Expand Down
64 changes: 51 additions & 13 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
int64_t &dim,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Expand All @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");

Expand Down Expand Up @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// TODO: Handle the case where the dim(th) dimension is dynamic.
// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
Expand All @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) {
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
Expand Down Expand Up @@ -1857,14 +1864,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

// If stride is negative, then flip the input tensor corresponding to that
// dim, update the stride for flipped tensor by multiplying it by -1, and
// update the offset as follows:
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
//
// For example:
// Input = [0, 1, 2, 3, 4, 5]
// stride = [-2], result_shape = [2], offset = [3]
// Result = [3, 1]
// After flipping:
// Input = [5, 4, 3, 2, 1, 0]
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
// Result = [3, 1]

Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
SmallVector<int64_t>{dim});
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, strides[dim], zero);
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
Value resShapeMulStride =
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value flippedOffset =
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
offsets[dim] = rewriter.create<arith::SelectOp>(
loc, isNegativeStride, flippedOffset, offsets[dim]);

input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
flippedInput, input);

SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Expand Down Expand Up @@ -2095,12 +2134,11 @@ class ConvertAtenSliceScatterOp
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

Expand Down
89 changes: 38 additions & 51 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType =
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

SmallVector<int64_t> axis;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
Expand All @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
}
}

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);

SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, self.getType(), self, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(
loc, dims[flipDim], indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);

return success();
}
};
Expand Down Expand Up @@ -1221,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
Expand All @@ -1236,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) {
// Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
weightShape[2], weightShape[3]};
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);
if (!inputZp) {
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::DepthwiseConv1DNcwCwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");

// currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw
Expand Down Expand Up @@ -1297,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType());
Expand Down
41 changes: 41 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
.getResult(0);
return success();
}

// Flips an input tensor based on the values of axis list.
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> axis) {
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);

SmallVector<utils::IteratorType> iteratorTypes(selfRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(), input, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
return flipped;
}
Loading

0 comments on commit cc708c2

Please sign in to comment.