@@ -58,30 +58,43 @@ namespace gpu
5858{
5959
6060template <typename T>
61- class TypedAllocator : public thrust::device_allocator<T>
62- {
63- public:
61+ struct TypedAllocator {
6462 using value_type = T;
65- using pointer = T*;
63+ using pointer = thrust::device_ptr<T>;
64+ using const_pointer = thrust::device_ptr<const T>;
65+ using size_type = std::size_t;
66+ using difference_type = std::ptrdiff_t;
67+
68+ TypedAllocator() noexcept : mInternalAllocator(nullptr) {}
69+ explicit TypedAllocator(ExternalAllocator* a) noexcept : mInternalAllocator(a) {}
6670
6771 template <typename U>
68- struct rebind {
69- using other = TypedAllocator<U>;
70- };
72+ TypedAllocator(const TypedAllocator<U>& o) noexcept : mInternalAllocator(o.mInternalAllocator)
73+ {
74+ }
7175
72- explicit TypedAllocator(ExternalAllocator* allocPtr)
73- : mInternalAllocator(allocPtr) {}
76+ pointer allocate(size_type n)
77+ {
78+ void* raw = mInternalAllocator->allocate(n * sizeof(T));
79+ return thrust::device_pointer_cast(static_cast<T*>(raw));
80+ }
7481
75- T* allocate(size_t n)
82+ void deallocate(pointer p, size_type n) noexcept
7683 {
77- return reinterpret_cast<T*>(mInternalAllocator->allocate(n * sizeof(T)));
84+ if (!p) {
85+ return;
86+ }
87+ void* raw = thrust::raw_pointer_cast(p);
88+ mInternalAllocator->deallocate(static_cast<char*>(raw), n * sizeof(T));
7889 }
7990
80- void deallocate(T* p, size_t n)
91+ bool operator==(TypedAllocator const& o) const noexcept
92+ {
93+ return mInternalAllocator == o.mInternalAllocator;
94+ }
95+ bool operator!=(TypedAllocator const& o) const noexcept
8196 {
82- char* raw_ptr = reinterpret_cast<char*>(p);
83- size_t bytes = n * sizeof(T);
84- mInternalAllocator->deallocate(raw_ptr, bytes); // redundant as internal dealloc is no-op.
97+ return !(*this == o);
8598 }
8699
87100 private:
0 commit comments