Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with a33d1232 (Sep 27) (60) #422

Merged
merged 10 commits into from
Dec 17, 2024
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,

Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);

Value castIntToIndex(OpBuilder &b, Location loc, Value v);

Expand Down
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9224,6 +9224,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
}];
}

def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$pos_weight,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 8 additions & 8 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,29 +1662,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto shapeType = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{inputRank},
resultType.getOptionalDtype());

Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
binder.getLoc(), shapeType, operand);

if (inputRank == 0) {
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
binder.op, resultType, shape);
return success();
}

if (start == 0 && end == -1) {
rewriter.replaceOp(binder.op, shape);
return success();
}

Value sv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(start));

Value ev = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(end));

Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);

Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);

shape = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), resultType, shape, dim, sv, ev, step);

rewriter.replaceOp(binder.op, shape);
rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
binder.op, resultType, shape, dim, sv, ev, step);
return success();
});

Expand Down
75 changes: 75 additions & 0 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,79 @@ class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
};
} // namespace

namespace {
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Type inputElementType =
cast<RankedTensorType>(input.getType()).getElementType();

// Converting the input element type to the result's element type.
// The only possible mismatch would be when the input element type is an
// integer but not `si64`. Therefore, we directly convert the input to
// `si64`. Rest all cases are handled in the dtype definition for this op.
if (elementType != inputElementType) {
Value torchInput = convertTensorToDtype(
rewriter, loc, op.getSelf(),
rewriter.getIntegerType(64, IntegerType::Signed));
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(torchInput.getType()),
torchInput);
}

int64_t inputRank = resultType.getRank();
Value dtype = op.getDtype();
if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(
op, "unsupported: dtype argument not supported");

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "invalid dim");

SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
output = rewriter.create<tensor::CastOp>(loc, resultType, output);

SmallVector<Value> accSizes(sizes);
accSizes.erase(accSizes.begin() + dim);
SmallVector<int64_t> accStatic(
makeShapeTorchCompatible(resultType.getShape()));
accStatic.erase(accStatic.begin() + dim);
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
Type accType =
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);

Value result = createTMTensorScanOp(
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
[](OpBuilder &b, Location loc, Value input, Value acc) {
Value prod =
(isa<mlir::FloatType>(input.getType())
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
b.create<TMTensor::YieldOp>(loc, prod);
});

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace

namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public:
Expand Down Expand Up @@ -2243,6 +2316,8 @@ class ConvertTorchToTMTensor
patterns.add<ConvertAtenSortOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
target.addIllegalOp<AtenCumprodOp>();
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context);
Expand Down
10 changes: 10 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
}

Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c1 =
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
}

Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9155,6 +9155,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -10307,6 +10310,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %0 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
Expand Down Expand Up @@ -11878,6 +11893,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int4 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %2#1 : !torch.int\n"
" }\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down Expand Up @@ -14650,6 +14684,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
73 changes: 73 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8799,6 +8799,77 @@ class DecomposeAtenCrossEntropyLossOp
};
} // namespace

namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto self = op.getSelf();
auto target = op.getTarget();
auto posWeight = op.getPosWeight();
auto weight = op.getWeight();
auto reduction = op.getReduction();

Value loss;
auto one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto _one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));

auto _target =
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
_target, one, one);
Value mm =
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
Value logSigm =
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);

if (!isa<Torch::NoneType>(posWeight.getType())) {
auto logWeight = rewriter.create<AtenAddScalarOp>(
loc, posWeight.getType(),
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
one, one),
one, one);
loss = rewriter.create<AtenSubTensorOp>(
loc, mm.getType(), mm,
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
logSigm),
one);
} else {
loss =
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
}

if (!isa<Torch::NoneType>(weight.getType())) {
loss =
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
}

// apply loss reduction.
int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
}

auto none = rewriter.create<ConstantNoneOp>(loc);
Value res;
if (reductionInt == 1) {
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
} else if (reductionInt == 2) {
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
} else {
res = loss;
}

rewriter.replaceOp(op, res);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
Expand Down Expand Up @@ -10025,6 +10096,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
Expand Down
Loading
Loading