Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Divide as a * (1.0 / b) during weight compression #3055

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def calculate_quantized_weight(
config: WeightCompressionConfig,
scale: Tensor,
zero_point: Optional[Tensor] = None,
invert_scale=False,
) -> Tensor:
"""
Quantizes the weight tensor using the provided scale and zero point.
Expand All @@ -295,7 +294,6 @@ def calculate_quantized_weight(
:param config: Weight compression configuration.
:param scale: Scale tensor used for quantization.
:param zero_point: Zero point tensor used for quantization.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
:return: Quantized weight tensor of uint8 or int8 type.
"""
if weight.dtype != TensorDataType.float32:
Expand All @@ -309,11 +307,7 @@ def calculate_quantized_weight(
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

if invert_scale:
scale = fns.power(scale, -1)
compressed_weights = weight * scale
else:
compressed_weights = weight / scale
compressed_weights = weight * fns.reciprocal(scale)
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
Expand All @@ -328,7 +322,6 @@ def do_int_quantization(
config: WeightCompressionConfig,
precomputed_scale: Tensor = None,
precomputed_zero_point: Tensor = None,
invert_scale=False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
The method quantizes the given weights to integer data type uniformly in accordance with the compression config.
Expand All @@ -351,8 +344,6 @@ def do_int_quantization(
:param config: Information on how to compress (quantize) a specific weight.
:param precomputed_scale: Precomputed scale.
:param precomputed_zero_point: Precomputed zero point.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
Need as reference implementation for OV.
:return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type,
scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization.
"""
Expand Down Expand Up @@ -380,7 +371,7 @@ def do_int_quantization(
if precomputed_zero_point is not None:
zero_point = precomputed_zero_point

compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale)
compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point)
return compressed_weights, scale, zero_point


Expand Down
2 changes: 2 additions & 0 deletions nncf/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
from nncf.tensor.functions.numeric import minimum as minimum
from nncf.tensor.functions.numeric import moveaxis as moveaxis
from nncf.tensor.functions.numeric import multiply as multiply
from nncf.tensor.functions.numeric import ones as ones
from nncf.tensor.functions.numeric import ones_like as ones_like
from nncf.tensor.functions.numeric import percentile as percentile
from nncf.tensor.functions.numeric import power as power
from nncf.tensor.functions.numeric import quantile as quantile
from nncf.tensor.functions.numeric import reciprocal as reciprocal
from nncf.tensor.functions.numeric import reshape as reshape
from nncf.tensor.functions.numeric import round as round
from nncf.tensor.functions.numeric import searchsorted as searchsorted
Expand Down
35 changes: 35 additions & 0 deletions nncf/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,27 @@ def zeros(
return Tensor(get_numeric_backend_fn("zeros", backend)(shape, dtype=dtype, device=device))


def ones(
shape: Tuple[int, ...],
*,
backend: TensorBackend,
dtype: Optional[TensorDataType] = None,
device: Optional[TensorDeviceType] = None,
) -> Tensor:
"""
Return a new array of given shape and type, filled with ones.

:param shape: Shape of the new array
:param backend: The backend type for which the ones tensor is required.
:param dtype: The data type of the returned tensor, If dtype is not given,
then the default data type is determined by backend.
:param device: The device on which the tensor will be allocated, If device is not given,
then the default device is determined by backend.
:return: A tensor filled with ones of the specified shape and data type.
"""
return Tensor(get_numeric_backend_fn("ones", backend)(shape, dtype=dtype, device=device))


def eye(
n: int,
m: Optional[int] = None,
Expand Down Expand Up @@ -905,3 +926,17 @@ def ceil(a: Tensor) -> Tensor:
:return: An array of the same type as a, containing the ceiling values.
"""
return Tensor(ceil(a.data))


@functools.singledispatch
@tensor_guard
def reciprocal(a: Tensor) -> Tensor:
"""
Compute the reciprocal of a tensor or a float.

This function returns a new tensor where each element is the reciprocal of the corresponding element in `a`.

:param a: The input tensor or float.
:return: A tensor containing the reciprocal of each element in `a`.
"""
return Tensor(reciprocal(a.data))
18 changes: 18 additions & 0 deletions nncf/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,19 @@ def zeros(
return np.zeros(shape, dtype=dtype)


def ones(
shape: Tuple[int, ...],
*,
dtype: Optional[TensorDataType] = None,
device: Optional[TensorDeviceType] = None,
) -> np.ndarray:
if device is not None and device != TensorDeviceType.CPU:
raise ValueError("numpy_numeric.ones only supports CPU device.")
if dtype is not None:
dtype = DTYPE_MAP[dtype]
return np.ones(shape, dtype=dtype)


def eye(
n: int,
m: Optional[int] = None,
Expand Down Expand Up @@ -431,3 +444,8 @@ def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]:
@register_numpy_types(numeric.ceil)
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.ceil(a)


@register_numpy_types(numeric.reciprocal)
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.reciprocal(a)
18 changes: 18 additions & 0 deletions nncf/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,19 @@ def zeros(
return torch.zeros(*shape, dtype=dtype, device=device)


def ones(
shape: Tuple[int, ...],
*,
dtype: Optional[TensorDataType] = None,
device: Optional[TensorDeviceType] = None,
) -> torch.Tensor:
if dtype is not None:
dtype = DTYPE_MAP[dtype]
if device is not None:
device = DEVICE_MAP[device]
return torch.ones(*shape, dtype=dtype, device=device)


def eye(
n: int,
m: Optional[int] = None,
Expand Down Expand Up @@ -465,3 +478,8 @@ def _(a: torch.Tensor) -> torch.Tensor:
@numeric.ceil.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.ceil(a)


@numeric.reciprocal.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.reciprocal(a)
19 changes: 19 additions & 0 deletions tests/cross_fw/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,19 @@ def test_fn_zeros(self):
assert tensor_a.shape == shape
assert fns.all(tensor_a == 0)

def test_fn_ones(self):
shape = (2, 2)
for dtype in TensorDataType:
if dtype == TensorDataType.bfloat16 and self.backend() == TensorBackend.numpy:
continue
tensor_a = fns.ones(shape, backend=self.backend(), dtype=dtype, device=self.device())
assert isinstance(tensor_a, Tensor)
assert tensor_a.device == self.device()
assert tensor_a.backend == self.backend()
assert tensor_a.dtype == dtype
assert tensor_a.shape == shape
assert fns.all(tensor_a == 1)

@pytest.mark.parametrize(
"n, m, ref",
(
Expand Down Expand Up @@ -1695,3 +1708,9 @@ def test_svd(self, a, full_matrices, abs_res_ref):
for act, abs_ref in zip(res, abs_res_ref):
assert isinstance(act, Tensor)
assert fns.allclose(fns.abs(act), abs_ref, atol=1e-7)

@pytest.mark.parametrize("a,ref", [([1], [1.0]), ([2, 4], [0.5, 0.25])])
def test_reciprocal(self, a, ref):
t_a = Tensor(self.to_tensor(a)).astype(TensorDataType.float32)
res = fns.reciprocal(t_a)
assert fns.allclose(res, Tensor(self.to_tensor(ref)).astype(TensorDataType.float32))
45 changes: 30 additions & 15 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,32 +1078,47 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
assert ref_e8m0_nodes == names_e8m0


@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM))
def test_np_ov_compression_decompression(mode):
sz = 60
w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0
@pytest.mark.parametrize("mode", [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM])
@pytest.mark.parametrize(
"w,s,zp",
[
(
np.array([[1.4372410774230957]], np.float32),
np.array([[-0.9581607580184937]], np.float32),
np.array([[1]], np.int32),
),
(np.arange(-60, 60).reshape(2, 60).astype(np.float32) / 9.0, None, None),
],
)
def test_np_ov_compression_decompression(mode, w, s, zp):
w = Tensor(w)
if s is not None:
s = Tensor(s)
if mode == CompressWeightsMode.INT4_SYM:
zp = None
if zp is not None:
zp = Tensor(zp)

config = WeightCompressionConfig(mode)

compressed_weighs, scale, zp = do_int_quantization(w, -1, config, invert_scale=True)
decompressed_weighs = do_int_dequantization(compressed_weighs, scale, zp)
compressed_weights, s, zp = do_int_quantization(w, -1, config, precomputed_scale=s, precomputed_zero_point=zp)
decompressed_weights = do_int_dequantization(compressed_weights, s, zp)

compressed_weighs = compressed_weighs.data
decompressed_weighs = decompressed_weighs.data
compressed_weights = compressed_weights.data
decompressed_weights = decompressed_weights.data
zp_shape = zp.shape if zp is not None else None

compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape)
compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, s.shape, zp_shape)
compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline(
config, w.shape, scale.shape, zp_shape
config, w.shape, s.shape, zp_shape
)

params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data]
compressed_weighs_ov = compress(params)
decompressed_weighs_ov = compress_decompress(params)
params = [w.data, s.data, zp.data] if zp is not None else [w.data, s.data]
compressed_weights_ov = compress(params)
decompressed_weights_ov = compress_decompress(params)

assert np.allclose(compressed_weighs, compressed_weighs_ov)
assert np.allclose(decompressed_weighs, decompressed_weighs_ov)
assert np.allclose(compressed_weights, compressed_weights_ov, atol=0)
assert np.allclose(decompressed_weights, decompressed_weights_ov, atol=0)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tinyllama_int8_data_free_backend_TORCH:
num_int4: 0
num_int8: 312
tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV:
metric_value: 0.86503
metric_value: 0.81880
num_int4: 94
num_int8: 124
metrics_xfail_reason: "Issue-148819"
Expand Down