Skip to content

Commit

Permalink
Timm wrapper label names (#35553)
Browse files Browse the repository at this point in the history
* Add timm wrapper label names mapping

* Add index to classification pipeline

* Revert adding index for pipelines

* Add custom model check for loading timm labels

* Add tests for labels

* [run-slow] timm_wrapper

* Add note regarding label2id mapping
  • Loading branch information
qubvel authored Jan 8, 2025
1 parent f1639ea commit 59e5b3f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
34 changes: 32 additions & 2 deletions src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing import Any, Dict

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils import is_timm_available, logging


if is_timm_available():
from timm.data import ImageNetInfo, infer_imagenet_subset


logger = logging.get_logger(__name__)
Expand All @@ -33,6 +37,9 @@ class TimmWrapperConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
imagenet models is set to `None` due to occlusions in the label descriptions.
Args:
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Expand Down Expand Up @@ -60,10 +67,30 @@ def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **k

@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
label_names = config_dict.get("label_names", None)
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs

# if no labels added to config, use imagenet labeller in timm
if label_names is None and not is_custom_model:
imagenet_subset = infer_imagenet_subset(config_dict)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
synsets = dataset_info.label_names()
label_descriptions = dataset_info.label_descriptions(as_dict=True)
label_names = [label_descriptions[synset] for synset in synsets]

if label_names is not None and not is_custom_model:
kwargs["id2label"] = dict(enumerate(label_names))

# if all label names are unique, create label2id mapping as well
if len(set(label_names)) == len(label_names):
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
else:
kwargs["label2id"] = None

# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
# and to avoid duplicate attributes

num_labels_in_kwargs = kwargs.pop("num_labels", None)
num_labels_in_dict = config_dict.pop("num_classes", None)

Expand All @@ -80,6 +107,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
output["num_classes"] = self.num_labels
output["label_names"] = list(self.id2label.values())
output.pop("id2label", None)
output.pop("label2id", None)
return output


Expand Down
48 changes: 48 additions & 0 deletions tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,54 @@ def test_do_pooling_option(self):
output = model(**inputs_dict, do_pooling=True)
self.assertIsNotNone(output.pooler_output)

def test_timm_config_labels(self):
# test timm config with no labels
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
self.assertIsNone(config.label2id)
self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 1000)
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")

# test timm config with labels in config
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
config = TimmWrapperConfig.from_pretrained(checkpoint)

self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 10000)
self.assertEqual(config.id2label[1], "Sabella spallanzanii")

self.assertIsInstance(config.label2id, dict)
self.assertEqual(len(config.label2id), 10000)
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)

# test custom labels are provided
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})

# test with provided id2label and label2id
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})

# test save load
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)

self.assertEqual(config.num_labels, restored_config.num_labels)
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)


# We will verify our results on an image of cute cats
def prepare_img():
Expand Down

0 comments on commit 59e5b3f

Please sign in to comment.