diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index 0267aca4..85eebe00 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -93,9 +93,11 @@ def compress_weight( args=quantization_args, dtype=quantization_args.pytorch_dtype(), ) + else: + quantized_weight = weight - if device is not None: - quantized_weight = quantized_weight.to(device) + if device is not None: + quantized_weight = quantized_weight.to(device) return {"weight": quantized_weight} diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index ce9f0a57..c236f8c9 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -94,6 +94,8 @@ def compress_weight( args=quantization_args, dtype=torch.int8, ) + else: + quantized_weight = weight packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits) weight_shape = torch.tensor(weight.shape)