From d2eba7f723e474dd4250eebb3172739c24ec1d58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Hadnagy?= Date: Mon, 6 Jan 2025 23:26:24 +0100 Subject: [PATCH] Remove debugging-related stuff --- .../library/extensions/cuda/__init__.py | 3 - .../test_marlin_int4_packed_tensor.py | 156 ++++-------------- .../test_marlin_int4_weight_qbits_tensor.py | 5 +- 3 files changed, 35 insertions(+), 129 deletions(-) diff --git a/optimum/quanto/library/extensions/cuda/__init__.py b/optimum/quanto/library/extensions/cuda/__init__.py index b23f3705..7ba29365 100644 --- a/optimum/quanto/library/extensions/cuda/__init__.py +++ b/optimum/quanto/library/extensions/cuda/__init__.py @@ -48,8 +48,6 @@ def get_max_cuda_arch(): extra_cuda_cflags = [ "--expt-extended-lambda", "--use_fast_math", - "-lineinfo", - '-O0' ] # We need to know the minimum CUDA Arch to select only the relevant kernels # but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code) @@ -189,7 +187,6 @@ def gemm_f16i4_marlin( dtype=input.dtype, device=input.device, ) - print(f"input shapes: {input.reshape((-1, input.shape[-1])).shape}, in2: {other.shape}, out: {output.reshape((-1, output.shape[-1])).shape}") ext.lib.marlin_gemm_f16i4( input.reshape((-1, input.shape[-1])), other, diff --git a/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py b/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py index 185dc26b..4048a3ea 100644 --- a/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py +++ b/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py @@ -12,139 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import numpy as np import pytest import torch -from helpers import device_eq, random_qweight -from tensor.weights.weight_helpers import check_weight_qtensor_linear +from helpers import device_eq + +from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor + -from optimum.quanto import qint4 -from optimum.quanto.library.extensions import is_extension_available -from optimum.quanto.tensor.weights import WeightQBitsTensor -from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor +def get_uint4_tensor(shape, device, random=False): + qmax = 2**4 + if random: + t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) + else: + numel = np.prod(shape) + t = torch.tensor(range(numel), dtype=torch.int32) + t = (t % qmax).reshape(shape).to(torch.uint8).to(device) + return t -@pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available" -) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) -def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features): - qtype = qint4 - group_size = 128 - dtype = torch.float16 +@pytest.mark.parametrize("random", [True, False]) +def test_pack_marlin_int4_tensor(in_features, out_features, random): shape = (out_features, in_features) device = torch.device("cuda") - qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) - # Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members - marlinqbt = MarlinInt4WeightQBitsTensor( - qtype=qbt.qtype, - axis=qbt.axis, - group_size=qbt._group_size, - size=qbt.size(), - stride=qbt.stride(), - data=qbt._data.unpack(), - scale=qbt._scale, - shift=qbt._shift, - ) - assert marlinqbt.dtype == dtype - assert marlinqbt.qtype == qtype - assert marlinqbt.shape == shape - assert device_eq(marlinqbt.device, device) - # Verify the dequantized tensors are identical - assert torch.equal(marlinqbt.dequantize(), qbt.dequantize()) - # Now verify that we can reconstruct the WeightQBitsTensor - new_qbt = marlinqbt.weight_qbits_tensor() - assert type(new_qbt) is WeightQBitsTensor - assert new_qbt.dtype == dtype - assert new_qbt.qtype == qtype - assert new_qbt.shape == shape - assert torch.equal(new_qbt._data, qbt._data) - assert torch.equal(new_qbt._scale, qbt._scale) - assert torch.equal(new_qbt._shift, qbt._shift) + t = get_uint4_tensor(shape, device, random) + packed = MarlinInt4PackedTensor.pack(t) + assert isinstance(packed, MarlinInt4PackedTensor) + assert device_eq(packed.device, device) + assert torch.equal(t, packed.unpack()) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_marlin_int4_weight_qbits_tensor_move(device): - qtype = qint4 - group_size = 128 - dtype = torch.float16 - shape = (1024, 1024) +def test_move_marlin_int4_packed_tensor(device): + shape = (256, 256) device = torch.device("cuda") - # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA - qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda")) - marlinqbt = MarlinInt4WeightQBitsTensor( - qtype=qbt.qtype, - axis=qbt.axis, - group_size=qbt._group_size, - size=qbt.size(), - stride=qbt.stride(), - data=qbt._data.unpack(), - scale=qbt._scale, - shift=qbt._shift, - ) - # Move to device, dequantize and compare - moved_qbt = marlinqbt.to(device) - assert isinstance(moved_qbt, WeightQBitsTensor) - if device.type != "cuda": - assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor - assert marlinqbt.dtype == moved_qbt.dtype - assert marlinqbt.qtype == moved_qbt.qtype - assert marlinqbt.shape == moved_qbt.shape - assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize()) - - -def _test_marlin_int4_weight_qbits_tensor_linear( - dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias -): - # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA - qbt = random_qweight( - (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda") - ) - marlin_qweight = MarlinInt4WeightQBitsTensor( - qtype=qbt.qtype, - axis=qbt.axis, - group_size=qbt._group_size, - size=qbt.size(), - stride=qbt.stride(), - data=qbt._data.unpack(), - scale=qbt._scale, - shift=qbt._shift, - ) - check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias) - - -@pytest.mark.skipif( - not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, - reason="CUDA >= sm80 not available", -) -@pytest.mark.parametrize("batch_size", [1, 2]) -@pytest.mark.parametrize("tokens", [16, 32]) -@pytest.mark.parametrize("in_features", [1024]) -@pytest.mark.parametrize("out_features", [1024, 2048, 4096]) -@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias): - dtype = torch.float16 - weight_qtype = qint4 - group_size = 128 - _test_marlin_int4_weight_qbits_tensor_linear( - dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias - ) - - -#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332 -@pytest.mark.skipif( - not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, - reason="CUDA >= sm80 not available", -) -@pytest.mark.parametrize("batch_size", [1, 2]) -@pytest.mark.parametrize("tokens", [48, 64]) -# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) -@pytest.mark.parametrize("in_features", [4096, 16384]) -@pytest.mark.parametrize("out_features", [2048, 4096]) -def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): - dtype = torch.float16 - weight_qtype = qint4 - group_size = 128 - _test_marlin_int4_weight_qbits_tensor_linear( - dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False - ) \ No newline at end of file + t = get_uint4_tensor(shape, device) + packed = MarlinInt4PackedTensor.pack(t) + moved = packed.to("cuda") + assert isinstance(moved, MarlinInt4PackedTensor) + # Marlin int4 tensors are unpacked when moved out of CUDA device + moved = packed.to("cpu") + assert type(moved) is torch.Tensor + assert torch.equal(t, moved.to("cuda")) diff --git a/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py index a44db5b2..c1a8d93d 100644 --- a/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py +++ b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py @@ -131,15 +131,14 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, ) -@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False) +#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332 @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [48, 64]) -# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) -@pytest.mark.parametrize("in_features", [4096, 16384]) +@pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) @pytest.mark.parametrize("out_features", [2048, 4096]) def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): dtype = torch.float16