Skip to content

Commit

Permalink
Merge pull request #444 from Xilinx/bump_to_614fcdd1
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 614fcdd (Oct 08) (74)
  • Loading branch information
mgehre-amd authored Jan 8, 2025
2 parents f9df768 + 9be8dfd commit 90d34f2
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 17 deletions.
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// TODO: Handle the case where the dim(th) dimension is dynamic.
// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
Expand All @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) {
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
Expand Down
50 changes: 37 additions & 13 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
Expand All @@ -1199,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) {
// Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
weightShape[2], weightShape[3]};
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);
if (!inputZp) {
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::DepthwiseConv1DNcwCwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");

// currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw
Expand Down Expand Up @@ -1260,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType());
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,7 @@
"ContiguousModule_basic",
"Conv1dNoPaddingGroupModule_basic",
"Conv1dNoPaddingModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
Expand Down Expand Up @@ -1970,6 +1971,7 @@
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"ContiguousModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dNoPaddingGroupModule_basic",
"Conv1dNoPaddingModule_basic",
"Conv2dBiasNoPaddingModule_basic",
Expand Down Expand Up @@ -3645,6 +3647,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -4336,6 +4339,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias)


class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 4, 6], torch.float32, True),
([4, 1, 3], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv1d(
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
)


@register_test_case(
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
)
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
inputVec = tu.rand(2, 4, 6)
weight = torch.randn(4, 1, 3)
module.forward(inputVec, weight)


class Conv2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchToLinalg/squeeze.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

0 comments on commit 90d34f2

Please sign in to comment.