Skip to content

Commit

Permalink
Clean-up composite configs (#34603)
Browse files Browse the repository at this point in the history
* remove manual assignment tie-word-embeddings

* remove another unused attribute

* fix tests

* fix tests

* remove unnecessary overwrites

* fix

* decoder=True

* clean pix2struct

* run-all

* forgot `_tied_weights_keys` when adding Emu3

* also Aria + fix-copies

* and clean aria
  • Loading branch information
zucchini-nlp authored Jan 15, 2025
1 parent c61fcde commit 09d5f76
Show file tree
Hide file tree
Showing 33 changed files with 68 additions and 219 deletions.
3 changes: 0 additions & 3 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def _update_causal_mask(
self,
attention_mask,
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def tie_weights(self):
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
if getattr(self.config, "tie_word_embeddings", True):
if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True):
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
Expand Down Expand Up @@ -2104,7 +2104,10 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
new_num_tokens = new_embeddings.weight.shape[0]

# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
if (
self.get_output_embeddings() is not None
and not self.config.get_text_config(decoder=True).tie_word_embeddings
):
old_lm_head = self.get_output_embeddings()
if isinstance(old_lm_head, torch.nn.Embedding):
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
Expand Down Expand Up @@ -4604,7 +4607,10 @@ def _load_pretrained_model(
_loaded_keys = loaded_keys
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
# If we're about to tie the output embeds to the input embeds we don't need to init them
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
if (
hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings")
and model.config.get_text_config(decoder=True).tie_word_embeddings
):
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
# Still need to initialize if there is a bias term since biases are not tied.
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]

def __init__(self, config: AriaConfig):
super().__init__(config)
Expand Down Expand Up @@ -1406,9 +1407,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]

def __init__(self, config: AriaConfig):
super().__init__(config)
Expand Down Expand Up @@ -1383,9 +1384,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ def __init__(
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_text_hidden_size = image_text_hidden_size
self.image_token_index = image_token_index
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPTextConfig):
super().__init__(config)

text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation)
text_model = CLIPTextModel._from_config(config)
self.text_model = text_model.text_model

self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
Expand Down Expand Up @@ -1514,7 +1514,7 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)

vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation)
vision_model = CLIPVisionModel._from_config(config)
self.vision_model = vision_model.vision_model

self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,8 @@ def forward(


class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,8 @@ def forward(**super_kwargs):


class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
Expand Down
9 changes: 3 additions & 6 deletions src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
Expand Down Expand Up @@ -181,9 +181,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2104,9 +2104,7 @@ def __init__(self, config: GroundingDinoConfig):
)

# Create text backbone
self.text_backbone = AutoModel.from_config(
config.text_config, add_pooling_layer=False, attn_implementation=config._attn_implementation
)
self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False)
self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model)

if config.embedding_init_target or not config.two_stage:
Expand Down
33 changes: 0 additions & 33 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,13 +1285,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.text_model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
return model_embeds

def inputs_merger(
self,
input_ids: torch.LongTensor,
Expand Down Expand Up @@ -1515,32 +1508,6 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
# model_embeds = self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Update base model and current model config
# Ignore copy
self.config.text_config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = self.config.text_config.vocab_size

# Tie weights again if needed
self.tie_weights()

return model_embeds

def tie_weights(self):
"""
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()

if getattr(self.config, "tie_word_embeddings", True):
output_embeddings.weight = input_embeddings.weight

@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down
11 changes: 0 additions & 11 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,17 +1094,6 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.tie_weights
def tie_weights(self):
"""
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()

if getattr(self.config, "tie_word_embeddings", True):
output_embeddings.weight = input_embeddings.weight

@add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ def __init__(
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_token_index = image_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ def __init__(
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.video_token_index = video_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def __init__(
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.video_token_index = video_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
Expand Down
14 changes: 4 additions & 10 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def __init__(self, config: LlavaConfig):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)

if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1

self.post_init()
Expand All @@ -264,16 +268,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
Expand Down
15 changes: 3 additions & 12 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ def __init__(self, config: LlavaNextConfig):

self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
Expand Down Expand Up @@ -395,18 +398,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()

# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(
self,
image_features,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def __init__(

self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)
Expand Down Expand Up @@ -430,16 +433,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(
self,
image_features,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ def __init__(self, config: LlavaOnevisionConfig):

self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.post_init()

# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
Expand All @@ -400,10 +403,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()

def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,9 @@ def __init__(self, config: MllamaConfig):

self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaForCausalLM._from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]

self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
Expand All @@ -2011,9 +2014,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig")
def forward(
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,12 +1914,9 @@ def __init__(self, config: MoshiConfig):
self.embed_tokens = nn.ModuleList(
[nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)]
)
self.audio_encoder = AutoModel.from_config(
config.audio_encoder_config, attn_implementation=config._attn_implementation
)
self.audio_encoder = AutoModel.from_config(config.audio_encoder_config)
self.decoder = MoshiForCausalLM(config)

config.depth_decoder_config._attn_implementation_internal = config._attn_implementation
self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config)

self.num_codebooks = config.num_codebooks
Expand Down
Loading

0 comments on commit 09d5f76

Please sign in to comment.