Skip to content

Commit

Permalink
[API/MemRef] Implement canonical stride validation for MemRefValue cr…
Browse files Browse the repository at this point in the history
…eation (#252)

Add optional stride validation in `MemRefValue::create` to compute
canonical stride and compare against given strides while creaing a
memref view from DLPack tensors. We need to handle special cases for
zero-sized and unit-sized dimensions since frameworks deal with them
arbitrarily while converting to the corresponding DLPack tensor. Add
Python tests to verify both canonical and non-canonical stride
validation.
  • Loading branch information
jhalakpatel authored Oct 8, 2024
1 parent 223cc67 commit 4dbc5cc
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ MLIR_CAPI_EXPORTED MTRT_Status
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides = false);

/// Creates an externally managed MemRef value. The caller provides all the
/// metadata for the MemRef including the shape, strides (in elements), pointer,
Expand All @@ -142,7 +143,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreateExternal(
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
const int64_t *shape, const int64_t *strides, MTRT_Device device,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides = false);

/// Destroys `MTRT_MemRefValue` in a potentially asynchronous manner.
/// If `buffer` is a device buffer, device memory is freed in the stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ class MemRefValue : public RuntimeValue {
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device,
std::optional<ScalarType> scalarType);
std::optional<ScalarType> scalarType,
std::optional<bool> assertCanonicalStrides = {});

mlirtrt::runtime::PointerType getBufferKind() { return addressSpace; }
int64_t getElementBitWidth() const { return bitsPerElement; }
Expand Down Expand Up @@ -917,15 +918,17 @@ class RuntimeClient {
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device = {},
std::optional<CudaStream> stream = {},
std::optional<ScalarType> scalarType = {});
std::optional<ScalarType> scalarType = {},
std::optional<bool> assertCanonicalStrides = {});

StatusOr<std::unique_ptr<MemRefValue>>
createExternalMemRef(PointerType addressSpace, int64_t bitsPerElement,
uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device = {},
std::optional<ScalarType> scalarType = {});
std::optional<ScalarType> scalarType = {},
std::optional<bool> assertCanonicalStrides = {});

/// Frees the memory in `value`. The `stream` may optionally be provided
/// for resources that can be deallocated asynchronously.
Expand Down
12 changes: 8 additions & 4 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ MTRT_Status
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides) {
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
unwrap(client)->allocateMemRef(
unwrap(pointerKind), bitsPerElement,
Expand All @@ -244,7 +245,8 @@ mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
: std::optional(unwrap(stream)->getRawStream()),
scalarType != MTRT_ScalarTypeCode::MTRT_ScalarTypeCode_unknown
? std::optional(ScalarType(unwrap(scalarType)))
: std::nullopt);
: std::nullopt,
std::optional(assertCanonicalStrides));

if (bufferImpl.isError())
return wrap(bufferImpl.getStatus());
Expand All @@ -257,7 +259,8 @@ MTRT_Status mtrtMemRefCreateExternal(
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
const int64_t *shape, const int64_t *strides, MTRT_Device device,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides) {
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
unwrap(client)->createExternalMemRef(
unwrap(pointerKind), bitsPerElement, ptr, offset,
Expand All @@ -267,7 +270,8 @@ MTRT_Status mtrtMemRefCreateExternal(
: std::optional(unwrap(device)),
scalarType == MTRT_ScalarTypeCode_unknown
? std::nullopt
: std::optional(ScalarType(unwrap(scalarType))));
: std::optional(ScalarType(unwrap(scalarType))),
std::optional(assertCanonicalStrides));

if (bufferImpl.isError())
return wrap(bufferImpl.getStatus());
Expand Down
63 changes: 57 additions & 6 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,50 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
return sizeBytes;
}

static llvm::SmallVector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
if (shape.empty())
return {};

llvm::SmallVector<int64_t> canonicalStride(shape.size(), 1);
int64_t cumulativeProduct = 1;

for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
// For dimensions with size 0 or 1, the stride can be arbitrary.
// We set it to 1 here, but other values would also be valid.
if (isFirstZeroDim || shape[dimIndex] == 1)
canonicalStride[dimIndex] = 1;
else
canonicalStride[dimIndex] = cumulativeProduct;
// For zero-sized dimensions (except the last one), we don't update the cumulative product
// This allows for consistent handling of zero-sized dimensions across different frameworks
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
}

return canonicalStride;
}

static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> stride,
llvm::ArrayRef<int64_t> expectedStride) {
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
return false;

for (size_t i = 0; i < shape.size(); ++i)
// Allow arbitrary strides for dimensions with size 0 or 1
// This accounts for discrepancies in how different frameworks handle these cases
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
return false;

return true;
}

StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device,
std::optional<ScalarType> scalarType) {
std::optional<const Device *> device, std::optional<ScalarType> scalarType,
std::optional<bool> assertCanonicalStrides) {
if (!client)
return getInvalidArgStatus("a valid RuntimeClient must be provided to "
"create a tracked MemRef object");
Expand All @@ -691,6 +729,19 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
return getInvalidArgStatus("a specific device must be provided for MemRefs "
"that are device-visible");

// Check if given strides match canonical stride
if (assertCanonicalStrides && *assertCanonicalStrides) {
llvm::SmallVector<int64_t> canonicalStride = getCanonicalStride(shape);
if (!strides.empty() &&
!areStridesEquivalent(shape, strides, canonicalStride)) {
std::string errorMsg =
llvm::formatv("Given strides [{0}] do not match canonical strides "
"[{1}] for shape [{2}]",
strides, canonicalStride, shape);
return getInvalidArgStatus(errorMsg.c_str());
}
}

return std::unique_ptr<MemRefValue>(
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
strides, device, scalarType));
Expand Down Expand Up @@ -777,7 +828,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
PointerType addressSpace, int64_t bitsPerElement,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device, std::optional<CudaStream> stream,
std::optional<ScalarType> scalarType) {
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
if (addressSpace == PointerType::device ||
addressSpace == PointerType::unified) {
if (!device || !*device)
Expand All @@ -800,7 +851,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
// Create the descriptor.
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
MemRefValue::create(this, addressSpace, bitsPerElement, allocation->ptr,
0, shape, strides, device, scalarType);
0, shape, strides, device, scalarType, assertCanonicalStrides);
if (bufferImpl.isError())
return bufferImpl.getStatus();

Expand All @@ -811,11 +862,11 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::createExternalMemRef(
PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr,
int64_t offset, llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> strides, std::optional<const Device *> device,
std::optional<ScalarType> scalarType) {
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
// Create the descriptor.
StatusOr<std::unique_ptr<MemRefValue>> memref =
MemRefValue::create(this, addressSpace, bitsPerElement, ptr, offset,
shape, strides, device, scalarType);
shape, strides, device, scalarType, assertCanonicalStrides);
if (!memref.isOk())
return memref.getStatus();

Expand Down
31 changes: 19 additions & 12 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ static std::unique_ptr<PyMemRefValue> createMemRef(
}

static std::unique_ptr<PyMemRefValue>
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule,
std::optional<bool> assertCanonicalStrides) {
DLManagedTensor *managedTensor = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(capsule.ptr(), "dltensor"));

Expand Down Expand Up @@ -368,14 +369,16 @@ createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
}

if (data) {
s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8,
reinterpret_cast<uintptr_t>(data), offset,
rank, shape, strides, device, elementType,
&result);
s = mtrtMemRefCreateExternal(
client, addressSpace, bytesPerElement * 8,
reinterpret_cast<uintptr_t>(data), offset, rank, shape, strides, device,
elementType, &result,
assertCanonicalStrides ? *assertCanonicalStrides : false);
} else {
s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, shape,
strides, device, mtrtStreamGetNull(), elementType,
&result);
s = mtrtMemRefCreate(
client, addressSpace, bytesPerElement * 8, rank, shape, strides, device,
mtrtStreamGetNull(), elementType, &result,
assertCanonicalStrides ? *assertCanonicalStrides : false);
}

THROW_IF_MTRT_ERROR(s);
Expand Down Expand Up @@ -788,11 +791,15 @@ PYBIND11_MODULE(_api, m) {
"returns a new memref and allocates uninitialized backing storage")
.def(
"create_memref_view_from_dlpack",
[](PyRuntimeClient &self, py::capsule capsule) {
return createMemRefViewFromDLPack(self, capsule).release();
[](PyRuntimeClient &self, py::capsule capsule,
std::optional<bool> assertCanonicalStrides) {
return createMemRefViewFromDLPack(self, capsule,
assertCanonicalStrides)
.release();
},
py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(),
py::keep_alive<0, 2>())
py::arg("dltensor") = py::none(),
py::arg("assert_canonical_strides") = py::none(),
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
.def(
"create_device_memref_view",
[](PyRuntimeClient &self, uintptr_t ptr, std::vector<int64_t> shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,54 @@ def create_dangling_memref():
# CHECK-LABEL: Test memref maintains data's lifetime
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]


def check_non_canonical_stride(client, assert_canonical_strides):
try:
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
a = cp.transpose(t)
memref = client.create_memref_view_from_dlpack(
a.__dlpack__(), assert_canonical_strides
)
except Exception as e:
print(f"Received error message: {str(e)}")


def check_canonical_stride(client, assert_canonical_strides):
try:
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
memref = client.create_memref_view_from_dlpack(
t.__dlpack__(), assert_canonical_strides
)
except Exception as e:
print(f"Received error message: {str(e)}")


def test_memref_strides():
print("Testing non-canonical stride: assert_canonical_strides = True")
non_canonical_result = check_non_canonical_stride(
client, assert_canonical_strides=True
)

print("Testing non-canonical stride: assert_canonical_strides = False")
non_canonical_result = check_non_canonical_stride(
client, assert_canonical_strides=False
)

print("Testing canonical stride: assert_canonical_strides = True")
canonical_result = check_canonical_stride(client, assert_canonical_strides=True)

print("Testing canonical stride: assert_canonical_strides = False")
canonical_result = check_canonical_stride(client, assert_canonical_strides=False)


print("Test memref strides")
test_memref_strides()

# CHECK-LABEL: Test memref strides
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = True
# CHECK-NEXT: Received error message: InvalidArgument: InvalidArgument:
# CHECK-SAME: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = False
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = True
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = False

0 comments on commit 4dbc5cc

Please sign in to comment.