Skip to content

Commit

Permalink
Fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Jan 7, 2025
1 parent 7b9b6fc commit 7371c8d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ class IndicesContiguousCounter {
} // namespace

// Decomposes ScatterNDs into a single Split and Concat.
// We can always split an ScatterNDs by splitting the input tensor together with
// We can always split ScatterNDs by splitting the input tensor together with
// the indices and their updates belonging to that part of the input tensor,
// performing the ScatterNDs on each split, and the concatenating the result.
// Here, we handle certain ScatterNDs where after splitting them into three,
Expand Down Expand Up @@ -998,8 +998,8 @@ struct DecomposeScatterNDPattern : public OpRewritePattern<ONNXScatterNDOp> {
// -- The expected index is calculated the following way:
// --- The expected index is initialized with the first index in indices and
// then always incremented by one.
// --- The increment works like an manual addition, the least significant
// digit/subindex gets incremented by one. If an digit overflows, it
// --- The increment works like a manual addition, the least significant
// digit/subindex gets incremented by one. If a digit overflows, it
// gets reset to the first index and the addition carries to the next,
// more significant digit. The addition overflows, if the index for an
// axis is equal to the size of this axis in updates/indices. (By
Expand Down
8 changes: 4 additions & 4 deletions test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1148,21 +1148,21 @@ func.func @test_scatter_nd_dynamic(%data : tensor<*xf32>, %updates : tensor<1x1x
// CHECK: onnx.ScatterND

// -----
func.func @test_scatter_nd_mulit_dim_differ(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> {
func.func @test_scatter_nd_multi_dim_differ(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> {
%indices = onnx.Constant dense<[[[[0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], [0, 1, 9]]]]> : tensor<1x1x10x3xi64>
%0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<2x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<2x6x10x12xf32>
onnx.Return %0 : tensor<2x6x10x12xf32>
}
// CHECK-LABEL: func.func @test_scatter_nd_mulit_dim_differ
// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_differ
// CHECK: onnx.ScatterND

// -----
func.func @test_scatter_nd_mulit_dim_differ_multi_shift(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> {
func.func @test_scatter_nd_multi_dim_differ_multi_shift(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> {
%indices = onnx.Constant dense<[[[[1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3], [1, 1, 4], [1, 1, 5], [1, 1, 6], [1, 1, 7], [1, 1, 8], [1, 1, 9]]]]> : tensor<1x1x10x3xi64>
%0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<2x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<2x6x10x12xf32>
onnx.Return %0 : tensor<2x6x10x12xf32>
}
// CHECK-LABEL: func.func @test_scatter_nd_mulit_dim_differ_multi_shift
// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_differ_multi_shift
// CHECK: onnx.ScatterND

// -----
Expand Down

0 comments on commit 7371c8d

Please sign in to comment.