diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 5454cb365..235b67368 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -237,9 +237,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include diff --git a/src/xir/passes/aggregate_field_bitmask.h b/include/luisa/xir/passes/aggregate_field_bitmask.h similarity index 69% rename from src/xir/passes/aggregate_field_bitmask.h rename to include/luisa/xir/passes/aggregate_field_bitmask.h index 5a5c2de34..e4a3df166 100644 --- a/src/xir/passes/aggregate_field_bitmask.h +++ b/include/luisa/xir/passes/aggregate_field_bitmask.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace luisa::compute { @@ -12,7 +13,7 @@ namespace detail { class AggregateFieldTree; }// namespace detail -class alignas(16) AggregateFieldBitmask { +class LC_XIR_API alignas(16) AggregateFieldBitmask { private: const detail::AggregateFieldTree *_field_tree; @@ -43,7 +44,7 @@ class alignas(16) AggregateFieldBitmask { [[nodiscard]] const Type *type() const noexcept; public: - class ConstBitSpan { + class LC_XIR_API ConstBitSpan { protected: uint64_t *_bits; uint32_t _offset; @@ -51,15 +52,32 @@ class alignas(16) AggregateFieldBitmask { public: ConstBitSpan(uint64_t *bits, uint32_t offset, uint32_t size) noexcept : _bits{bits}, _offset{offset}, _size{size} {} + [[nodiscard]] const uint64_t *raw_bits() const noexcept { return _bits; } + [[nodiscard]] size_t offset() const noexcept { return _offset; } + [[nodiscard]] size_t size() const noexcept { return _size; } [[nodiscard]] bool all() const noexcept; [[nodiscard]] bool any() const noexcept; [[nodiscard]] bool none() const noexcept; }; - class BitSpan : public ConstBitSpan { + class LC_XIR_API BitSpan : public ConstBitSpan { public: using ConstBitSpan::ConstBitSpan; - void set(bool value = true) && noexcept; - void flip() && noexcept; + [[nodiscard]] uint64_t *raw_bits() noexcept { return _bits; } + + void set(bool value = true) noexcept; + void flip() noexcept; + + BitSpan &operator|=(const ConstBitSpan &rhs) noexcept; + BitSpan &operator&=(const ConstBitSpan &rhs) noexcept; + BitSpan &operator^=(const ConstBitSpan &rhs) noexcept; + [[nodiscard]] bool operator==(const ConstBitSpan &rhs) const noexcept; + [[nodiscard]] bool operator!=(const ConstBitSpan &rhs) const noexcept; + + BitSpan &operator|=(const AggregateFieldBitmask &rhs) noexcept { return *this |= rhs.access(); } + BitSpan &operator&=(const AggregateFieldBitmask &rhs) noexcept { return *this &= rhs.access(); } + BitSpan &operator^=(const AggregateFieldBitmask &rhs) noexcept { return *this ^= rhs.access(); } + [[nodiscard]] bool operator==(const AggregateFieldBitmask &rhs) const noexcept { return *this == rhs.access(); } + [[nodiscard]] bool operator!=(const AggregateFieldBitmask &rhs) const noexcept { return *this != rhs.access(); } }; [[nodiscard]] BitSpan access(luisa::span access_chain) noexcept; [[nodiscard]] ConstBitSpan access(luisa::span access_chain) const noexcept; diff --git a/include/luisa/xir/passes/ref_arg_usage.h b/include/luisa/xir/passes/ref_arg_usage.h new file mode 100644 index 000000000..a3b2f78d7 --- /dev/null +++ b/include/luisa/xir/passes/ref_arg_usage.h @@ -0,0 +1,7 @@ +#pragma once + +namespace luisa::compute::xir { + + + +} diff --git a/src/xir/CMakeLists.txt b/src/xir/CMakeLists.txt index f106dc833..8398bb8ed 100644 --- a/src/xir/CMakeLists.txt +++ b/src/xir/CMakeLists.txt @@ -56,6 +56,7 @@ set(LUISA_COMPUTE_XIR_SOURCES passes/sink_alloca.cpp passes/trace_gep.cpp passes/aggregate_field_bitmask.cpp + passes/ref_arg_usage.cpp ) add_library(luisa-compute-xir SHARED ${LUISA_COMPUTE_XIR_SOURCES}) diff --git a/src/xir/passes/aggregate_field_bitmask.cpp b/src/xir/passes/aggregate_field_bitmask.cpp index 046f1b1be..d61e9163d 100644 --- a/src/xir/passes/aggregate_field_bitmask.cpp +++ b/src/xir/passes/aggregate_field_bitmask.cpp @@ -3,8 +3,7 @@ #include #include #include - -#include "aggregate_field_bitmask.h" +#include namespace luisa::compute::xir { @@ -211,7 +210,7 @@ void AggregateFieldBitmask::flip() noexcept { } } -void AggregateFieldBitmask::BitSpan::set(bool value) && noexcept { +void AggregateFieldBitmask::BitSpan::set(bool value) noexcept { auto lower = _offset / 64u; auto upper = (_offset + _size - 1u) / 64u; if (lower == upper) {// all selected bits are in the same bucket @@ -252,7 +251,7 @@ void AggregateFieldBitmask::BitSpan::set(bool value) && noexcept { } } -void AggregateFieldBitmask::BitSpan::flip() && noexcept { +void AggregateFieldBitmask::BitSpan::flip() noexcept { auto lower = _offset / 64u; auto upper = (_offset + _size - 1u) / 64u; if (lower == upper) {// all selected bits are in the same bucket @@ -277,6 +276,56 @@ void AggregateFieldBitmask::BitSpan::flip() && noexcept { } } +// TODO: Implement the following methods in a SIMD-friendly way +AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator|=(const ConstBitSpan &rhs) noexcept { + LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch."); + for (auto i = 0u; i < _size; i++) { + if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u]; + (rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) { + _bits[(_offset + i) / 64u] |= 1ull << ((_offset + i) % 64u); + } + } + return *this; +} + +AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator&=(const ConstBitSpan &rhs) noexcept { + LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch."); + for (auto i = 0u; i < _size; i++) { + if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u]; + (rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) { + _bits[(_offset + i) / 64u] &= 1ull << ((_offset + i) % 64u); + } + } + return *this; +} + +AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator^=(const ConstBitSpan &rhs) noexcept { + LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch."); + for (auto i = 0u; i < _size; i++) { + if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u]; + (rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) { + _bits[(_offset + i) / 64u] ^= 1ull << ((_offset + i) % 64u); + } + } + return *this; +} + +bool AggregateFieldBitmask::BitSpan::operator==(const ConstBitSpan &rhs) const noexcept { + if (_size != rhs.size()) { return false; } + if (this != &rhs) { + for (auto i = 0u; i < _size; i++) { + auto lhs_bit = (_bits[(_offset + i) / 64u] >> ((_offset + i) % 64u)) & 1ull; + auto rhs_bit = (rhs.raw_bits()[i / 64u] >> (i % 64u)) & 1ull; + if (lhs_bit != rhs_bit) { return false; } + } + } + return true; +} + +bool AggregateFieldBitmask::BitSpan::operator!=(const ConstBitSpan &rhs) const noexcept { + return !(*this == rhs); +} + bool AggregateFieldBitmask::ConstBitSpan::all() const noexcept { auto lower = _offset / 64u; auto upper = (_offset + _size - 1u) / 64u; diff --git a/src/xir/passes/ref_arg_usage.cpp b/src/xir/passes/ref_arg_usage.cpp new file mode 100644 index 000000000..92ef23689 --- /dev/null +++ b/src/xir/passes/ref_arg_usage.cpp @@ -0,0 +1,5 @@ +#include + +namespace luisa::compute::xir { + +}// namespace luisa::compute::xir diff --git a/src/xir/tests/test_aggregate_field_bitmasks.cpp b/src/xir/tests/test_aggregate_field_bitmasks.cpp index 39bf83881..23614f4ae 100644 --- a/src/xir/tests/test_aggregate_field_bitmasks.cpp +++ b/src/xir/tests/test_aggregate_field_bitmasks.cpp @@ -1,6 +1,4 @@ #include -#include "../passes/aggregate_field_bitmask.h" -#include "luisa/dsl/struct.h" using namespace luisa; using namespace luisa::compute;