rsx: Use address_range64 to simplify MM intersection tests

This commit is contained in:
kd-11 2025-05-26 03:42:56 +03:00 committed by kd-11
parent 2ea7ff6b14
commit 45718d7679
3 changed files with 17 additions and 31 deletions

View file

@ -18,7 +18,7 @@ namespace rsx
{
for (const auto& block : g_deferred_mprotect_queue)
{
utils::memory_protect(reinterpret_cast<void*>(block.start), block.length, block.prot);
utils::memory_protect(reinterpret_cast<void*>(block.range.start), block.range.length(), block.prot);
}
g_deferred_mprotect_queue.clear();
@ -28,7 +28,7 @@ namespace rsx
{
// We could stack and merge requests here, but that is more trouble than it is truly worth.
// A fresh call to memory_protect only takes a few nanoseconds of setup overhead, it is not worth the risk of hanging because of conflicts.
g_deferred_mprotect_queue.push_back({ start, length, prot });
g_deferred_mprotect_queue.push_back({ utils::address_range64::start_length(start, length), prot });
}
void mm_protect(void* ptr, u64 length, utils::protection prot)
@ -41,7 +41,7 @@ namespace rsx
// Naive merge. Eventually it makes more sense to do conflict resolution, but it's not as important.
const auto start = reinterpret_cast<u64>(ptr);
const auto end = start + length;
const auto range = utils::address_range64::start_length(start, length);
std::lock_guard lock(g_mprotect_queue_lock);
@ -50,7 +50,7 @@ namespace rsx
// Basically an unlock op. Flush if any overlap is detected
for (const auto& block : g_deferred_mprotect_queue)
{
if (block.overlaps(start, end))
if (block.overlaps(range))
{
mm_flush_mprotect_queue_internal();
break;
@ -90,7 +90,7 @@ namespace rsx
}
}
void mm_flush(const rsx::simple_array<utils::address_range32>& ranges)
void mm_flush(const rsx::simple_array<utils::address_range64>& ranges)
{
std::lock_guard lock(g_mprotect_queue_lock);
if (g_deferred_mprotect_queue.empty())
@ -98,16 +98,9 @@ namespace rsx
return;
}
const auto ranges64 = ranges.map([](const auto& r)
{
const u64 start = reinterpret_cast<uintptr_t>(vm::base(r.start));
const u64 end = start + r.length();
return std::make_pair(start, end);
});
for (const auto& block : g_deferred_mprotect_queue)
{
if (ranges64.any(FN(block.overlaps(x.first, x.second))))
if (ranges.any(FN(block.overlaps(x))))
{
mm_flush_mprotect_queue_internal();
return;

View file

@ -10,24 +10,17 @@ namespace rsx
{
struct MM_block
{
u64 start;
u64 length;
utils::address_range64 range;
utils::protection prot;
inline bool overlaps(u64 start, u64 end) const
inline bool overlaps(const utils::address_range64& test) const
{
// [Start, End] is not a proper closed range, there is an off-by-one by design.
// FIXME: Use address_range64
const u64 this_end = this->start + this->length;
return (this->start < end && start < this_end);
return range.overlaps(test);
}
inline bool overlaps(u64 addr) const
{
// [Start, End] is not a proper closed range, there is an off-by-one by design.
// FIXME: Use address_range64
const u64 this_end = this->start + this->length;
return (addr >= start && addr < this_end);
return range.overlaps(addr);
}
};
@ -39,6 +32,6 @@ namespace rsx
void mm_protect(void* start, u64 length, utils::protection prot);
void mm_flush_lazy();
void mm_flush(u32 vm_address);
void mm_flush(const rsx::simple_array<utils::address_range32>& ranges);
void mm_flush(const rsx::simple_array<utils::address_range64>& ranges);
void mm_flush();
}

View file

@ -58,16 +58,16 @@ namespace rsx
auto res = ::rsx::reservation_lock<true>(write_address, write_length, read_address, read_length);
rsx::simple_array<utils::address_range32> flush_mm_ranges =
u8* dst = vm::_ptr<u8>(write_address);
const u8* src = vm::_ptr<u8>(read_address);
rsx::simple_array<utils::address_range64> flush_mm_ranges =
{
utils::address_range32::start_length(write_address, write_length).to_page_range(),
utils::address_range32::start_length(read_address, read_length).to_page_range()
utils::address_range64::start_length(reinterpret_cast<u64>(dst), write_length),
utils::address_range64::start_length(reinterpret_cast<u64>(src), read_length)
};
rsx::mm_flush(flush_mm_ranges);
u8 *dst = vm::_ptr<u8>(write_address);
const u8 *src = vm::_ptr<u8>(read_address);
const bool is_overlapping = dst_dma == src_dma && [&]() -> bool
{
const u32 src_max = src_offset + read_length;