Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of 346a536
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 22, 2024
2 parents a7209ff + 346a536 commit 11f0b6b
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 221 deletions.
21 changes: 5 additions & 16 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
class ConvertAtenIndexPutHackedTwinOp
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Expand All @@ -699,17 +699,6 @@ class ConvertAten_IndexPutImplOp
return rewriter.notifyMatchFailure(
op, "unimplemented: the values tensor type must have sizes.");

// The unsafe should be either `False` or `none`.
if (!op.getUnsafe().getType().isa<Torch::NoneType>()) {
bool unsafe;
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe must be a constant");
else if (unsafe)
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe is expected to be false");
}

// The accumulate should be a torch constant of boolean type.
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
Expand Down Expand Up @@ -1624,8 +1613,8 @@ class ConvertTorchToTMTensor
RewritePatternSet patterns(context);
target.addIllegalOp<AtenBincountOp>();
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
target.addIllegalOp<Aten_IndexPutImplOp>();
patterns.add<ConvertAten_IndexPutImplOp>(typeConverter, context);
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
patterns.add<ConvertAtenIndexPutHackedTwinOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context);
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3970,8 +3970,8 @@ class SimplifyAten_IndexPutImplOp

// Handle Aten_IndexPutImplOp on 1d tensors
template <>
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
Aten_IndexPutImplOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// TOSA scatter:
// // Copy the values_in tensor to the values_out tensor.
Expand Down Expand Up @@ -6227,7 +6227,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
Expand Down
Loading

0 comments on commit 11f0b6b

Please sign in to comment.