diff --git a/rpcs3/util/shared_ptr.hpp b/rpcs3/util/shared_ptr.hpp index 6850849579..c3a9bc06d1 100644 --- a/rpcs3/util/shared_ptr.hpp +++ b/rpcs3/util/shared_ptr.hpp @@ -22,7 +22,7 @@ namespace stx constexpr bool is_same_ptr_v = true; template - constexpr bool is_same_ptr_cast_v = std::is_same_v || std::is_convertible_v && is_same_ptr_v; + constexpr bool is_same_ptr_cast_v = std::is_same_v || (std::is_convertible_v && is_same_ptr_v); template class single_ptr; @@ -36,13 +36,13 @@ namespace stx // Basic assumption of userspace pointer size constexpr uint c_ptr_size = 47; - // Use lower 17 bits as atomic_ptr internal refcounter (pointer is shifted) + // Use lower 17 bits as atomic_ptr internal counter of borrowed refs (pointer itself is shifted) constexpr uint c_ref_mask = 0x1ffff, c_ref_size = 17; struct shared_counter { // Stored destructor - void (*destroy)(shared_counter* _this); + atomic_t destroy{}; // Reference counter atomic_t refs{1}; @@ -113,8 +113,6 @@ namespace stx constexpr single_ptr() noexcept = default; - constexpr single_ptr(std::nullptr_t) noexcept {} - single_ptr(const single_ptr&) = delete; single_ptr(single_ptr&& r) noexcept @@ -138,11 +136,6 @@ namespace stx single_ptr& operator=(const single_ptr&) = delete; - single_ptr& operator=(std::nullptr_t) noexcept - { - reset(); - } - single_ptr& operator=(single_ptr&& r) noexcept { m_ptr = r.m_ptr; @@ -164,7 +157,7 @@ namespace stx if (m_ptr) [[likely]] { const auto o = d(); - o->destroy(o); + o->destroy.load()(o); m_ptr = nullptr; } } @@ -257,7 +250,7 @@ namespace stx } } - ptr->m_ctr.destroy = [](shared_counter* _this) + ptr->m_ctr.destroy.raw() = [](shared_counter* _this) noexcept { delete reinterpret_cast*>(reinterpret_cast(_this) - offsetof(shared_data, m_ctr)); }; @@ -313,7 +306,7 @@ namespace stx ptr->m_count = count; - ptr->m_ctr.destroy = [](shared_counter* _this) + ptr->m_ctr.destroy.raw() = [](shared_counter* _this) noexcept { shared_data* ptr = reinterpret_cast*>(reinterpret_cast(_this) - offsetof(shared_data, m_ctr)); @@ -360,8 +353,6 @@ namespace stx constexpr shared_ptr() noexcept = default; - constexpr shared_ptr(std::nullptr_t) noexcept {} - shared_ptr(const shared_ptr& r) noexcept : m_ptr(r.m_ptr) { @@ -649,7 +640,7 @@ namespace stx if (v >> c_ref_size && !o->refs.sub_fetch(c_ref_mask + 1 - (v & c_ref_mask))) { - o->destroy(o); + o->destroy.load()(o); } } @@ -657,7 +648,7 @@ namespace stx atomic_ptr& operator=(std::remove_cv_t value) noexcept { shared_type r = make_single(std::move(value)); - r.d()->refs += c_ref_mask; + r.d()->refs.raw() += c_ref_mask; atomic_ptr old; old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size); @@ -688,6 +679,11 @@ namespace stx return *this; } + void reset() noexcept + { + store(shared_type{}); + } + shared_type load() const noexcept { shared_type r; @@ -714,8 +710,8 @@ namespace stx r.m_ptr = std::launder(reinterpret_cast(prev >> c_ref_size)); r.d()->refs++; - // Dereference if same pointer - m_val.fetch_op([prev = prev](uptr& val) + // Dereference if still the same pointer + const auto [_, did_deref] = m_val.fetch_op([prev = prev](uptr& val) { if (val >> c_ref_size == prev >> c_ref_size) { @@ -726,6 +722,12 @@ namespace stx return false; }); + if (!did_deref) + { + // Otherwise fix ref count (atomic_ptr has been overwritten) + r.d()->refs--; + } + return r; } @@ -734,6 +736,81 @@ namespace stx return load(); } + // Atomically inspect pointer with the possibility to reference it if necessary + template > + RT peek_op(F op) const noexcept + { + shared_type r; + + // Add reference + const auto [prev, did_ref] = m_val.fetch_op([](uptr& val) + { + if (val >> c_ref_size) + { + val++; + return true; + } + + return false; + }); + + // Set fake unreferenced pointer + if (did_ref) + { + r.m_ptr = std::launder(reinterpret_cast(prev >> c_ref_size)); + } + + // Result temp storage + std::conditional_t, int, RT> result; + + // Invoke + if constexpr (std::is_void_v) + { + std::invoke(op, std::as_const(r)); + + if (!did_ref) + { + return; + } + } + else + { + result = std::invoke(op, std::as_const(r)); + + if (!did_ref) + { + return result; + } + } + + // Dereference if still the same pointer + const auto [_, did_deref] = m_val.fetch_op([prev = prev](uptr& val) + { + if (val >> c_ref_size == prev >> c_ref_size) + { + val--; + return true; + } + + return false; + }); + + if (did_deref) + { + // Deactivate fake pointer + r.m_ptr = nullptr; + } + + if constexpr (std::is_void_v) + { + return; + } + else + { + return result; + } + } + template >> void store(Args&&... args) noexcept { @@ -786,29 +863,165 @@ namespace stx return value; } - // bool compare_exchange(shared_type& cmp_and_old, shared_type exch) - // { - // } + // Ineffective + [[nodiscard]] bool compare_exchange(shared_type& cmp_and_old, shared_type exch) + { + const uptr _old = reinterpret_cast(cmp_and_old.m_ptr); + const uptr _new = reinterpret_cast(exch.m_ptr); - // template >> - // shared_type compare_and_swap(const shared_ptr& cmp, shared_type exch) - // { - // } + if (exch.m_ptr) + { + exch.d().refs += c_ref_mask; + } - // template >> - // bool compare_and_swap_test(const shared_ptr& cmp, shared_type exch) - // { - // } + atomic_ptr old; - // template >> - // shared_type compare_and_swap(const single_ptr& cmp, shared_type exch) - // { - // } + const uptr _val = m_val.fetch_op([&](uptr& val) + { + if (val >> c_ref_size == _old) + { + // Set new value + val = _new << c_ref_size; + } + else if (val) + { + // Reference previous value + val++; + } + }); - // template >> - // bool compare_and_swap_test(const single_ptr& cmp, shared_type exch) - // { - // } + if (_val >> c_ref_size == _old) + { + // Success (exch is consumed, cmp_and_old is unchanged) + if (exch.m_ptr) + { + exch.m_ptr = nullptr; + } + + // Cleanup + old.m_val.raw() = _val; + return true; + } + + atomic_ptr old_exch; + old_exch.m_val.raw() = reinterpret_cast(std::exchange(exch.m_ptr, nullptr)) << 17; + + // Set to reset old cmp_and_old value + old.m_val.raw() = (cmp_and_old.m_ptr << c_ref_size) | c_ref_mask; + + if (!_val) + { + return false; + } + + // Set referenced pointer + cmp_and_old.m_ptr = std::launder(reinterpret_cast(_val >> c_ref_size)); + cmp_and_old.d()->refs++; + + // Dereference if still the same pointer + const auto [_, did_deref] = m_val.fetch_op([_val](uptr& val) + { + if (val >> c_ref_size == _val >> c_ref_size) + { + val--; + return true; + } + + return false; + }); + + if (!did_deref) + { + // Otherwise fix ref count (atomic_ptr has been overwritten) + cmp_and_old.d()->refs--; + } + + return false; + } + + // Unoptimized + template >> + shared_type compare_and_swap(const shared_ptr& cmp, shared_type exch) + { + verify(HERE), is_same_ptr(); + + shared_type old = cmp; + + if (compare_exchange(old, std::move(exch))) + { + return old; + } + else + { + return old; + } + } + + // More lightweight than compare_exchange + template >> + bool compare_and_swap_test(const shared_ptr& cmp, shared_type exch) + { + verify(HERE), is_same_ptr(); + + const uptr _old = reinterpret_cast(cmp.m_ptr); + const uptr _new = reinterpret_cast(exch.m_ptr); + + if (exch.m_ptr) + { + exch.d().refs += c_ref_mask; + } + + atomic_ptr old; + + const auto [_val, ok] = m_val.fetch_op([&](uptr& val) + { + if (val >> c_ref_size == _old) + { + // Set new value + val = _new << c_ref_size; + return true; + } + + return false; + }); + + if (ok) + { + // Success (exch is consumed, cmp_and_old is unchanged) + exch.m_ptr = nullptr; + old.m_val.raw() = _val; + return true; + } + + // Failure (return references) + old.m_val.raw() = reinterpret_cast(std::exchange(exch.m_ptr, nullptr)) << 17; + return false; + } + + // Unoptimized + template >> + shared_type compare_and_swap(const single_ptr& cmp, shared_type exch) + { + verify(HERE), is_same_ptr(); + + shared_type old = cmp; + + if (compare_exchange(old, std::move(exch))) + { + return old; + } + else + { + return old; + } + } + + // Supplementary + template >> + bool compare_and_swap_test(const single_ptr& cmp, shared_type exch) + { + return compare_and_swap_test(reinterpret_cast&>(cmp), std::move(exch)); + } // Simple atomic load is much more effective than load(), but it's a non-owning reference const volatile void* observe() const noexcept @@ -833,6 +1046,44 @@ namespace stx return observe() == r.get(); } }; + + // Some nullptr replacement for few cases + constexpr struct null_ptr_t + { + template + constexpr operator single_ptr() const noexcept + { + return {}; + } + + template + constexpr operator shared_ptr() const noexcept + { + return {}; + } + + template + constexpr operator atomic_ptr() const noexcept + { + return {}; + } + + explicit constexpr operator bool() const noexcept + { + return false; + } + + constexpr operator std::nullptr_t() const noexcept + { + return nullptr; + } + + constexpr std::nullptr_t get() const noexcept + { + return nullptr; + } + + } null_ptr; } namespace std @@ -850,6 +1101,7 @@ namespace std } } +using stx::null_ptr; using stx::single_ptr; using stx::shared_ptr; using stx::atomic_ptr;