Skip to content

Commit

Permalink
add offload_device argument (#228)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs authored Dec 19, 2024
1 parent 665c987 commit 0f4760a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
52 changes: 36 additions & 16 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import contextlib
from functools import wraps
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union

import torch

Expand Down Expand Up @@ -161,25 +161,32 @@ def register_offload_parameter(
module: torch.nn.Module,
name: str,
parameter: torch.nn.Parameter,
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Register a parameter to the given module which may be offloaded
:param module: maybe offloaded module
:param name: name of newly registered parameter
:param parameter: parameter being registered
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters on module
"""
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
module.register_parameter(name, parameter)

if has_offloaded_params(module):
update_offload_parameter(module, name, parameter.data)
set_module_tensor_to_device(module, name, "meta")
weights_map = module._hf_hook.weights_map
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
if not has_onload:
set_module_tensor_to_device(module, name, "meta")


def update_offload_parameter(
module: torch.nn.Module,
name: str,
data: Optional[torch.Tensor],
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Update the data of an existing parameter and its offload dict. Supports both
Expand All @@ -188,6 +195,8 @@ def update_offload_parameter(
:param module: module containing the parameter to update
:param name: name of module parameter to update
:param data: tensor to update parameter with
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters on module
"""
param = getattr(module, name)
data = data.to(param.dtype)
Expand All @@ -199,7 +208,7 @@ def update_offload_parameter(
# update offload dict
if has_offloaded_params(module):
weights_map = module._hf_hook.weights_map
offload_to_weights_map(weights_map, name, data)
offload_to_weights_map(weights_map, name, data, offload_device)


def delete_offload_parameter(module: torch.nn.Module, name: str):
Expand Down Expand Up @@ -240,7 +249,7 @@ def offload_to_weights_map(
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
key: str,
value: torch.Tensor,
default_device: torch.device = torch.device("cpu"),
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Helper function which implements offloaded item assignment for PrefixedDataset,
Expand All @@ -249,33 +258,44 @@ def offload_to_weights_map(
:param weights_map: weight map to be updated with offload information
:param key: key used to identify weight location
:param value: weight being offloaded
:param default_device: in the event that the weights_map does already contain
offloaded weights or use disk offloading, the weight will be offloaded to the
`default_device`
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters in weights_map
"""
if isinstance(weights_map, PrefixedDataset):
if offload_device == "disk":
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")

dataset = weights_map.dataset
key = f"{weights_map.prefix}{key}"
offload_to_weights_map(dataset, key, value)
offload_to_weights_map(dataset, key, value, offload_device)

elif isinstance(weights_map, OffloadedWeightsLoader):
if key not in weights_map.all_keys:
weights_map.all_keys.append(key)

if len(weights_map.index) <= 0:
offload_to_weights_map(weights_map.state_dict, key, value)
if len(weights_map.index) <= 0 and offload_device != "disk":
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)

else:
raise NotImplementedError(
"Updating weights_map with disk offloading is not implemented yet"
)

elif isinstance(weights_map, dict):
if key in weights_map:
offload_device = weights_map[key].device
else:
tens = next(iter(weights_map.values()), None)
offload_device = tens.device if tens is not None else default_device
if offload_device == "disk":
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")

# infer offload device
if offload_device is None:
if key in weights_map:
offload_device = weights_map[key].device
else:
tens = next(iter(weights_map.values()), None)
if tens is None:
raise ValueError(
"Cannot infer offload device from empty weights_map"
)
offload_device = tens.device

weights_map[key] = value.to(device=offload_device)

Expand Down
31 changes: 26 additions & 5 deletions tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressed_tensors.utils import (
align_module_device,
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_register_offload_parameter():
# register a param after offloading, check that added param was offloaded
register_offload_parameter(module, "d", parameter)
assert hasattr(module, "d") and module.d.device == torch.device("meta")
assert "d" in module._hf_hook.weights_map
assert module._hf_hook.weights_map["d"].device == torch.device("cpu")

# added parameters can be onloaded and offloaded
with align_module_device(module, execution_device="cpu"):
Expand All @@ -81,6 +82,18 @@ def test_register_offload_parameter():
assert module.c.device == torch.device("meta")
assert module.d.device == torch.device("meta")

# parameters can be added during onload
with align_module_device(module, execution_device="cpu"):
register_offload_parameter(module, "e", parameter)
assert module.e.device == torch.device("cpu")

# parameters can be added before onload and with explicit offload
register_offload_parameter(module, "f", parameter, offload_device="cpu")
assert module._hf_hook.weights_map["f"].device == torch.device("cpu")
with align_module_device(module, execution_device="cpu"):
assert module.f.device == torch.device("cpu")
assert module._hf_hook.weights_map["f"].device == torch.device("cpu")


@requires_accelerate()
def test_update_offload_parameter():
Expand Down Expand Up @@ -195,7 +208,9 @@ def test_offload_to_weights_map():

# Dict empty
weights_map = {}
offload_to_weights_map(weights_map, name, new_value)
with pytest.raises(ValueError):
offload_to_weights_map(weights_map, name, new_value)
offload_to_weights_map(weights_map, name, new_value, offload_device="cpu")
assert weights_map[name] == new_value

# Dict populated
Expand All @@ -205,7 +220,9 @@ def test_offload_to_weights_map():

# OffloadedWeightsLoader[Dict] empty
weights_map = OffloadedWeightsLoader({})
offload_to_weights_map(weights_map, name, new_value)
with pytest.raises(ValueError):
offload_to_weights_map(weights_map, name, new_value)
offload_to_weights_map(weights_map, name, new_value, offload_device="cpu")
assert weights_map[name] == new_value

# OffloadedWeightsLoader[Dict] populated
Expand All @@ -215,7 +232,9 @@ def test_offload_to_weights_map():

# PrefixedDataset[Dict] empty
weights_map = PrefixedDataset({}, prefix)
offload_to_weights_map(weights_map, name, new_value)
with pytest.raises(ValueError):
offload_to_weights_map(weights_map, name, new_value)
offload_to_weights_map(weights_map, name, new_value, offload_device="cpu")
assert weights_map[name] == new_value

# PrefixedDataset[Dict] populated
Expand All @@ -225,7 +244,9 @@ def test_offload_to_weights_map():

# PrefixedDataset[OffloadedWeightsLoader[Dict]] empty
weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix)
offload_to_weights_map(weights_map, name, new_value)
with pytest.raises(ValueError):
offload_to_weights_map(weights_map, name, new_value)
offload_to_weights_map(weights_map, name, new_value, offload_device="cpu")
assert weights_map[name] == new_value

# PrefixedDataset[OffloadedWeightsLoader[Dict]] populated
Expand Down

0 comments on commit 0f4760a

Please sign in to comment.