Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cache] add a test to confirm we can use cache at train time #35709

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration,
OwlViTForObjectDetection,
PretrainedConfig,
Expand Down Expand Up @@ -1790,6 +1791,43 @@ def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
)
self.assertTrue(check_models_equal(model, model_loaded))

def test_cache_at_train_time(self):
gante marked this conversation as resolved.
Show resolved Hide resolved
"""
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
is at train time used if we request it. Related issue: #35648
"""
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# By default it is not training, we have to set it
self.assertFalse(model.training)
model.train()

# If we set `use_cache=True` while training, then a cache is returned
model_outputs = model(**model_inputs, use_cache=True)
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
self.assertTrue(model.training)

# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
(
model_inputs["attention_mask"],
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
),
dim=1,
)
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
self.assertTrue(model.training)

# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
model_outputs = model(**model_inputs, use_cache=False)
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)


@slow
@require_torch
Expand Down