From a7d1441d657bdb2abf04c3931017ef8c9c3580cd Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 7 Jan 2025 19:11:02 +0000 Subject: [PATCH] Correctly list the chat template file in the Tokenizer saved files list (#34974) * Correctly list the chat template file in the saved files list * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add save file checking to test * make fixup * better filename handling * make fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 4 ++++ tests/test_tokenization_common.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index de0bc87b26b676..86e07a382f8812 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2429,6 +2429,7 @@ def save_pretrained( tokenizer_config["extra_special_tokens"] = self.extra_special_tokens tokenizer_config.update(self.extra_special_tokens) + saved_raw_chat_template = False if self.chat_template is not None: if isinstance(self.chat_template, dict): # Chat template dicts are saved to the config as lists of dicts with fixed key names. @@ -2439,6 +2440,7 @@ def save_pretrained( elif kwargs.get("save_raw_chat_template", False): with open(chat_template_file, "w", encoding="utf-8") as f: f.write(self.chat_template) + saved_raw_chat_template = True logger.info(f"chat template saved in {chat_template_file}") if "chat_template" in tokenizer_config: tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too @@ -2498,6 +2500,8 @@ def save_pretrained( logger.info(f"Special tokens file saved in {special_tokens_map_file}") file_names = (tokenizer_config_file, special_tokens_map_file) + if saved_raw_chat_template: + file_names += (chat_template_file,) save_files = self._save_pretrained( save_directory=save_directory, diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index ed09d800ad6dd5..d6957757dc55d8 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1107,7 +1107,9 @@ def test_chat_template(self): tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: - tokenizer.save_pretrained(tmp_dir_name) + save_files = tokenizer.save_pretrained(tmp_dir_name) + # Check we aren't saving a chat_template.jinja file + self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files)) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted @@ -1117,7 +1119,9 @@ def test_chat_template(self): new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: - tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True) + save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True) + # Check we are saving a chat_template.jinja file + self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files)) chat_template_file = Path(tmp_dir_name) / "chat_template.jinja" self.assertTrue(chat_template_file.is_file()) self.assertEqual(chat_template_file.read_text(), dummy_template)