Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with f08bfc4f (Oct 04) (66) #431

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure();

auto shapeSizes = shapeType.getSizes();
int64_t dataRank = dataType.getSizes().size();
ArrayRef<int64_t> dataShape = dataType.getSizes();
int64_t dataRank = dataShape.size();
int64_t shapeRank = shapeSizes.size();
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
return failure();
Expand All @@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// we are using torch implementation Torch::AtenBroadcastToOp which
// takes list of int
for (int i = 0; i < shapeSizes[0]; i++) {
// extract dim from shape
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
loc, selectResultType, shape, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
Value selectDim = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), extract);

if (i + rankDifference >= 0) {
// compute dim to pass to broadcast op. For non-broadcastable dims,
// pass -1
Value dim;
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
// broadcasted
// 2. we will explicitly disallow broadcasting dynamic dims that are
// secretly 1.
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
// Assert dataShape[i + rankDiff] >= selectDim. If both are
// constant, this should fold out.
Value iv =
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
loc, rewriter.getType<Torch::IntType>(), data, iv);
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
Value gtSelect =
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
rewriter.create<Torch::RuntimeAssertOp>(
loc, gtSelect,
rewriter.getStringAttr(
"onnx.Expand input has a dim that is not statically 1; "
"expected this dim >= dim provided shape."));
} else {
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
// dataRank)
// 2. selectDims which correspond to dataShape == 1 get included in
// broadcast
dim = selectDim;
}

dimList.push_back(dim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def import_onnx(contents):
# Import the ONNX model proto from the file contents:
raw_model = onnx.load_from_string(contents)
# since it does not affect current e2e tests, data_prop is left false here
model_proto = onnx.shape_inference.infer_shapes(raw_model)
model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)

# Import the ONNX module into an MLIR module:
context = Context()
Expand Down
20 changes: 8 additions & 12 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
Expand All @@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
// CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
// CHECK-NEXT: torch.runtime.assert %[[GE]]
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
// CHECK: return %[[EXPAND]]
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
Expand Down
Loading