Skip to content

Commit

Permalink
thrust/mr: fix the case of reuising a block for a smaller alloc. (#1232)
Browse files Browse the repository at this point in the history
* thrust/mr: fix the case of reuising a block for a smaller alloc.

Previously, the pool happily returned a pointer to a larger oversized
block than requested, without storing the information that the block is
now smaller, which meant that on deallocation, it'd look for the
descriptor of the block in the wrong place. This is now fixed by moving
the descriptor to always be where deallocation can find it using the
user-provided size, and by storing the original size to restore the
descriptor to its rightful place when deallocating.

Also a drive-by fix for a bug where in certain cases the reallocated
cached oversized block wasn't removed from the cached list. Whoops.
Kinda surprised this hasn't exploded before.

* thrust/mr: add aliases to reused pointer traits in pool.h
  • Loading branch information
griwes committed Jan 24, 2024
1 parent da0e1f2 commit 7622c4b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 42 deletions.
41 changes: 37 additions & 4 deletions thrust/testing/mr_pool.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,26 @@ public:

virtual tracked_pointer<void> do_allocate(std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(static_cast<bool>(id_to_allocate), true);
ASSERT_EQUAL(id_to_allocate || id_to_allocate == -1u, true);

void * raw = upstream.do_allocate(n, alignment);
tracked_pointer<void> ret(raw);
ret.id = id_to_allocate;
ret.size = n;
ret.alignment = alignment;

id_to_allocate = 0;
if (id_to_allocate != -1u)
{
id_to_allocate = 0;
}

return ret;
}

virtual void do_deallocate(tracked_pointer<void> p, std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(p.size, n);
ASSERT_EQUAL(p.alignment, alignment);
ASSERT_GEQUAL(p.size, n);
ASSERT_GEQUAL(p.alignment, alignment);

if (id_to_deallocate != 0)
{
Expand Down Expand Up @@ -318,6 +321,36 @@ void TestPoolCachingOversized()
upstream.id_to_allocate = 7;
tracked_pointer<void> a9 = pool.do_allocate(2048, 32);
ASSERT_EQUAL(a9.id, 7u);

// make sure that reusing a larger oversized block for a smaller allocation works
// this is NVIDIA/cccl#585
upstream.id_to_allocate = 8;
tracked_pointer<void> a10 = pool.do_allocate(2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
pool.do_deallocate(a10, 2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
tracked_pointer<void> a11 = pool.do_allocate(2048, THRUST_MR_DEFAULT_ALIGNMENT);
ASSERT_EQUAL(a11.ptr, a10.ptr);
pool.do_deallocate(a11, 2048, THRUST_MR_DEFAULT_ALIGNMENT);

// original minimized reproducer from NVIDIA/cccl#585:
{
upstream.id_to_allocate = -1u;

auto ptr1 = pool.allocate(43920240);
auto ptr2 = pool.allocate(2465264);
pool.deallocate(ptr1, 43920240);
pool.deallocate(ptr2, 2465264);
auto ptr3 = pool.allocate(4930528);
pool.deallocate(ptr3, 4930528);
auto ptr4 = pool.allocate(14640080);
std::memset(thrust::raw_pointer_cast(ptr4), 0xff, 14640080);

auto crash = pool.allocate(4930528);

pool.deallocate(crash, 4930528);
pool.deallocate(ptr4, 14640080);

upstream.id_to_allocate = 0;
}
}

void TestUnsynchronizedPoolCachingOversized()
Expand Down
108 changes: 70 additions & 38 deletions thrust/thrust/mr/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,17 @@ class unsynchronized_pool_resource final

private:
typedef typename Upstream::pointer void_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<char>::other char_ptr;
typedef thrust::detail::pointer_traits<void_ptr> void_ptr_traits;
typedef typename void_ptr_traits::template rebind<char>::other char_ptr;

struct block_descriptor;
struct chunk_descriptor;
struct oversized_block_descriptor;

typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<block_descriptor>::other block_descriptor_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<block_descriptor>::other block_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
typedef thrust::detail::pointer_traits<oversized_block_descriptor_ptr> oversized_block_ptr_traits;

struct block_descriptor
{
Expand Down Expand Up @@ -194,6 +196,7 @@ class unsynchronized_pool_resource final
oversized_block_descriptor_ptr prev;
oversized_block_descriptor_ptr next;
oversized_block_descriptor_ptr next_cached;
std::size_t current_size;
};

struct pool
Expand Down Expand Up @@ -244,17 +247,20 @@ class unsynchronized_pool_resource final
}

// deallocate cached oversized/overaligned memory
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(m_oversized))
while (oversized_block_ptr_traits::get(m_oversized))
{
oversized_block_descriptor_ptr alloc = m_oversized;
m_oversized = thrust::raw_reference_cast(*m_oversized).next;

oversized_block_descriptor desc =
thrust::raw_reference_cast(*alloc);

void_ptr p = static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(alloc)
) - thrust::raw_reference_cast(*alloc).size
);
m_upstream->do_deallocate(p, thrust::raw_reference_cast(*alloc).size + sizeof(oversized_block_descriptor), thrust::raw_reference_cast(*alloc).alignment);
static_cast<char_ptr>(static_cast<void_ptr>(alloc)) -
desc.current_size);
m_upstream->do_deallocate(
p, desc.size + sizeof(oversized_block_descriptor),
desc.alignment);
}

