diff --git a/thrust/testing/mr_pool.cu b/thrust/testing/mr_pool.cu index 30c1f18a476..d4f8a3056ed 100644 --- a/thrust/testing/mr_pool.cu +++ b/thrust/testing/mr_pool.cu @@ -123,7 +123,7 @@ public: virtual tracked_pointer do_allocate(std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override { - ASSERT_EQUAL(static_cast(id_to_allocate), true); + ASSERT_EQUAL(id_to_allocate || id_to_allocate == -1u, true); void * raw = upstream.do_allocate(n, alignment); tracked_pointer ret(raw); @@ -131,15 +131,18 @@ public: ret.size = n; ret.alignment = alignment; - id_to_allocate = 0; + if (id_to_allocate != -1u) + { + id_to_allocate = 0; + } return ret; } virtual void do_deallocate(tracked_pointer p, std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override { - ASSERT_EQUAL(p.size, n); - ASSERT_EQUAL(p.alignment, alignment); + ASSERT_GEQUAL(p.size, n); + ASSERT_GEQUAL(p.alignment, alignment); if (id_to_deallocate != 0) { @@ -318,6 +321,36 @@ void TestPoolCachingOversized() upstream.id_to_allocate = 7; tracked_pointer a9 = pool.do_allocate(2048, 32); ASSERT_EQUAL(a9.id, 7u); + + // make sure that reusing a larger oversized block for a smaller allocation works + // this is NVIDIA/cccl#585 + upstream.id_to_allocate = 8; + tracked_pointer a10 = pool.do_allocate(2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT); + pool.do_deallocate(a10, 2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT); + tracked_pointer a11 = pool.do_allocate(2048, THRUST_MR_DEFAULT_ALIGNMENT); + ASSERT_EQUAL(a11.ptr, a10.ptr); + pool.do_deallocate(a11, 2048, THRUST_MR_DEFAULT_ALIGNMENT); + + // original minimized reproducer from NVIDIA/cccl#585: + { + upstream.id_to_allocate = -1u; + + auto ptr1 = pool.allocate(43920240); + auto ptr2 = pool.allocate(2465264); + pool.deallocate(ptr1, 43920240); + pool.deallocate(ptr2, 2465264); + auto ptr3 = pool.allocate(4930528); + pool.deallocate(ptr3, 4930528); + auto ptr4 = pool.allocate(14640080); + std::memset(thrust::raw_pointer_cast(ptr4), 0xff, 14640080); + + auto crash = pool.allocate(4930528); + + pool.deallocate(crash, 4930528); + pool.deallocate(ptr4, 14640080); + + upstream.id_to_allocate = 0; + } } void TestUnsynchronizedPoolCachingOversized() diff --git a/thrust/thrust/mr/pool.h b/thrust/thrust/mr/pool.h index 0214f484b15..433e6409325 100644 --- a/thrust/thrust/mr/pool.h +++ b/thrust/thrust/mr/pool.h @@ -154,15 +154,17 @@ class unsynchronized_pool_resource final private: typedef typename Upstream::pointer void_ptr; - typedef typename thrust::detail::pointer_traits::template rebind::other char_ptr; + typedef thrust::detail::pointer_traits void_ptr_traits; + typedef typename void_ptr_traits::template rebind::other char_ptr; struct block_descriptor; struct chunk_descriptor; struct oversized_block_descriptor; - typedef typename thrust::detail::pointer_traits::template rebind::other block_descriptor_ptr; - typedef typename thrust::detail::pointer_traits::template rebind::other chunk_descriptor_ptr; - typedef typename thrust::detail::pointer_traits::template rebind::other oversized_block_descriptor_ptr; + typedef typename void_ptr_traits::template rebind::other block_descriptor_ptr; + typedef typename void_ptr_traits::template rebind::other chunk_descriptor_ptr; + typedef typename void_ptr_traits::template rebind::other oversized_block_descriptor_ptr; + typedef thrust::detail::pointer_traits oversized_block_ptr_traits; struct block_descriptor { @@ -194,6 +196,7 @@ class unsynchronized_pool_resource final oversized_block_descriptor_ptr prev; oversized_block_descriptor_ptr next; oversized_block_descriptor_ptr next_cached; + std::size_t current_size; }; struct pool @@ -244,17 +247,20 @@ class unsynchronized_pool_resource final } // deallocate cached oversized/overaligned memory - while (detail::pointer_traits::get(m_oversized)) + while (oversized_block_ptr_traits::get(m_oversized)) { oversized_block_descriptor_ptr alloc = m_oversized; m_oversized = thrust::raw_reference_cast(*m_oversized).next; + oversized_block_descriptor desc = + thrust::raw_reference_cast(*alloc); + void_ptr p = static_cast( - static_cast( - static_cast(alloc) - ) - thrust::raw_reference_cast(*alloc).size - ); - m_upstream->do_deallocate(p, thrust::raw_reference_cast(*alloc).size + sizeof(oversized_block_descriptor), thrust::raw_reference_cast(*alloc).alignment); + static_cast(static_cast(alloc)) - + desc.current_size); + m_upstream->do_deallocate( + p, desc.size + sizeof(oversized_block_descriptor), + desc.alignment); } m_cached_oversized = oversized_block_descriptor_ptr(); @@ -272,7 +278,7 @@ class unsynchronized_pool_resource final { oversized_block_descriptor_ptr ptr = m_cached_oversized; oversized_block_descriptor_ptr * previous = &m_cached_oversized; - while (detail::pointer_traits::get(ptr)) + while (oversized_block_ptr_traits::get(ptr)) { oversized_block_descriptor desc = *ptr; bool is_good = desc.size >= bytes && desc.alignment >= alignment; @@ -305,9 +311,7 @@ class unsynchronized_pool_resource final { if (previous != &m_cached_oversized) { - oversized_block_descriptor previous_desc = **previous; - previous_desc.next_cached = desc.next_cached; - **previous = previous_desc; + *previous = desc.next_cached; } else { @@ -315,13 +319,31 @@ class unsynchronized_pool_resource final } desc.next_cached = oversized_block_descriptor_ptr(); + + auto ret = + static_cast(static_cast(ptr)) - + desc.size; + + if (bytes != desc.size) { + desc.current_size = bytes; + + ptr = static_cast( + static_cast(ret + bytes)); + + if (oversized_block_ptr_traits::get(desc.prev)) { + thrust::raw_reference_cast(*desc.prev).next = ptr; + } else { + m_oversized = ptr; + } + + if (oversized_block_ptr_traits::get(desc.next)) { + thrust::raw_reference_cast(*desc.next).prev = ptr; + } + } + *ptr = desc; - return static_cast( - static_cast( - static_cast(ptr) - ) - desc.size - ); + return static_cast(ret); } previous = &thrust::raw_reference_cast(*ptr).next_cached; @@ -343,10 +365,11 @@ class unsynchronized_pool_resource final desc.prev = oversized_block_descriptor_ptr(); desc.next = m_oversized; desc.next_cached = oversized_block_descriptor_ptr(); + desc.current_size = bytes; *block = desc; m_oversized = block; - if (detail::pointer_traits::get(desc.next)) + if (oversized_block_ptr_traits::get(desc.next)) { oversized_block_descriptor next = *desc.next; next.prev = block; @@ -439,7 +462,7 @@ class unsynchronized_pool_resource final assert(detail::is_power_of_2(alignment)); // verify that the pointer is at least as aligned as claimed - assert(reinterpret_cast(detail::pointer_traits::get(p)) % alignment == 0); + assert(reinterpret_cast(void_ptr_traits::get(p)) % alignment == 0); // the deallocated block is oversized and/or overaligned if (n > m_options.largest_block_size || alignment > m_options.alignment) @@ -451,35 +474,44 @@ class unsynchronized_pool_resource final ); oversized_block_descriptor desc = *block; + assert(desc.current_size == n); + assert(desc.alignment == alignment); if (m_options.cache_oversized) { desc.next_cached = m_cached_oversized; - *block = desc; + + if (desc.size != n) { + desc.current_size = desc.size; + block = static_cast( + static_cast(static_cast(p) + + desc.size)); + if (oversized_block_ptr_traits::get(desc.prev)) { + thrust::raw_reference_cast(*desc.prev).next = block; + } else { + m_oversized = block; + } + + if (oversized_block_ptr_traits::get(desc.next)) { + thrust::raw_reference_cast(*desc.next).prev = block; + } + } + m_cached_oversized = block; + *block = desc; return; } - if (!detail::pointer_traits::get(desc.prev)) - { - assert(m_oversized == block); + if (oversized_block_ptr_traits::get( + desc.prev)) { + thrust::raw_reference_cast(*desc.prev).next = desc.next; + } else { m_oversized = desc.next; } - else - { - oversized_block_descriptor prev = *desc.prev; - assert(prev.next == block); - prev.next = desc.next; - *desc.prev = prev; - } - if (detail::pointer_traits::get(desc.next)) - { - oversized_block_descriptor next = *desc.next; - assert(next.prev == block); - next.prev = desc.prev; - *desc.next = next; + if (oversized_block_ptr_traits::get(desc.next)) { + thrust::raw_reference_cast(*desc.next).prev = desc.prev; } m_upstream->do_deallocate(p, desc.size + sizeof(oversized_block_descriptor), desc.alignment);