diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index d15057ce..12b43287 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -127,15 +127,17 @@ def decompress( yield other_name, value @staticmethod - def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool: + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: """ Check if a parameter should be compressed :param name: name of the parameter - :param targets: set of layer prefixes to compress + :param expanded_targets: set of layer prefixes to compress :return: whether or not the parameter should be compressed """ - if targets is None: + if expanded_targets is None: return name.endswith(".weight") - return name.endswith(".weight") and name[: -(len(".weight"))] in targets + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index e4f8d7a7..11f0f326 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -32,11 +32,10 @@ "get_nested_weight_mappings", "get_quantization_state_dict", "is_quantization_param", - "get_nested_mappings_from_state_dict", ] -WEIGHT_MAPPING_TYPE = Dict[str, str] -NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE] +WeightMappingType = Dict[str, str] +NestedWeightMappingType = Dict[str, WeightMappingType] def get_safetensors_folder( @@ -181,9 +180,7 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: def get_nested_weight_mappings( model_path: str, params_to_nest: List[str], return_other_params: bool = False -) -> Union[ - NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE] -]: +) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]: """ Takes a path to a state dict saved in safetensors format and returns a nested mapping from uncompressed parameterized layer names to the file locations of each @@ -256,17 +253,3 @@ def is_quantization_param(name: str) -> bool: return True return False - - -def get_nested_mappings_from_state_dict(state_dict, params_to_nest): - nested_weight_mappings = {} - for key in state_dict.keys(): - for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match - if dense_param not in nested_weight_mappings: - nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = state_dict[key] - - return nested_weight_mappings diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7474795b..d1799a39 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -296,7 +296,7 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): @pytest.mark.parametrize( - "targets, ignore, expected", + "targets, ignore, expected_targets", [ ([], [], set()), (["layer1", "layer2"], [], {"layer1", "layer2"}), @@ -305,13 +305,13 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), ], ) -def test_expand_targets_with_mock(mock_model, targets, ignore, expected): - result = expand_targets(mock_model, targets, ignore) - assert result == expected +def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): + expanded_targets = expand_targets(mock_model, targets, ignore) + assert expanded_targets == expected_targets @pytest.mark.parametrize( - "targets, ignore, expected", + "targets, ignore, expected_targets", [ ( ["re:model.layers.[01].self_attn.q_proj"], @@ -344,10 +344,10 @@ def test_expand_targets_with_mock(mock_model, targets, ignore, expected): ], ) def test_expand_targets_with_llama_stories( - llama_stories_model, targets, ignore, expected + llama_stories_model, targets, ignore, expected_targets ): - actual_targets = expand_targets(llama_stories_model, targets, ignore) - assert actual_targets == expected + expanded_targets = expand_targets(llama_stories_model, targets, ignore) + assert expanded_targets == expected_targets @pytest.mark.parametrize(