Skip to content

Commit

Permalink
Merge pull request #443 from Xilinx/bump_to_f6721e59
Browse files Browse the repository at this point in the history
[AutoBump] Merge with f6721e5 (Oct 08) (73)
  • Loading branch information
mgehre-amd authored Jan 8, 2025
2 parents 16f7253 + 9d1eb7d commit f9df768
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 47 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

// Flips an input tensor based on the values of axis list.
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
SmallVector<int64_t> axis);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
49 changes: 40 additions & 9 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
int64_t &dim,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Expand All @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");

Expand Down Expand Up @@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

// If stride is negative, then flip the input tensor corresponding to that
// dim, update the stride for flipped tensor by multiplying it by -1, and
// update the offset as follows:
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
//
// For example:
// Input = [0, 1, 2, 3, 4, 5]
// stride = [-2], result_shape = [2], offset = [3]
// Result = [3, 1]
// After flipping:
// Input = [5, 4, 3, 2, 1, 0]
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
// Result = [3, 1]

Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
SmallVector<int64_t>{dim});
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, strides[dim], zero);
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
Value resShapeMulStride =
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value flippedOffset =
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
offsets[dim] = rewriter.create<arith::SelectOp>(
loc, isNegativeStride, flippedOffset, offsets[dim]);

input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
flippedInput, input);

SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Expand Down Expand Up @@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

Expand Down
39 changes: 1 addition & 38 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType =
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

SmallVector<int64_t> axis;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
Expand All @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
}
}

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);

SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, self.getType(), self, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(
loc, dims[flipDim], indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);

return success();
}
};
Expand Down
41 changes: 41 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
.getResult(0);
return success();
}

// Flips an input tensor based on the values of axis list.
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> axis) {
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);

SmallVector<utils::IteratorType> iteratorTypes(selfRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(), input, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
return flipped;
}

0 comments on commit f9df768

Please sign in to comment.