diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 477d084b1d9309..da0b354fe76efa 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -249,9 +249,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - def _update_causal_mask( self, attention_mask, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c09c11050041d8..3e8d73f94b13ec 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1862,7 +1862,7 @@ def tie_weights(self): If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the weights instead. """ - if getattr(self.config, "tie_word_embeddings", True): + if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True): output_embeddings = self.get_output_embeddings() if output_embeddings is not None: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) @@ -2104,7 +2104,10 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean new_num_tokens = new_embeddings.weight.shape[0] # if word embeddings are not tied, make sure that lm head is resized as well - if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + if ( + self.get_output_embeddings() is not None + and not self.config.get_text_config(decoder=True).tie_word_embeddings + ): old_lm_head = self.get_output_embeddings() if isinstance(old_lm_head, torch.nn.Embedding): new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) @@ -4604,7 +4607,10 @@ def _load_pretrained_model( _loaded_keys = loaded_keys not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) # If we're about to tie the output embeds to the input embeds we don't need to init them - if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: + if ( + hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings") + and model.config.get_text_config(decoder=True).tie_word_embeddings + ): output_embeddings = model.get_output_embeddings() if output_embeddings is not None: # Still need to initialize if there is a bias term since biases are not tied. diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 43bf1c22687fe2..0b330b4aeeda2c 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1360,6 +1360,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False _supports_sdpa = False + _tied_weights_keys = ["language_model.lm_head.weight"] def __init__(self, config: AriaConfig): super().__init__(config) @@ -1406,9 +1407,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - def get_image_features( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 78c6e08bdfd0e5..295e2dcb7465b1 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1337,6 +1337,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False _supports_sdpa = False + _tied_weights_keys = ["language_model.lm_head.weight"] def __init__(self, config: AriaConfig): super().__init__(config) @@ -1383,9 +1384,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - def get_image_features( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 539a3e365c9883..3ea3b76878da7c 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -302,9 +302,6 @@ def __init__( text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" self.text_config = CONFIG_MAPPING[text_model_type](**text_config) - self.tie_word_embeddings = self.text_config.tie_word_embeddings - self.is_encoder_decoder = self.text_config.is_encoder_decoder - self.num_query_tokens = num_query_tokens self.image_text_hidden_size = image_text_hidden_size self.image_token_index = image_token_index diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 1818c6bb0963c0..01c8f4dcbc9a5b 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -1433,7 +1433,7 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel): def __init__(self, config: CLIPTextConfig): super().__init__(config) - text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation) + text_model = CLIPTextModel._from_config(config) self.text_model = text_model.text_model self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) @@ -1514,7 +1514,7 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel): def __init__(self, config: CLIPVisionConfig): super().__init__(config) - vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation) + vision_model = CLIPVisionModel._from_config(config) self.vision_model = vision_model.vision_model self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 557f63338a1f02..722d9078d28a09 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1791,6 +1791,8 @@ def forward( class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_model.lm_head.weight"] + def __init__(self, config): super().__init__(config) self.text_model = Emu3ForCausalLM._from_config(config.text_config) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index e9b80d5cbb4deb..da6016dc266bf9 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1103,6 +1103,8 @@ def forward(**super_kwargs): class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_model.lm_head.weight"] + def __init__(self, config): super().__init__(config) self.text_model = Emu3ForCausalLM._from_config(config.text_config) diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index a7afb411c44805..79f82b5ac4bd99 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -151,9 +151,9 @@ def __init__(self, config: FuyuConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] self.vision_embed_tokens = nn.Linear( config.patch_size * config.patch_size * config.num_channels, config.hidden_size @@ -181,9 +181,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - def gather_continuous_embeddings( self, word_embeddings: torch.Tensor, diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 4a101b1d93b4f7..283409327bf914 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -2104,9 +2104,7 @@ def __init__(self, config: GroundingDinoConfig): ) # Create text backbone - self.text_backbone = AutoModel.from_config( - config.text_config, add_pooling_layer=False, attn_implementation=config._attn_implementation - ) + self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False) self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model) if config.embedding_init_target or not config.two_stage: diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index bdae2c70d5f28b..4e819811a9849d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1285,13 +1285,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_model.set_input_embeddings(value) - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.text_model.resize_token_embeddings( - new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of - ) - self.config.text_config.vocab_size = model_embeds.num_embeddings - return model_embeds - def inputs_merger( self, input_ids: torch.LongTensor, @@ -1515,32 +1508,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - # model_embeds = self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) - model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - if new_num_tokens is None and pad_to_multiple_of is None: - return model_embeds - - # Update base model and current model config - # Ignore copy - self.config.text_config.vocab_size = model_embeds.weight.shape[0] - self.vocab_size = self.config.text_config.vocab_size - - # Tie weights again if needed - self.tie_weights() - - return model_embeds - - def tie_weights(self): - """ - Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding. - """ - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() - - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings.weight = input_embeddings.weight - @add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 0bfa2823c52e7d..31cf1a2e8f1173 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1094,17 +1094,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.tie_weights - def tie_weights(self): - """ - Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding. - """ - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() - - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings.weight = input_embeddings.weight - @add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/instructblip/configuration_instructblip.py b/src/transformers/models/instructblip/configuration_instructblip.py index d4f15ab0a58e5f..328d64761a5667 100644 --- a/src/transformers/models/instructblip/configuration_instructblip.py +++ b/src/transformers/models/instructblip/configuration_instructblip.py @@ -302,9 +302,6 @@ def __init__( text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" self.text_config = CONFIG_MAPPING[text_model_type](**text_config) - self.tie_word_embeddings = self.text_config.tie_word_embeddings - self.is_encoder_decoder = self.text_config.is_encoder_decoder - self.num_query_tokens = num_query_tokens self.image_token_index = image_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index 14687a96e54f37..6776c1b62b8852 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -308,9 +308,6 @@ def __init__( text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" self.text_config = CONFIG_MAPPING[text_model_type](**text_config) - self.tie_word_embeddings = self.text_config.tie_word_embeddings - self.is_encoder_decoder = self.text_config.is_encoder_decoder - self.num_query_tokens = num_query_tokens self.video_token_index = video_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 7184955af3aa56..1376e85c6f9525 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -137,9 +137,6 @@ def __init__( text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" self.text_config = CONFIG_MAPPING[text_model_type](**text_config) - self.tie_word_embeddings = self.text_config.tie_word_embeddings - self.is_encoder_decoder = self.text_config.is_encoder_decoder - self.num_query_tokens = num_query_tokens self.video_token_index = video_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 3d9bc339fd29a4..93d7465291cb19 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -242,6 +242,10 @@ def __init__(self, config: LlavaConfig): self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -264,16 +268,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str ): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 71e46389892822..51df47233b26c0 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -357,6 +357,9 @@ def __init__(self, config: LlavaNextConfig): self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.post_init() @@ -395,18 +398,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights - def tie_weights(self): - return self.language_model.tie_weights() - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def _merge_input_ids_with_image_features( self, image_features, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index b1ae26aaac80e0..257c81aa8fe4df 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -397,6 +397,9 @@ def __init__( self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.vision_resampler = LlavaNextVideoPooler(config) @@ -430,16 +433,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def _merge_input_ids_with_image_features( self, image_features, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 7bc88ec95ab359..5c5471479e86bf 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -374,6 +374,9 @@ def __init__(self, config: LlavaOnevisionConfig): self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.post_init() # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings @@ -400,10 +403,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights - def tie_weights(self): - return self.language_model.tie_weights() - def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index cb9a20dadc677e..b40c366a6d75ef 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1986,6 +1986,9 @@ def __init__(self, config: MllamaConfig): self.vision_model = MllamaVisionModel._from_config(config.vision_config) self.language_model = MllamaForCausalLM._from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, @@ -2011,9 +2014,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") def forward( diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 35e107d7cb7a8c..a1c15b7a0b3775 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1914,12 +1914,9 @@ def __init__(self, config: MoshiConfig): self.embed_tokens = nn.ModuleList( [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] ) - self.audio_encoder = AutoModel.from_config( - config.audio_encoder_config, attn_implementation=config._attn_implementation - ) + self.audio_encoder = AutoModel.from_config(config.audio_encoder_config) self.decoder = MoshiForCausalLM(config) - config.depth_decoder_config._attn_implementation_internal = config._attn_implementation self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config) self.num_codebooks = config.num_codebooks diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 0d7eca8aa1dc06..36a9e59118b678 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -335,10 +335,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma - def tie_weights(self): - return self.language_model.tie_weights() - def _update_causal_mask( self, attention_mask, diff --git a/src/transformers/models/pix2struct/configuration_pix2struct.py b/src/transformers/models/pix2struct/configuration_pix2struct.py index f6924c4784b3ac..db2f0ff7e3ca31 100644 --- a/src/transformers/models/pix2struct/configuration_pix2struct.py +++ b/src/transformers/models/pix2struct/configuration_pix2struct.py @@ -14,9 +14,6 @@ # limitations under the License. """Pix2Struct model configuration""" -import os -from typing import Union - from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -147,26 +144,6 @@ def __init__( **kwargs, ) - @classmethod - def from_pretrained( - cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs - ) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) - - # get the text config dict if we are loading from Pix2StructConfig - if config_dict.get("model_type") == "pix2struct": - config_dict = config_dict["text_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - class Pix2StructVisionConfig(PretrainedConfig): r""" @@ -266,26 +243,6 @@ def __init__( self.relative_attention_max_distance = relative_attention_max_distance self.d_kv = d_kv - @classmethod - def from_pretrained( - cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs - ) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) - - # get the vision config dict if we are loading from Pix2StructConfig - if config_dict.get("model_type") == "pix2struct": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - class Pix2StructConfig(PretrainedConfig): r""" diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index ba5f1421ed4af1..77ce68659a50f7 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1733,14 +1733,6 @@ def get_output_embeddings(self) -> nn.Module: def set_output_embeddings(self, new_embeddings): self.decoder.set_output_embeddings(new_embeddings) - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: - model_embeds = self.decoder.resize_token_embeddings(new_num_tokens) - - # update vocab size - self.config.text_config.vocab_size = new_num_tokens - - return model_embeds - def get_decoder(self): return self.decoder diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 421bb3801dfd37..a844a67861d540 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -856,6 +856,9 @@ def __init__(self, config: Qwen2AudioConfig): self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.post_init() @@ -894,18 +897,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights - def tie_weights(self): - return self.language_model.tie_weights() - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def _merge_input_ids_with_audio_features( self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels ): diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index dfadca184483e4..9fa099d19230d5 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -261,6 +261,9 @@ def get_encoder(self): def get_decoder(self): return self.decoder + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + def get_output_embeddings(self): return self.decoder.get_output_embeddings() diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index aeff4ad1d0c29d..293fb10ae27795 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -245,6 +245,9 @@ def __init__(self, config: VideoLlavaConfig): self.multi_modal_projector = VideoLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -266,17 +269,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def _merge_input_ids_with_visual_features( self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1 ): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 084b92c9771c4e..0daaa8327b631b 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -242,6 +242,10 @@ def __init__(self, config: VipLlavaConfig): self.multi_modal_projector = VipLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -264,16 +268,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - # Ignore copy def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]): """ diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index abca01987c7293..55c759f8e9ae2e 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -237,6 +237,9 @@ def get_encoder(self): def get_decoder(self): return self.decoder + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + def get_output_embeddings(self): return self.decoder.get_output_embeddings() @@ -659,12 +662,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def resize_token_embeddings(self, *args, **kwargs): - raise NotImplementedError( - "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" - " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" - ) - def _reorder_cache(self, past_key_values, beam_idx): # apply decoder cache reordering here return self.decoder._reorder_cache(past_key_values, beam_idx) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 5e18b006a5d815..5556f14a0b93b1 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1123,6 +1123,13 @@ def is_pipeline_test_to_skip( def setUp(self): self.model_tester = Blip2ModelTester(self) + common_properties = ["image_token_index", "num_query_tokens", "image_text_hidden_size"] + self.config_tester = ConfigTester( + self, config_class=Blip2Config, has_text_modality=False, common_properties=common_properties + ) + + def test_config(self): + self.config_tester.run_common_tests() def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index baacc12caa073d..335b690d879f18 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -158,7 +158,10 @@ class InstructBlipVisionModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = InstructBlipVisionModelTester(self) self.config_tester = ConfigTester( - self, config_class=InstructBlipVisionConfig, has_text_modality=False, hidden_size=37 + self, + config_class=InstructBlipConfig, + has_text_modality=False, + common_properties=["num_query_tokens", "image_token_index"], ) def test_config(self): diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 3be5f89325cf38..6b59a9878aa4f4 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -163,8 +163,9 @@ class InstructBlipVideoVisionModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = InstructBlipVideoVisionModelTester(self) + common_properties = ["num_query_tokens", "video_token_index"] self.config_tester = ConfigTester( - self, config_class=InstructBlipVideoVisionConfig, has_text_modality=False, hidden_size=37 + self, config_class=InstructBlipVideoConfig, has_text_modality=False, common_properties=common_properties ) def test_config(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0d12bf77d861fe..965d7593693397 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2283,7 +2283,7 @@ def test_load_save_without_tied_weights(self): def test_tied_weights_keys(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - config.tie_word_embeddings = True + config.get_text_config().tie_word_embeddings = True for model_class in self.all_model_classes: model_tied = model_class(config)