atomic.cpp: implement some atomic wait operations.

Instead of plain waiting while equal to some value,
it can be something like less, or greater, or even bitcount.
But it's a draft and untested. Hopefully doesn't break anything.
This commit is contained in:
Nekotekina 2020-11-11 10:25:22 +03:00
parent 829a697c39
commit 7cd1e767be
2 changed files with 207 additions and 44 deletions

View file

@ -20,6 +20,7 @@
#include <random> #include <random>
#include "asm.hpp" #include "asm.hpp"
#include "endian.hpp"
// Total number of entries, should be a power of 2. // Total number of entries, should be a power of 2.
static constexpr std::size_t s_hashtable_size = 1u << 18; static constexpr std::size_t s_hashtable_size = 1u << 18;
@ -30,13 +31,18 @@ static thread_local bool(*s_tls_wait_cb)(const void* data) = [](const void*){ re
// Callback for notification functions for optimizations // Callback for notification functions for optimizations
static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = [](const void*, u64){}; static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = [](const void*, u64){};
static inline bool operator &(atomic_wait::op lhs, atomic_wait::op_flag rhs)
{
return !!(static_cast<u8>(lhs) & static_cast<u8>(rhs));
}
// Compare data in memory with old value, and return true if they are equal // Compare data in memory with old value, and return true if they are equal
template <bool CheckCb = true> template <bool CheckCb = true>
static NEVER_INLINE bool static NEVER_INLINE bool
#ifdef _WIN32 #ifdef _WIN32
__vectorcall __vectorcall
#endif #endif
ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr) ptr_cmp(const void* data, u32 _size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr)
{ {
if constexpr (CheckCb) if constexpr (CheckCb)
{ {
@ -46,32 +52,138 @@ ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait
} }
} }
const u64 old_value = _mm_cvtsi128_si64(old128); using atomic_wait::op;
const u64 mask = _mm_cvtsi128_si64(mask128); using atomic_wait::op_flag;
const u8 size = static_cast<u8>(_size);
const op flag{static_cast<u8>(_size >> 8)};
bool result = false; bool result = false;
switch (size) if (size <= 8)
{ {
case 1: result = (reinterpret_cast<const atomic_t<u8>*>(data)->load() & mask) == (old_value & mask); break; u64 new_value = 0;
case 2: result = (reinterpret_cast<const atomic_t<u16>*>(data)->load() & mask) == (old_value & mask); break; u64 old_value = _mm_cvtsi128_si64(old128);
case 4: result = (reinterpret_cast<const atomic_t<u32>*>(data)->load() & mask) == (old_value & mask); break; u64 mask = _mm_cvtsi128_si64(mask128) & (UINT64_MAX >> ((64 - size * 8) & 63));
case 8: result = (reinterpret_cast<const atomic_t<u64>*>(data)->load() & mask) == (old_value & mask); break;
case 16:
{
const auto v0 = std::bit_cast<__m128i>(atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data)));
const auto v1 = _mm_xor_si128(v0, old128);
const auto v2 = _mm_and_si128(v1, mask128);
const auto v3 = _mm_packs_epi16(v2, v2);
result = _mm_cvtsi128_si64(v3) == 0; switch (size)
break; {
case 1: new_value = reinterpret_cast<const atomic_t<u8>*>(data)->load(); break;
case 2: new_value = reinterpret_cast<const atomic_t<u16>*>(data)->load(); break;
case 4: new_value = reinterpret_cast<const atomic_t<u32>*>(data)->load(); break;
case 8: new_value = reinterpret_cast<const atomic_t<u64>*>(data)->load(); break;
default:
{
fprintf(stderr, "ptr_cmp(): bad size (arg=0x%x)" HERE "\n", _size);
std::abort();
}
}
if (flag & op_flag::bit_not)
{
new_value = ~new_value;
}
if (!mask) [[unlikely]]
{
new_value = 0;
old_value = 0;
}
else
{
if (flag & op_flag::byteswap)
{
switch (size)
{
case 2:
{
new_value = stx::se_storage<u16>::swap(static_cast<u16>(new_value));
old_value = stx::se_storage<u16>::swap(static_cast<u16>(old_value));
mask = stx::se_storage<u16>::swap(static_cast<u16>(mask));
break;
}
case 4:
{
new_value = stx::se_storage<u32>::swap(static_cast<u32>(new_value));
old_value = stx::se_storage<u32>::swap(static_cast<u32>(old_value));
mask = stx::se_storage<u32>::swap(static_cast<u32>(mask));
break;
}
case 8:
{
new_value = stx::se_storage<u64>::swap(new_value);
old_value = stx::se_storage<u64>::swap(old_value);
mask = stx::se_storage<u64>::swap(mask);
}
default:
{
break;
}
}
}
// Make most significant bit sign bit
const auto shv = std::countl_zero(mask);
new_value &= mask;
old_value &= mask;
new_value <<= shv;
old_value <<= shv;
}
s64 news = new_value;
s64 olds = old_value;
u64 newa = news < 0 ? (0ull - new_value) : new_value;
u64 olda = olds < 0 ? (0ull - old_value) : old_value;
switch (op{static_cast<u8>(static_cast<u8>(flag) & 0xf)})
{
case op::eq: result = old_value == new_value; break;
case op::slt: result = olds < news; break;
case op::sgt: result = olds > news; break;
case op::ult: result = old_value < new_value; break;
case op::ugt: result = old_value > new_value; break;
case op::alt: result = olda < newa; break;
case op::agt: result = olda > newa; break;
case op::pop:
{
// Count is taken from least significant byte and ignores some flags
const u64 count = _mm_cvtsi128_si64(old128) & 0xff;
u64 bitc = new_value;
bitc = (bitc & 0xaaaaaaaaaaaaaaaa) / 2 + (bitc & 0x5555555555555555);
bitc = (bitc & 0xcccccccccccccccc) / 4 + (bitc & 0x3333333333333333);
bitc = (bitc & 0xf0f0f0f0f0f0f0f0) / 16 + (bitc & 0x0f0f0f0f0f0f0f0f);
bitc = (bitc & 0xff00ff00ff00ff00) / 256 + (bitc & 0x00ff00ff00ff00ff);
bitc = ((bitc & 0xffff0000ffff0000) >> 16) + (bitc & 0x0000ffff0000ffff);
bitc = (bitc >> 32) + bitc;
result = count < bitc;
break;
}
default:
{
fmt::raw_error("ptr_cmp(): unrecognized atomic wait operation.");
}
}
} }
default: else if (size == 16 && (flag == op::eq || flag == (op::eq | op_flag::inverse)))
{ {
fprintf(stderr, "ptr_cmp(): bad size (size=%u)" HERE "\n", size); u128 new_value = atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data));
std::abort(); u128 old_value = std::bit_cast<u128>(old128);
u128 mask = std::bit_cast<u128>(mask128);
// TODO
result = !((old_value ^ new_value) & mask);
} }
else if (size == 16)
{
fmt::raw_error("ptr_cmp(): no alternative operations are supported for 16-byte atomic wait yet.");
}
if (flag & op_flag::inverse)
{
result = !result;
} }
// Check other wait variables if provided // Check other wait variables if provided
@ -101,16 +213,8 @@ __vectorcall
#endif #endif
cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m128i val2) cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m128i val2)
{ {
// In force wake up, one of the size arguments is zero (obsolete)
const u32 size = std::min(size1, size2);
if (!size) [[unlikely]]
{
return 2;
}
// Compare only masks, new value is not available in this mode // Compare only masks, new value is not available in this mode
if ((size1 | size2) == umax) if (size1 == umax)
{ {
// Simple mask overlap // Simple mask overlap
const auto v0 = _mm_and_si128(mask1, mask2); const auto v0 = _mm_and_si128(mask1, mask2);
@ -121,6 +225,17 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
// Generate masked value inequality bits // Generate masked value inequality bits
const auto v0 = _mm_and_si128(_mm_and_si128(mask1, mask2), _mm_xor_si128(val1, val2)); const auto v0 = _mm_and_si128(_mm_and_si128(mask1, mask2), _mm_xor_si128(val1, val2));
using atomic_wait::op;
using atomic_wait::op_flag;
const u8 size = std::min<u8>(static_cast<u8>(size2), static_cast<u8>(size1));
const op flag{static_cast<u8>(size2 >> 8)};
if (flag != op::eq && flag != (op::eq | op_flag::inverse))
{
fmt::raw_error("cmp_mask(): no operations are supported for notification with forced value yet.");
}
if (size <= 8) if (size <= 8)
{ {
// Generate sized mask // Generate sized mask
@ -128,14 +243,14 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
if (!(_mm_cvtsi128_si64(v0) & mask)) if (!(_mm_cvtsi128_si64(v0) & mask))
{ {
return 0; return flag & op_flag::inverse ? 2 : 0;
} }
} }
else if (size == 16) else if (size == 16)
{ {
if (!_mm_cvtsi128_si64(_mm_packs_epi16(v0, v0))) if (!_mm_cvtsi128_si64(_mm_packs_epi16(v0, v0)))
{ {
return 0; return flag & op_flag::inverse ? 2 : 0;
} }
} }
else else
@ -145,7 +260,7 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
} }
// Use force wake-up // Use force wake-up
return 2; return flag & op_flag::inverse ? 0 : 2;
} }
static atomic_t<u64> s_min_tsc{0}; static atomic_t<u64> s_min_tsc{0};
@ -227,7 +342,8 @@ namespace atomic_wait
// Temporarily reduced unique tsc stamp to 48 bits to make space for refs (TODO) // Temporarily reduced unique tsc stamp to 48 bits to make space for refs (TODO)
u64 tsc0 : 48 = 0; u64 tsc0 : 48 = 0;
u64 link : 16 = 0; u64 link : 16 = 0;
u16 size{}; u8 size{};
u8 flag{};
atomic_t<u16> refs{}; atomic_t<u16> refs{};
atomic_t<u32> sync{}; atomic_t<u32> sync{};
@ -262,6 +378,7 @@ namespace atomic_wait
tsc0 = 0; tsc0 = 0;
link = 0; link = 0;
size = 0; size = 0;
flag = 0;
sync = 0; sync = 0;
#ifdef USE_STD #ifdef USE_STD
@ -868,7 +985,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
// Store some info for notifiers (some may be unused) // Store some info for notifiers (some may be unused)
cond->link = 0; cond->link = 0;
cond->size = static_cast<u16>(size); cond->size = static_cast<u8>(size);
cond->flag = static_cast<u8>(size >> 8);
cond->mask = mask; cond->mask = mask;
cond->oldv = old_value; cond->oldv = old_value;
cond->tsc0 = stamp0; cond->tsc0 = stamp0;
@ -877,7 +995,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
{ {
// Extensions point to original cond_id, copy remaining info // Extensions point to original cond_id, copy remaining info
cond_ext[i]->link = cond_id; cond_ext[i]->link = cond_id;
cond_ext[i]->size = static_cast<u16>(ext[i].size); cond_ext[i]->size = static_cast<u8>(ext[i].size);
cond_ext[i]->flag = static_cast<u8>(ext[i].size >> 8);
cond_ext[i]->mask = ext[i].mask; cond_ext[i]->mask = ext[i].mask;
cond_ext[i]->oldv = ext[i].old; cond_ext[i]->oldv = ext[i].old;
cond_ext[i]->tsc0 = stamp0; cond_ext[i]->tsc0 = stamp0;
@ -1058,7 +1177,7 @@ alert_sema(u32 cond_id, const void* data, u64 info, u32 size, __m128i mask, __m1
u32 cmp_res = 0; u32 cmp_res = 0;
if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size, cond->mask, cond->oldv)))))) if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size | (cond->flag << 8), cond->mask, cond->oldv))))))
{ {
// Redirect if necessary // Redirect if necessary
const auto _old = cond; const auto _old = cond;

View file

@ -14,14 +14,56 @@ enum class atomic_wait_timeout : u64
inf = 0xffffffffffffffff, inf = 0xffffffffffffffff,
}; };
// Unused externally // Various extensions for atomic_t::wait
namespace atomic_wait namespace atomic_wait
{ {
// Max number of simultaneous atomic variables to wait on (can be extended if really necessary)
constexpr uint max_list = 8; constexpr uint max_list = 8;
struct root_info; struct root_info;
struct sema_handle; struct sema_handle;
enum class op : u8
{
eq, // Wait while value is bitwise equal to
slt, // Wait while signed value is less than
sgt, // Wait while signed value is greater than
ult, // Wait while unsigned value is less than
ugt, // Wait while unsigned value is greater than
alt, // Wait while absolute value is less than
agt, // Wait while absolute value is greater than
pop, // Wait while set bit count of the value is less than
__max
};
static_assert(static_cast<u8>(op::__max) == 8);
enum class op_flag : u8
{
inverse = 1 << 4, // Perform inverse operation (negate the result)
bit_not = 1 << 5, // Perform bitwise NOT on loaded value before operation
byteswap = 1 << 6, // Perform byteswap on both arguments and masks when applicable
};
constexpr op_flag op_ne = {};
constexpr op_flag op_be = std::endian::native == std::endian::little ? op_flag::byteswap : op_flag{0};
constexpr op_flag op_le = std::endian::native == std::endian::little ? op_flag{0} : op_flag::byteswap;
constexpr op operator |(op_flag lhs, op_flag rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
constexpr op operator |(op_flag lhs, op rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
constexpr op operator |(op lhs, op_flag rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
struct info struct info
{ {
const void* data; const void* data;
@ -114,24 +156,24 @@ namespace atomic_wait
return *this; return *this;
} }
template <uint Index, typename T2, std::size_t Align, typename U> template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U>
constexpr void set(atomic_t<T2, Align>& var, U value) constexpr void set(atomic_t<T2, Align>& var, U value)
{ {
static_assert(Index < Max); static_assert(Index < Max);
m_info[Index].data = &var.raw(); m_info[Index].data = &var.raw();
m_info[Index].size = sizeof(T2); m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
m_info[Index].template set_value<T2>(value); m_info[Index].template set_value<T2>(value);
m_info[Index].mask = _mm_set1_epi64x(-1); m_info[Index].mask = _mm_set1_epi64x(-1);
} }
template <uint Index, typename T2, std::size_t Align, typename U, typename V> template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U, typename V>
constexpr void set(atomic_t<T2, Align>& var, U value, V mask) constexpr void set(atomic_t<T2, Align>& var, U value, V mask)
{ {
static_assert(Index < Max); static_assert(Index < Max);
m_info[Index].data = &var.raw(); m_info[Index].data = &var.raw();
m_info[Index].size = sizeof(T2); m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
m_info[Index].template set_value<T2>(value); m_info[Index].template set_value<T2>(value);
m_info[Index].template set_mask<T2>(mask); m_info[Index].template set_mask<T2>(mask);
} }
@ -1387,34 +1429,36 @@ public:
} }
// Timeout is discouraged // Timeout is discouraged
template <atomic_wait::op Flags = atomic_wait::op::eq>
void wait(type old_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) const noexcept void wait(type old_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) const noexcept
{ {
if constexpr (sizeof(T) <= 8) if constexpr (sizeof(T) <= 8)
{ {
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value)); const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1)); atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
} }
else if constexpr (sizeof(T) == 16) else if constexpr (sizeof(T) == 16)
{ {
const __m128i old = std::bit_cast<__m128i>(old_value); const __m128i old = std::bit_cast<__m128i>(old_value);
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1)); atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
} }
} }
// Overload with mask (only selected bits are checked), timeout is discouraged // Overload with mask (only selected bits are checked), timeout is discouraged
template <atomic_wait::op Flags = atomic_wait::op::eq>
void wait(type old_value, type mask_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) void wait(type old_value, type mask_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf)
{ {
if constexpr (sizeof(T) <= 8) if constexpr (sizeof(T) <= 8)
{ {
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value)); const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
const __m128i mask = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(mask_value)); const __m128i mask = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(mask_value));
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask); atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
} }
else if constexpr (sizeof(T) == 16) else if constexpr (sizeof(T) == 16)
{ {
const __m128i old = std::bit_cast<__m128i>(old_value); const __m128i old = std::bit_cast<__m128i>(old_value);
const __m128i mask = std::bit_cast<__m128i>(mask_value); const __m128i mask = std::bit_cast<__m128i>(mask_value);
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask); atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
} }
} }