m_cached_oversized = oversized_block_descriptor_ptr();
Expand All @@ -272,7 +278,7 @@ class unsynchronized_pool_resource final
{
oversized_block_descriptor_ptr ptr = m_cached_oversized;
oversized_block_descriptor_ptr * previous = &m_cached_oversized;
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(ptr))
while (oversized_block_ptr_traits::get(ptr))
{
oversized_block_descriptor desc = *ptr;
bool is_good = desc.size >= bytes && desc.alignment >= alignment;
Expand Down Expand Up @@ -305,23 +311,39 @@ class unsynchronized_pool_resource final
{
if (previous != &m_cached_oversized)
{
oversized_block_descriptor previous_desc = **previous;
previous_desc.next_cached = desc.next_cached;
**previous = previous_desc;
*previous = desc.next_cached;
}
else
{
m_cached_oversized = desc.next_cached;
}

desc.next_cached = oversized_block_descriptor_ptr();

auto ret =
static_cast<char_ptr>(static_cast<void_ptr>(ptr)) -
desc.size;

if (bytes != desc.size) {
desc.current_size = bytes;

ptr = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(ret + bytes));

if (oversized_block_ptr_traits::get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = ptr;
} else {
m_oversized = ptr;
}

if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = ptr;
}
}

*ptr = desc;

return static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(ptr)
) - desc.size
);
return static_cast<void_ptr>(ret);
}

previous = &thrust::raw_reference_cast(*ptr).next_cached;
Expand All @@ -343,10 +365,11 @@ class unsynchronized_pool_resource final
desc.prev = oversized_block_descriptor_ptr();
desc.next = m_oversized;
desc.next_cached = oversized_block_descriptor_ptr();
desc.current_size = bytes;
*block = desc;
m_oversized = block;

if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
if (oversized_block_ptr_traits::get(desc.next))
{
oversized_block_descriptor next = *desc.next;
next.prev = block;
Expand Down Expand Up @@ -439,7 +462,7 @@ class unsynchronized_pool_resource final
assert(detail::is_power_of_2(alignment));

// verify that the pointer is at least as aligned as claimed
assert(reinterpret_cast<detail::intmax_t>(detail::pointer_traits<void_ptr>::get(p)) % alignment == 0);
assert(reinterpret_cast<detail::intmax_t>(void_ptr_traits::get(p)) % alignment == 0);

// the deallocated block is oversized and/or overaligned
if (n > m_options.largest_block_size || alignment > m_options.alignment)
Expand All @@ -451,35 +474,44 @@ class unsynchronized_pool_resource final
);

oversized_block_descriptor desc = *block;
assert(desc.current_size == n);
assert(desc.alignment == alignment);

if (m_options.cache_oversized)
{
desc.next_cached = m_cached_oversized;
*block = desc;

if (desc.size != n) {
desc.current_size = desc.size;
block = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(static_cast<char_ptr>(p) +
desc.size));
if (oversized_block_ptr_traits::get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = block;
} else {
m_oversized = block;
}

if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = block;
}
}

m_cached_oversized = block;
*block = desc;

return;
}

if (!detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.prev))
{
assert(m_oversized == block);
if (oversized_block_ptr_traits::get(
desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = desc.next;
} else {
m_oversized = desc.next;
}
else
{
oversized_block_descriptor prev = *desc.prev;
assert(prev.next == block);
prev.next = desc.next;
*desc.prev = prev;
}

if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
{
oversized_block_descriptor next = *desc.next;
assert(next.prev == block);
next.prev = desc.prev;
*desc.next = next;
if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = desc.prev;
}

m_upstream->do_deallocate(p, desc.size + sizeof(oversized_block_descriptor), desc.alignment);
Expand Down

0 comments on commit 7622c4b

Please sign in to comment.