Skip to content

Commit

Permalink
Correctly list the chat template file in the Tokenizer saved files li…
Browse files Browse the repository at this point in the history
…st (#34974)

* Correctly list the chat template file in the saved files list

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <[email protected]>

* Add save file checking to test

* make fixup

* better filename handling

* make fixup

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
Rocketknight1 and ArthurZucker authored Jan 7, 2025
1 parent cdca3cf commit a7d1441
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,7 @@ def save_pretrained(
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
tokenizer_config.update(self.extra_special_tokens)

saved_raw_chat_template = False
if self.chat_template is not None:
if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
Expand All @@ -2439,6 +2440,7 @@ def save_pretrained(
elif kwargs.get("save_raw_chat_template", False):
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
saved_raw_chat_template = True
logger.info(f"chat template saved in {chat_template_file}")
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
Expand Down Expand Up @@ -2498,6 +2500,8 @@ def save_pretrained(
logger.info(f"Special tokens file saved in {special_tokens_map_file}")

file_names = (tokenizer_config_file, special_tokens_map_file)
if saved_raw_chat_template:
file_names += (chat_template_file,)

save_files = self._save_pretrained(
save_directory=save_directory,
Expand Down
8 changes: 6 additions & 2 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,9 @@ def test_chat_template(self):
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)

with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
save_files = tokenizer.save_pretrained(tmp_dir_name)
# Check we aren't saving a chat_template.jinja file
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)

self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
Expand All @@ -1117,7 +1119,9 @@ def test_chat_template(self):
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)

with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
# Check we are saving a chat_template.jinja file
self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files))
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
self.assertTrue(chat_template_file.is_file())
self.assertEqual(chat_template_file.read_text(), dummy_template)
Expand Down

0 comments on commit a7d1441

Please sign in to comment.