atomic.cpp: fix cond_handle data structures

Fix a critical bug with possible id out of range.
This commit is contained in:
Nekotekina 2020-11-06 11:55:01 +03:00
parent 1e45437498
commit bfe9580551

View file

@ -193,13 +193,13 @@ namespace atomic_wait
} }
// Max allowed thread number is chosen to fit in 16 bits // Max allowed thread number is chosen to fit in 16 bits
static std::aligned_storage_t<sizeof(atomic_wait::cond_handle), alignof(atomic_wait::cond_handle)> s_cond_list[UINT16_MAX]{}; static std::aligned_storage_t<sizeof(atomic_wait::cond_handle), alignof(atomic_wait::cond_handle)> s_cond_list[UINT16_MAX + 1]{};
// Used to allow concurrent notifying // Used to allow concurrent notifying
static atomic_t<u16> s_cond_refs[UINT16_MAX + 1]{}; static atomic_t<u32> s_cond_refs[UINT16_MAX + 1]{};
// Allocation bits // Allocation bits
static atomic_t<u64, 64> s_cond_bits[::align<u32>(UINT16_MAX, 64) / 64]{}; static atomic_t<u64, 64> s_cond_bits[(UINT16_MAX + 1) / 64]{};
// Allocation semaphore // Allocation semaphore
static atomic_t<u32, 64> s_cond_sema{0}; static atomic_t<u32, 64> s_cond_sema{0};
@ -237,18 +237,24 @@ static u32 cond_alloc()
return false; return false;
}); });
if (ok) if (ok) [[likely]]
{ {
// Find lowest clear bit // Find lowest clear bit
const u32 id = group * 64 + std::countr_one(bits); const u32 id = group * 64 + std::countr_one(bits);
if (id == 0) [[unlikely]]
{
// Special case, set bit and continue
continue;
}
// Construct inplace before it can be used // Construct inplace before it can be used
new (s_cond_list + id) atomic_wait::cond_handle(); new (s_cond_list + id) atomic_wait::cond_handle();
// Add first reference // Add first reference
verify(HERE), !s_cond_refs[id]++; verify(HERE), !s_cond_refs[id]++;
return id + 1; return id;
} }
} }
@ -261,7 +267,7 @@ static atomic_wait::cond_handle* cond_get(u32 cond_id)
{ {
if (cond_id - 1 < u32{UINT16_MAX}) [[likely]] if (cond_id - 1 < u32{UINT16_MAX}) [[likely]]
{ {
return std::launder(reinterpret_cast<atomic_wait::cond_handle*>(s_cond_list + (cond_id - 1))); return std::launder(reinterpret_cast<atomic_wait::cond_handle*>(s_cond_list + cond_id));
} }
return nullptr; return nullptr;
@ -276,7 +282,7 @@ static void cond_free(u32 cond_id)
} }
// Dereference, destroy on last ref // Dereference, destroy on last ref
if (--s_cond_refs[cond_id - 1]) if (--s_cond_refs[cond_id])
{ {
return; return;
} }
@ -285,7 +291,7 @@ static void cond_free(u32 cond_id)
cond_get(cond_id)->~cond_handle(); cond_get(cond_id)->~cond_handle();
// Remove the allocation bit // Remove the allocation bit
s_cond_bits[(cond_id - 1) / 64] &= ~(1ull << ((cond_id - 1) % 64)); s_cond_bits[cond_id / 64] &= ~(1ull << (cond_id % 64));
// Release the semaphore // Release the semaphore
s_cond_sema--; s_cond_sema--;
@ -295,9 +301,9 @@ static u32 cond_lock(atomic_t<u16>* sema)
{ {
while (const u32 cond_id = sema->load()) while (const u32 cond_id = sema->load())
{ {
const auto [old, ok] = s_cond_refs[cond_id - 1].fetch_op([](u16& ref) const auto [old, ok] = s_cond_refs[cond_id].fetch_op([](u32& ref)
{ {
if (!ref || ref == UINT16_MAX) if (!ref || ref == UINT32_MAX)
{ {
// Don't reference already deallocated semaphore // Don't reference already deallocated semaphore
return false; return false;
@ -312,9 +318,9 @@ static u32 cond_lock(atomic_t<u16>* sema)
return cond_id; return cond_id;
} }
if (old == UINT16_MAX) if (old == UINT32_MAX)
{ {
fmt::raw_error("Thread limit " STRINGIZE(UINT16_MAX) " for a single address reached in atomic notifier."); fmt::raw_error("Thread limit " STRINGIZE(UINT32_MAX) " for a single address reached in atomic notifier.");
} }
if (sema->load() != cond_id) if (sema->load() != cond_id)