Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into kylesayrs/upstream-candidates
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 19, 2024
2 parents 9177650 + 975cb22 commit 38d7dbf
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_compression_config(
format, **sparsity_config
)
if quantization_config is not None:
quantization_config = QuantizationConfig.parse_obj(quantization_config)
quantization_config = QuantizationConfig.model_validate(quantization_config)

return cls(
sparsity_config=sparsity_config, quantization_config=quantization_config
Expand Down Expand Up @@ -193,7 +193,7 @@ def parse_sparsity_config(

if is_compressed_tensors_config(compression_config):
s_config = compression_config.sparsity_config
return s_config.dict() if s_config is not None else None
return s_config.model_dump() if s_config is not None else None

return compression_config.get(SPARSITY_CONFIG_NAME, None)

Expand All @@ -214,7 +214,7 @@ def parse_quantization_config(

if is_compressed_tensors_config(compression_config):
q_config = compression_config.quantization_config
return q_config.dict() if q_config is not None else None
return q_config.model_dump() if q_config is not None else None

quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def model_post_init(self, __context):

def to_dict(self):
# for compatibility with HFQuantizer
return self.dict()
return self.model_dump()

@staticmethod
def from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


version_base = "0.8.0"
version_base = "0.8.1"
is_release = True # change to True to set the generated version as a release version


Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor
min_val = torch.amin(value, dim=dim, keepdims=True)
max_val = torch.amax(value, dim=dim, keepdims=True)
scale, zp = calculate_qparams(min_val, max_val, args)
scale = scale.reshape((1, 1))
zp = zp.reshape((1, 1))
update_parameter_data(module, scale, f"{base_name}_scale")
update_parameter_data(module, zp, f"{base_name}_zero_point")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path):
)
q_config = QuantizationConfig(**q_config) if q_config is not None else None

s_config_dict = s_config.dict() if s_config is not None else None
q_config_dict = q_config.dict() if q_config is not None else None
s_config_dict = s_config.model_dump() if s_config is not None else None
q_config_dict = q_config.model_dump() if q_config is not None else None

assert compressor.sparsity_config == s_config
assert compressor.quantization_config == q_config
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
},
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)
return QuantizationConfig.model_validate(config_dict)


@requires_accelerate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def get_sample_dynamic_tinyllama_quant_config():
},
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)
return QuantizationConfig.model_validate(config_dict)
89 changes: 87 additions & 2 deletions tests/test_quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@


import pytest
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from tests.testing_utils import requires_accelerate
from torch.nn import Linear


NUM_BITS = 8
Q_PARAM_NAMES = {
"input_activations": "input",
"weights": "weight",
"output_activations": "output",
}


@pytest.fixture
Expand Down Expand Up @@ -115,3 +125,78 @@ def test_initialize_module_for_quantization_offloaded(
input_activations,
layer,
)


@pytest.mark.parametrize(
"weights,input_activations",
[
(
QuantizationArgs(strategy="tensor"),
QuantizationArgs(strategy="tensor"),
),
(
QuantizationArgs(strategy="channel"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="group"),
None,
),
(
QuantizationArgs(strategy="group", group_size=2, actorder="weight"),
None,
),
(
QuantizationArgs(strategy="block"),
QuantizationArgs(strategy="block"),
),
(
QuantizationArgs(strategy="token"),
QuantizationArgs(strategy="token"),
),
],
)
def test_initialize_quantization_parameters(weights, input_activations):
quantization_scheme = QuantizationScheme(
targets=["*"],
weights=weights,
input_activations=input_activations,
)
layer = Linear(7, 8)
initialize_module_for_quantization(layer, quantization_scheme)

for q_type in ("input_activations", "weights"):
args = getattr(quantization_scheme, q_type)
if args is None:
continue
q_param_name = Q_PARAM_NAMES[q_type]

# scale and zero point
if args.strategy == QuantizationStrategy.TENSOR:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.CHANNEL: # only weight
expected_shape = (layer.weight.shape[0], 1)

elif args.strategy == QuantizationStrategy.GROUP: # only weight
num_groups = layer.weight.shape[1] // args.group_size
expected_shape = (layer.weight.shape[0], max(num_groups, 1))

elif args.strategy == QuantizationStrategy.BLOCK:
expected_shape = (1,)

elif args.strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)

assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape

# g_idx
if args.actorder == ActivationOrdering.GROUP:
assert getattr(layer, f"{q_param_name}_g_idx").shape == (
layer.weight.shape[1],
)
20 changes: 10 additions & 10 deletions tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_channelwise(
if input_symmetry is not None:
mock_per_channel_calibration(model, base_name="input", value=inputs)

assert list(model.weight_scale.shape) == [model_shape[1], 1]
assert list(model.weight_zero_point.shape) == [model_shape[1], 1]
assert model.weight_scale.shape == (model_shape[1], 1)
assert model.weight_zero_point.shape == (model_shape[1], 1)


@torch.no_grad
Expand Down Expand Up @@ -97,14 +97,14 @@ def test_group(
model, base_name="input", value=inputs, group_size=group_size
)

assert list(model.weight_scale.shape) == [
assert model.weight_scale.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
assert list(model.weight_zero_point.shape) == [
)
assert model.weight_zero_point.shape == (
model_shape[1],
int(model_shape[0] / group_size),
]
)


@torch.no_grad
Expand All @@ -131,8 +131,8 @@ def test_token(
mock_per_channel_calibration(model, base_name="weight", value=model.weight)
mock_per_token_calibration(model, base_name="input", value=inputs)

assert list(model.input_scale.shape) == [1, 1]
assert list(model.input_zero_point.shape) == [1, 1]
assert model.input_scale.shape == (1, 1)
assert model.input_zero_point.shape == (1, 1)

assert list(model.weight_scale.shape) == [256, 1]
assert list(model.weight_zero_point.shape) == [256, 1]
assert model.weight_scale.shape == (256, 1)
assert model.weight_zero_point.shape == (256, 1)
7 changes: 7 additions & 0 deletions tests/test_quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ def test_load_scheme_from_preset(scheme_name: str):
assert scheme_name in config.config_groups
assert isinstance(config.config_groups[scheme_name], QuantizationScheme)
assert config.config_groups[scheme_name].targets == targets


def test_to_dict():
config_groups = {"group_1": QuantizationScheme(targets=[])}
config = QuantizationConfig(config_groups=config_groups)
reloaded = QuantizationConfig.model_validate(config.to_dict())
assert config == reloaded
58 changes: 58 additions & 0 deletions tests/test_quantization/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.utils import calculate_qparams


@pytest.mark.parametrize(
"keepdims,strategy,exp_shape",
[
(
False,
QuantizationStrategy.TENSOR,
torch.Size(
[
1,
]
),
),
(True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])),
(True, QuantizationStrategy.GROUP, torch.Size([1, 1])),
(
False,
QuantizationStrategy.BLOCK,
torch.Size(
[
1,
]
),
),
(True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
],
)
def test_calculate_qparams(keepdims, strategy, exp_shape):
value = torch.randn(14, 5)
min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)

if strategy == QuantizationStrategy.GROUP:
args = QuantizationArgs(strategy=strategy, group_size=2)
else:
args = QuantizationArgs(strategy=strategy)
scale, zp = calculate_qparams(min_val, max_val, args)
assert scale.shape == exp_shape
assert zp.shape == exp_shape

0 comments on commit 38d7dbf

Please sign in to comment.