Skip to content

Commit

Permalink
use fshl/fshr to implement bit rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 13, 2025
1 parent eeda691 commit 66e3e52
Showing 1 changed file with 4 additions and 64 deletions.
68 changes: 4 additions & 64 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,77 +765,17 @@ class FallbackCodegen {
}

[[nodiscard]] llvm::Value *_translate_binary_rotate_left(CurrentFunction &current, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept {
LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate left.");
auto llvm_value = _lookup_value(current, b, value);
auto llvm_shift = _lookup_value(current, b, shift);
auto value_type = value->type();
auto elem_type = value_type->is_vector() ? value_type->element() : value_type;
LUISA_ASSERT(value_type != nullptr, "Operand type is null.");
LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate left.");
LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type.");
auto bit_width = 0u;
switch (elem_type->tag()) {
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::UINT8: bit_width = 8; break;
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::UINT16: bit_width = 16; break;
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::UINT32: bit_width = 32; break;
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT64: bit_width = 64; break;
default: LUISA_ERROR_WITH_LOCATION(
"Invalid operand type for rotate left operation: {}.",
elem_type->description());
}
auto llvm_elem_type = _translate_type(elem_type, true);
auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width);
if (value_type->is_vector()) {
llvm_bit_width = llvm::ConstantVector::getSplat(
llvm::ElementCount::getFixed(value_type->dimension()),
llvm_bit_width);
}
auto shifted_left = b.CreateShl(llvm_value, llvm_shift);
auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift);
auto shifted_right = b.CreateLShr(llvm_value, complement_shift);
return b.CreateOr(shifted_left, shifted_right);
return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshl, {llvm_value, llvm_value, llvm_shift});
}

[[nodiscard]] llvm::Value *_translate_binary_rotate_right(CurrentFunction &current, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept {
// Lookup LLVM values for operands
LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate right.");
auto llvm_value = _lookup_value(current, b, value);
auto llvm_shift = _lookup_value(current, b, shift);
auto value_type = value->type();
auto elem_type = value_type->is_vector() ? value_type->element() : value_type;

// Type and null checks
LUISA_ASSERT(value_type != nullptr, "Operand type is null.");
LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate right.");
LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type.");

auto bit_width = 0u;
switch (elem_type->tag()) {
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::UINT8: bit_width = 8; break;
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::UINT16: bit_width = 16; break;
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::UINT32: bit_width = 32; break;
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT64: bit_width = 64; break;
default: LUISA_ERROR_WITH_LOCATION(
"Invalid operand type for rotate right operation: {}.",
elem_type->description());
}
auto llvm_elem_type = _translate_type(elem_type, true);
auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width);
if (value_type->is_vector()) {
llvm_bit_width = llvm::ConstantVector::getSplat(
llvm::ElementCount::getFixed(value_type->dimension()),
llvm_bit_width);
}
auto shifted_right = b.CreateLShr(llvm_value, llvm_shift);
auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift);
auto shifted_left = b.CreateShl(llvm_value, complement_shift);
return b.CreateOr(shifted_left, shifted_right);
return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshr, {llvm_value, llvm_value, llvm_shift});
}

[[nodiscard]] llvm::Value *_translate_binary_less(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
Expand Down

0 comments on commit 66e3e52

Please sign in to comment.