Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM: enable skipped tests #35746

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,9 @@ def _init_weights(self, module):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""

BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -2094,6 +2097,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -2217,6 +2221,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
Expand All @@ -2242,6 +2247,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss
logits = outputs.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def _init_weights(self, module):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""


Expand Down Expand Up @@ -1375,6 +1378,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1485,6 +1489,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
Expand All @@ -1510,6 +1515,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,9 @@ def forward(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""


Expand Down Expand Up @@ -1369,6 +1372,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1512,6 +1516,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
Expand All @@ -1537,6 +1542,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
```python
Expand Down Expand Up @@ -322,6 +323,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
Expand All @@ -347,6 +349,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/kosmos2/modeling_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,7 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
use_cache=None,
cache_position=None,
**model_kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -1704,17 +1705,21 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_shape)

position_ids = None
if cache_position is None:
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)

# cut input_ids if past_key_values is used
if past_key_values is not None:
position_ids = create_position_ids_from_input_ids(
input_ids,
padding_idx=self.config.pad_token_id,
past_key_values_length=0,
)[:, -1:]
)

if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = position_ids[:, -input_ids.shape[1] :]

input_ids = input_ids[:, -1:]
# the image info. is already encoded into the past keys/values
image_embeds = None
image_embeds_position_mask = None
elif image_embeds_position_mask is not None:
Expand Down
45 changes: 32 additions & 13 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
Expand Down Expand Up @@ -634,7 +634,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
Expand Down Expand Up @@ -963,7 +963,7 @@ def test_contrastive_generate(self):
config, inputs_dict = self.prepare_config_and_inputs_for_generate()

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True

Expand Down Expand Up @@ -992,7 +992,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
config, inputs_dict = self.prepare_config_and_inputs_for_generate()

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True

Expand Down Expand Up @@ -1032,7 +1032,7 @@ def test_contrastive_generate_low_memory(self):
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

config.is_decoder = True
Expand Down Expand Up @@ -1151,6 +1151,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
Expand All @@ -1159,7 +1163,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

config.is_decoder = True
Expand Down Expand Up @@ -1224,6 +1228,10 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
"seamlessm4t",
"clvp",
"fuyu",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
Expand All @@ -1232,7 +1240,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

config.is_decoder = True
Expand Down Expand Up @@ -1338,6 +1346,10 @@ def test_assisted_decoding_sample(self):
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
Expand All @@ -1346,7 +1358,7 @@ def test_assisted_decoding_sample(self):
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

config.is_decoder = True
Expand Down Expand Up @@ -1538,7 +1550,7 @@ def test_past_key_values_format(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

model = model_class(config).to(torch_device)
Expand Down Expand Up @@ -1573,7 +1585,14 @@ def test_past_key_values_format(self):

# Encoder-Decoder checks
if config.is_encoder_decoder:
encoder_num_attention_heads = config.encoder_attention_heads
# encoder-decoder models usually don't have text config
# below is needed only for Pix2Struct which we cannot modify now due to BC
config = config.get_text_config()
encoder_num_attention_heads = (
config.encoder_attention_heads
if hasattr(config, "encoder_attention_heads")
else config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
for i in range(num_hidden_layers):
Expand Down Expand Up @@ -1772,14 +1791,14 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")

config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

# Let's make it always:
Expand Down Expand Up @@ -2129,7 +2148,7 @@ def test_assisted_decoding_with_num_logits_to_keep(self):

config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
Expand Down
6 changes: 3 additions & 3 deletions tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def __init__(
moe_intermediate_size=4,
moe_num_experts=4,
moe_topk=2,
num_attention_heads=20,
num_attention_heads=8,
num_experts_per_tok=3,
num_hidden_layers=2,
num_key_value_heads=20,
num_key_value_heads=8,
rope_theta=5000000,
vocab_size=99,
eos_token_id=2,
head_dim=2,
head_dim=4,
),
is_training=True,
vision_config=Idefics3VisionConfig(
Expand Down
Loading