From 66e3e52349751389cd61925e84ab64cdc6257dec Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 13 Jan 2025 16:11:10 +0800 Subject: [PATCH] use fshl/fshr to implement bit rotate --- src/backends/fallback/fallback_codegen.cpp | 68 ++-------------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/src/backends/fallback/fallback_codegen.cpp b/src/backends/fallback/fallback_codegen.cpp index 688324ee8..832326bf7 100644 --- a/src/backends/fallback/fallback_codegen.cpp +++ b/src/backends/fallback/fallback_codegen.cpp @@ -765,77 +765,17 @@ class FallbackCodegen { } [[nodiscard]] llvm::Value *_translate_binary_rotate_left(CurrentFunction ¤t, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept { + LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate left."); auto llvm_value = _lookup_value(current, b, value); auto llvm_shift = _lookup_value(current, b, shift); - auto value_type = value->type(); - auto elem_type = value_type->is_vector() ? value_type->element() : value_type; - LUISA_ASSERT(value_type != nullptr, "Operand type is null."); - LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate left."); - LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type."); - auto bit_width = 0u; - switch (elem_type->tag()) { - case Type::Tag::INT8: [[fallthrough]]; - case Type::Tag::UINT8: bit_width = 8; break; - case Type::Tag::INT16: [[fallthrough]]; - case Type::Tag::UINT16: bit_width = 16; break; - case Type::Tag::INT32: [[fallthrough]]; - case Type::Tag::UINT32: bit_width = 32; break; - case Type::Tag::INT64: [[fallthrough]]; - case Type::Tag::UINT64: bit_width = 64; break; - default: LUISA_ERROR_WITH_LOCATION( - "Invalid operand type for rotate left operation: {}.", - elem_type->description()); - } - auto llvm_elem_type = _translate_type(elem_type, true); - auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width); - if (value_type->is_vector()) { - llvm_bit_width = llvm::ConstantVector::getSplat( - llvm::ElementCount::getFixed(value_type->dimension()), - llvm_bit_width); - } - auto shifted_left = b.CreateShl(llvm_value, llvm_shift); - auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift); - auto shifted_right = b.CreateLShr(llvm_value, complement_shift); - return b.CreateOr(shifted_left, shifted_right); + return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshl, {llvm_value, llvm_value, llvm_shift}); } [[nodiscard]] llvm::Value *_translate_binary_rotate_right(CurrentFunction ¤t, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept { - // Lookup LLVM values for operands + LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate right."); auto llvm_value = _lookup_value(current, b, value); auto llvm_shift = _lookup_value(current, b, shift); - auto value_type = value->type(); - auto elem_type = value_type->is_vector() ? value_type->element() : value_type; - - // Type and null checks - LUISA_ASSERT(value_type != nullptr, "Operand type is null."); - LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate right."); - LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type."); - - auto bit_width = 0u; - switch (elem_type->tag()) { - case Type::Tag::INT8: [[fallthrough]]; - case Type::Tag::UINT8: bit_width = 8; break; - case Type::Tag::INT16: [[fallthrough]]; - case Type::Tag::UINT16: bit_width = 16; break; - case Type::Tag::INT32: [[fallthrough]]; - case Type::Tag::UINT32: bit_width = 32; break; - case Type::Tag::INT64: [[fallthrough]]; - case Type::Tag::UINT64: bit_width = 64; break; - default: LUISA_ERROR_WITH_LOCATION( - "Invalid operand type for rotate right operation: {}.", - elem_type->description()); - } - auto llvm_elem_type = _translate_type(elem_type, true); - auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width); - if (value_type->is_vector()) { - llvm_bit_width = llvm::ConstantVector::getSplat( - llvm::ElementCount::getFixed(value_type->dimension()), - llvm_bit_width); - } - auto shifted_right = b.CreateLShr(llvm_value, llvm_shift); - auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift); - auto shifted_left = b.CreateShl(llvm_value, complement_shift); - return b.CreateOr(shifted_left, shifted_right); + return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshr, {llvm_value, llvm_value, llvm_shift}); } [[nodiscard]] llvm::Value *_translate_binary_less(CurrentFunction ¤t, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {