Skip to content

Commit

Permalink
Implement aliasable mixin and alias activation ordering (#213)
Browse files Browse the repository at this point in the history
* implement aliasable mixin and alias activation ordering

Signed-off-by: Kyle Sayers <[email protected]>

* update docstring

Signed-off-by: Kyle Sayers <[email protected]>

* fix docstring

Signed-off-by: Kyle Sayers <[email protected]>

* uncomment

Signed-off-by: Kyle Sayers <[email protected]>

* rename and make abstract

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs authored Nov 29, 2024
1 parent 525ef3a commit 724d5ce
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Optional, Union

import torch
from compressed_tensors.utils import Aliasable
from pydantic import BaseModel, Field, field_validator, model_validator


Expand Down Expand Up @@ -53,17 +54,30 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token"


class ActivationOrdering(str, Enum):
class ActivationOrdering(Aliasable, str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
latency when compared to group actorder\n
Dynamic: alias for Group\n
Static: alias for Weight\n
"""

GROUP = "group"
WEIGHT = "weight"
# aliases
DYNAMIC = "dynamic"
STATIC = "static"

@property
@staticmethod
def aliases(self) -> Dict[str, str]:
return {
"dynamic": "group",
"static": "weight",
}


class QuantizationArgs(BaseModel, use_enum_values=True):
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:

return model


"""
Pre-Set Quantization Scheme Args
"""
Expand Down
37 changes: 36 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional
from abc import abstractmethod
from typing import Any, Dict, Optional

import torch
from transformers import AutoConfig
Expand All @@ -24,6 +25,7 @@
"tensor_follows_mask_structure",
"replace_module",
"is_compressed_tensors_config",
"Aliasable",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
Expand Down Expand Up @@ -119,3 +121,36 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
return isinstance(compression_config, CompressedTensorsConfig)
except ImportError:
return False


class Aliasable:
"""
A mixin for enums to allow aliasing of enum members
Example:
>>> class MyClass(Aliasable, int, Enum):
>>> ...
"""

@property
@staticmethod
@abstractmethod
def aliases(self) -> Dict[str, str]:
raise NotImplementedError()

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value or (
self.aliases.get(self.value, self.value)
== self.aliases.get(other.value, other.value)
)
else:
self_value = self.aliases.get(self.value, self.value)
other_value = self.aliases.get(other, other)
return self_value == other_value

return False

def __hash__(self):
canonical_value = self.aliases.get(self.value, self.value)
return hash(canonical_value)
46 changes: 46 additions & 0 deletions tests/test_quantization/test_quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,28 @@ def test_actorder():
# test group inference with actorder
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.GROUP)
assert args.strategy == QuantizationStrategy.GROUP
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.DYNAMIC)
assert args.strategy == QuantizationStrategy.GROUP

# test invalid pairings
with pytest.raises(ValueError):
QuantizationArgs(group_size=None, actorder="group")
with pytest.raises(ValueError):
QuantizationArgs(group_size=None, actorder="weight")
with pytest.raises(ValueError):
QuantizationArgs(group_size=None, actorder="static")
with pytest.raises(ValueError):
QuantizationArgs(group_size=-1, actorder="group")
with pytest.raises(ValueError):
QuantizationArgs(group_size=-1, actorder="weight")
with pytest.raises(ValueError):
QuantizationArgs(group_size=-1, actorder="static")
with pytest.raises(ValueError):
QuantizationArgs(strategy="tensor", actorder="group")
with pytest.raises(ValueError):
QuantizationArgs(strategy="tensor", actorder="weight")
with pytest.raises(ValueError):
QuantizationArgs(strategy="tensor", actorder="static")

# test boolean and none defaulting
assert (
Expand All @@ -101,6 +115,38 @@ def test_actorder():
assert QuantizationArgs(group_size=1, actorder=None).actorder is None


def test_actorder_aliases():
assert (
ActivationOrdering.GROUP
== ActivationOrdering.DYNAMIC
== ActivationOrdering.GROUP
)
assert (
ActivationOrdering.WEIGHT
== ActivationOrdering.STATIC
== ActivationOrdering.WEIGHT
)

assert ActivationOrdering.GROUP == "dynamic" == ActivationOrdering.GROUP
assert ActivationOrdering.DYNAMIC == "dynamic" == ActivationOrdering.DYNAMIC
assert ActivationOrdering.GROUP == "group" == ActivationOrdering.GROUP
assert ActivationOrdering.DYNAMIC == "group" == ActivationOrdering.DYNAMIC

assert ActivationOrdering.WEIGHT == "static" == ActivationOrdering.WEIGHT
assert ActivationOrdering.STATIC == "static" == ActivationOrdering.STATIC
assert ActivationOrdering.WEIGHT == "weight" == ActivationOrdering.WEIGHT
assert ActivationOrdering.STATIC == "weight" == ActivationOrdering.STATIC

assert ActivationOrdering.WEIGHT != "dynamic" != ActivationOrdering.WEIGHT
assert ActivationOrdering.STATIC != "dynamic" != ActivationOrdering.STATIC
assert ActivationOrdering.WEIGHT != "group" != ActivationOrdering.WEIGHT
assert ActivationOrdering.STATIC != "group" != ActivationOrdering.STATIC
assert ActivationOrdering.GROUP != "static" != ActivationOrdering.GROUP
assert ActivationOrdering.DYNAMIC != "static" != ActivationOrdering.DYNAMIC
assert ActivationOrdering.GROUP != "weight" != ActivationOrdering.GROUP
assert ActivationOrdering.DYNAMIC != "weight" != ActivationOrdering.DYNAMIC


def test_invalid():
with pytest.raises(ValidationError):
QuantizationArgs(type="invalid")
Expand Down

0 comments on commit 724d5ce

Please sign in to comment.