Skip to content

Commit

Permalink
fix char/uchar support in constant data
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Apr 9, 2024
1 parent 3309143 commit 815126c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/luisa/ast/constant_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class LC_AST_API ConstantDecoder {

protected:
virtual void _decode_bool(bool x) noexcept = 0;
virtual void _decode_char(char x) noexcept = 0;
virtual void _decode_uchar(ubyte x) noexcept = 0;
virtual void _decode_short(short x) noexcept = 0;
virtual void _decode_ushort(ushort x) noexcept = 0;
virtual void _decode_int(int x) noexcept = 0;
Expand Down
11 changes: 8 additions & 3 deletions src/ast/constant_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ void ConstantDecoder::_decode_array(const Type *type, const std::byte *data) noe
void ConstantDecoder::_decode(const Type *type, const std::byte *data) noexcept {
switch (type->tag()) {
case Type::Tag::BOOL: _decode_bool(*reinterpret_cast<const bool *>(data)); break;
case Type::Tag::FLOAT32: _decode_float(*reinterpret_cast<const float *>(data)); break;
case Type::Tag::INT8: _decode_char(*reinterpret_cast<const char *>(data)); break;
case Type::Tag::UINT8: _decode_uchar(*reinterpret_cast<const uchar *>(data)); break;
case Type::Tag::INT16: _decode_short(*reinterpret_cast<const short *>(data)); break;
case Type::Tag::UINT16: _decode_ushort(*reinterpret_cast<const ushort *>(data)); break;
case Type::Tag::INT32: _decode_int(*reinterpret_cast<const int *>(data)); break;
case Type::Tag::UINT32: _decode_uint(*reinterpret_cast<const uint *>(data)); break;
case Type::Tag::INT64: _decode_long(*reinterpret_cast<const slong *>(data)); break;
case Type::Tag::UINT64: _decode_ulong(*reinterpret_cast<const ulong *>(data)); break;
case Type::Tag::FLOAT16: _decode_half(*reinterpret_cast<const half *>(data)); break;
case Type::Tag::INT16: _decode_short(*reinterpret_cast<const short *>(data)); break;
case Type::Tag::UINT16: _decode_ushort(*reinterpret_cast<const ushort *>(data)); break;
case Type::Tag::FLOAT32: _decode_float(*reinterpret_cast<const float *>(data)); break;
case Type::Tag::FLOAT64: _decode_double(*reinterpret_cast<const double *>(data)); break;
case Type::Tag::VECTOR: _decode_vector(type, data); break;
case Type::Tag::MATRIX: _decode_matrix(type, data); break;
case Type::Tag::ARRAY: _decode_array(type, data); break;
Expand Down Expand Up @@ -113,6 +116,8 @@ class ConstantSerializer final : public ConstantDecoder {

protected:
void _decode_bool(bool x) noexcept override { _s.append(luisa::format("bool({})", x)); }
void _decode_char(char x) noexcept override { _s.append(luisa::format("char({})", static_cast<int>(x))); }
void _decode_uchar(uchar x) noexcept override { _s.append(luisa::format("uchar({})", static_cast<uint>(x))); }
void _decode_short(short x) noexcept override { _s.append(luisa::format("short({})", x)); }
void _decode_ushort(ushort x) noexcept override { _s.append(luisa::format("ushort({})", x)); }
void _decode_int(int x) noexcept override { _s.append(luisa::format("int({})", x)); }
Expand Down
6 changes: 6 additions & 0 deletions src/backends/common/hlsl/hlsl_codegen_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,12 @@ class CodegenConstantPrinter final : public ConstantDecoder {
void _decode_bool(bool x) noexcept override {
PrintValue<bool>{}(x, _str);
}
void _decode_char(char x) noexcept override {
LUISA_NOT_IMPLEMENTED();
}
void _decode_uchar(uchar x) noexcept override {
LUISA_NOT_IMPLEMENTED();
}
void _decode_short(short x) noexcept override {
LUISA_NOT_IMPLEMENTED();
}
Expand Down
2 changes: 2 additions & 0 deletions src/backends/cuda/cuda_codegen_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,8 @@ class CUDAConstantPrinter final : public ConstantDecoder {

protected:
void _decode_bool(bool x) noexcept override { _codegen->_scratch << (x ? "true" : "false"); }
void _decode_char(char x) noexcept override { _codegen->_scratch << luisa::format("lc_byte({})", static_cast<int>(x)); }
void _decode_uchar(ubyte x) noexcept override { _codegen->_scratch << luisa::format("lc_ubyte({})", static_cast<uint>(x)); }
void _decode_short(short x) noexcept override { _codegen->_scratch << luisa::format("lc_short({})", x); }
void _decode_ushort(ushort x) noexcept override { _codegen->_scratch << luisa::format("lc_ushort({})", x); }
void _decode_int(int x) noexcept override { _codegen->_scratch << luisa::format("lc_int({})", x); }
Expand Down
2 changes: 2 additions & 0 deletions src/backends/metal/metal_codegen_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ class MetalConstantPrinter final : public ConstantDecoder {

protected:
void _decode_bool(bool x) noexcept override { _codegen->_scratch << (x ? "true" : "false"); }
void _decode_char(char x) noexcept override { _codegen->_scratch << luisa::format("char({})", static_cast<int>(x)); }
void _decode_uchar(uchar x) noexcept override { _codegen->_scratch << luisa::format("uchar({})", static_cast<uint>(x)); }
void _decode_short(short x) noexcept override { _codegen->_scratch << luisa::format("short({})", x); }
void _decode_ushort(ushort x) noexcept override { _codegen->_scratch << luisa::format("ushort({})", x); }
void _decode_int(int x) noexcept override { _codegen->_scratch << luisa::format("int({})", x); }
Expand Down

0 comments on commit 815126c

Please sign in to comment.