diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 620292cf..b3c77c58 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -27,7 +27,7 @@ import contextlib from functools import wraps -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import torch @@ -161,6 +161,7 @@ def register_offload_parameter( module: torch.nn.Module, name: str, parameter: torch.nn.Parameter, + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Register a parameter to the given module which may be offloaded @@ -168,18 +169,24 @@ def register_offload_parameter( :param module: maybe offloaded module :param name: name of newly registered parameter :param parameter: parameter being registered + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module """ + has_onload = any(p.device != torch.device("meta") for p in module.parameters()) module.register_parameter(name, parameter) if has_offloaded_params(module): - update_offload_parameter(module, name, parameter.data) - set_module_tensor_to_device(module, name, "meta") + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, name, parameter.data, offload_device) + if not has_onload: + set_module_tensor_to_device(module, name, "meta") def update_offload_parameter( module: torch.nn.Module, name: str, data: Optional[torch.Tensor], + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Update the data of an existing parameter and its offload dict. Supports both @@ -188,6 +195,8 @@ def update_offload_parameter( :param module: module containing the parameter to update :param name: name of module parameter to update :param data: tensor to update parameter with + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters on module """ param = getattr(module, name) data = data.to(param.dtype) @@ -199,7 +208,7 @@ def update_offload_parameter( # update offload dict if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - offload_to_weights_map(weights_map, name, data) + offload_to_weights_map(weights_map, name, data, offload_device) def delete_offload_parameter(module: torch.nn.Module, name: str): @@ -240,7 +249,7 @@ def offload_to_weights_map( weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], key: str, value: torch.Tensor, - default_device: torch.device = torch.device("cpu"), + offload_device: Optional[Union[torch.device, Literal["disk"]]] = None, ): """ Helper function which implements offloaded item assignment for PrefixedDataset, @@ -249,21 +258,23 @@ def offload_to_weights_map( :param weights_map: weight map to be updated with offload information :param key: key used to identify weight location :param value: weight being offloaded - :param default_device: in the event that the weights_map does already contain - offloaded weights or use disk offloading, the weight will be offloaded to the - `default_device` + :param offload_device: device on which weight will be offloaded to. If None is + provided, then infer device from parameters in weights_map """ if isinstance(weights_map, PrefixedDataset): + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + dataset = weights_map.dataset key = f"{weights_map.prefix}{key}" - offload_to_weights_map(dataset, key, value) + offload_to_weights_map(dataset, key, value, offload_device) elif isinstance(weights_map, OffloadedWeightsLoader): if key not in weights_map.all_keys: weights_map.all_keys.append(key) - if len(weights_map.index) <= 0: - offload_to_weights_map(weights_map.state_dict, key, value) + if len(weights_map.index) <= 0 and offload_device != "disk": + offload_to_weights_map(weights_map.state_dict, key, value, offload_device) else: raise NotImplementedError( @@ -271,11 +282,20 @@ def offload_to_weights_map( ) elif isinstance(weights_map, dict): - if key in weights_map: - offload_device = weights_map[key].device - else: - tens = next(iter(weights_map.values()), None) - offload_device = tens.device if tens is not None else default_device + if offload_device == "disk": + raise ValueError(f"Cannot offload to disk with type {type(weights_map)}") + + # infer offload device + if offload_device is None: + if key in weights_map: + offload_device = weights_map[key].device + else: + tens = next(iter(weights_map.values()), None) + if tens is None: + raise ValueError( + "Cannot infer offload device from empty weights_map" + ) + offload_device = tens.device weights_map[key] = value.to(device=offload_device) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index befeeb84..1002a4f5 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.utils import ( align_module_device, @@ -72,7 +73,7 @@ def test_register_offload_parameter(): # register a param after offloading, check that added param was offloaded register_offload_parameter(module, "d", parameter) assert hasattr(module, "d") and module.d.device == torch.device("meta") - assert "d" in module._hf_hook.weights_map + assert module._hf_hook.weights_map["d"].device == torch.device("cpu") # added parameters can be onloaded and offloaded with align_module_device(module, execution_device="cpu"): @@ -81,6 +82,18 @@ def test_register_offload_parameter(): assert module.c.device == torch.device("meta") assert module.d.device == torch.device("meta") + # parameters can be added during onload + with align_module_device(module, execution_device="cpu"): + register_offload_parameter(module, "e", parameter) + assert module.e.device == torch.device("cpu") + + # parameters can be added before onload and with explicit offload + register_offload_parameter(module, "f", parameter, offload_device="cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + with align_module_device(module, execution_device="cpu"): + assert module.f.device == torch.device("cpu") + assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + @requires_accelerate() def test_update_offload_parameter(): @@ -195,7 +208,9 @@ def test_offload_to_weights_map(): # Dict empty weights_map = {} - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # Dict populated @@ -205,7 +220,9 @@ def test_offload_to_weights_map(): # OffloadedWeightsLoader[Dict] empty weights_map = OffloadedWeightsLoader({}) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # OffloadedWeightsLoader[Dict] populated @@ -215,7 +232,9 @@ def test_offload_to_weights_map(): # PrefixedDataset[Dict] empty weights_map = PrefixedDataset({}, prefix) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # PrefixedDataset[Dict] populated @@ -225,7 +244,9 @@ def test_offload_to_weights_map(): # PrefixedDataset[OffloadedWeightsLoader[Dict]] empty weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix) - offload_to_weights_map(weights_map, name, new_value) + with pytest.raises(ValueError): + offload_to_weights_map(weights_map, name, new_value) + offload_to_weights_map(weights_map, name, new_value, offload_device="cpu") assert weights_map[name] == new_value # PrefixedDataset[OffloadedWeightsLoader[Dict]] populated