Skip to content

Commit

Permalink
[cache] add a test to confirm we can use cache at train time (#35709)
Browse files Browse the repository at this point in the history
* add test

* augment test as suggested

* Update tests/utils/test_modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* rerun tests

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
gante and ArthurZucker authored Jan 16, 2025
1 parent 57bf1a1 commit aeeceb9
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration,
OwlViTForObjectDetection,
PretrainedConfig,
Expand Down Expand Up @@ -1790,6 +1791,43 @@ def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
)
self.assertTrue(check_models_equal(model, model_loaded))

def test_cache_when_needed_at_train_time(self):
"""
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
is at train time used if we request it. Related issue: #35648
"""
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# By default it is not training, we have to set it
self.assertFalse(model.training)
model.train()

# If we set `use_cache=True` while training, then a cache is returned
model_outputs = model(**model_inputs, use_cache=True)
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
self.assertTrue(model.training)

# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
(
model_inputs["attention_mask"],
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
),
dim=1,
)
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
self.assertTrue(model.training)

# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
model_outputs = model(**model_inputs, use_cache=False)
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)


@slow
@require_torch
Expand Down

0 comments on commit aeeceb9

Please sign in to comment.