Skip to content

Commit

Permalink
Merge pull request #185 from Xilinx/matthias.revert_to_upstream
Browse files Browse the repository at this point in the history
Reduce our diff compared to upstream
  • Loading branch information
mgehre-amd authored Jun 21, 2024
2 parents 7e834f9 + 5710f3c commit 0b089c8
Show file tree
Hide file tree
Showing 10 changed files with 4 additions and 44 deletions.
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3066,6 +3066,7 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Operation *op, Value x, Type dtype) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();

auto loc = op->getLoc();

// buildNormalCdf, mean = zero, sigma = one
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,

auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);

if (dtype) {
return rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7538,8 +7538,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArcSinCosOp<AtenAsinOp>>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArcSinCosOp<AtenAcosOp>>(
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);

newStart = rewriter.create<PrimMinIntOp>(op.getLoc(), newStart, dimSize);
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);
Expand Down
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,6 @@
"ElementwiseAcosTensorIntModule_basic",
"ElementwiseAsinTensorIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"Im2ColModule_basic",
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic",
"PrimsSumFloatModule_basic",
"RepeatInterleaveFillModule_basic",
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def _get_decomposition_table():
aten._native_batch_norm_legit,
aten.squeeze,
aten.cumsum,
aten.im2col,
aten.index_select,
aten.linalg_vector_norm,
aten.eye,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"RepeatInterleaveModule_basic",
"Im2ColModule_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}
Expand Down
20 changes: 0 additions & 20 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5141,26 +5141,6 @@ def forward(self, x):
def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))

# ==============================================================================

class Im2Col_Module(torch.nn.Module):

def __init__(self):
super().__init__()
self.tensor = torch.ones(2, 3)

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]);

@register_test_case(module_factory=lambda: Im2Col_Module())
def Im2ColModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4,5,2))


# ==============================================================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1850,22 +1850,6 @@ def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
# ==============================================================================


class EyeStaticModule(torch.nn.Module):
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3, 5)


@register_test_case(module_factory=lambda: EyeStaticModule())
def EyeStaticModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================


class EmptyStridedModule(torch.nn.Module):

def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,4 +419,4 @@ def forward(self, a, b):

@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))

0 comments on commit 0b089c8

Please sign in to comment.