From 9d9116441b704958a3c781d0405964ad1eee9818 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sun, 21 Apr 2024 11:17:26 +0000 Subject: [PATCH] add dialog and summarize pipeline Signed-off-by: Aisuko --- .../tests/{test_chat.py => test_dialog.py} | 35 ++++++++++++++ src/kimchima/utils/dialog.py | 47 ++++++++++++++++++- src/kimchima/utils/downloader.py | 1 + 3 files changed, 82 insertions(+), 1 deletion(-) rename src/kimchima/tests/{test_chat.py => test_dialog.py} (52%) diff --git a/src/kimchima/tests/test_chat.py b/src/kimchima/tests/test_dialog.py similarity index 52% rename from src/kimchima/tests/test_chat.py rename to src/kimchima/tests/test_dialog.py index 607125d..0b1a858 100644 --- a/src/kimchima/tests/test_chat.py +++ b/src/kimchima/tests/test_dialog.py @@ -52,3 +52,38 @@ def test_chat_summary(self): # res is str and should not be None self.assertIsNotNone(res) self.assertIsInstance(res, str) + + + def test_dialog_with_pipe(self): + """ + Test dialog with pipe method + """ + con = PipelinesFactory.init_conversation() + con.add_message({"role": "user", "content": "Dod you like weather of Melbourne?"}) + con.add_message({"role": "assistant", "content": "Melbourne is also sunny which is my favourite weather"}) + con.add_message({"role": "user", "content": "why Melbourne is a good place to travel?"}) + pipe=Dialog.dialog_with_pipe(pipe_con=self.pipe_con, messages=con) + + # pipe is list and should not be None + self.assertIsNotNone(pipe) + + + def test_summary_with_pipe(self): + """ + Test summary with pipe method + """ + paragraph=""" + Melbourne, the vibrant capital of Victoria, + Australia, pulsates with a captivating blend of culture, art, and sport. + Its laneways are adorned with striking street art, while grand Victorian-era buildings stand as testaments to its rich history. + The city boasts world-class museums, like the Melbourne Museum and the National Gallery of Victoria, + alongside bustling markets and hidden bars waiting to be discovered. + Sports enthusiasts revel in the electric atmosphere of the Melbourne Cricket Ground and Rod Laver Arena, + hosting iconic events like the Australian Open and the Formula 1 Grand Prix. With its diverse culinary scene, + renowned coffee culture, and thriving nightlife, Melbourne offers an unforgettable experience for every visitor.""" + pipe=PipelinesFactory.customized_pipe(model=self.summarization_model, device_map='auto') + res=Dialog.summary_with_pipe(summary_pipe=pipe, paragraph=paragraph, max_length=10) + + # res is str and should not be None + self.assertIsNotNone(res) + self.assertIsInstance(res, str) diff --git a/src/kimchima/utils/dialog.py b/src/kimchima/utils/dialog.py index 9dbd0af..50fca7a 100644 --- a/src/kimchima/utils/dialog.py +++ b/src/kimchima/utils/dialog.py @@ -29,7 +29,7 @@ def __init__(self): ) @classmethod - def chat_summary(clas, *args,**kwargs)-> str: + def chat_summary(cls, *args,**kwargs)-> str: r""" Chat and summarize the conversation. """ @@ -64,3 +64,48 @@ def chat_summary(clas, *args,**kwargs)-> str: logger.info("Finish summarization pipeline") return response[0].get('summary_text') + + @classmethod + def dialog_with_pipe(cls, *args, **kwargs): + r""" + Conversational pipeline with the conversation. + + Args: + * conver_pipe: pipeline with `conversational` task + * like pipeline(task='conversational',model="microsoft/GODEL-v1_1-base-seq2seq", tokenizer=tokenizer) + * con: Huggingface transformers Conversation class instance + * **kwargs: + * max_length: maximum length of the response + * min_length: minimum length of the response + * top_k: top k tokens to sample from + * top_p: top p tokens to sample from + * temperature: temperature of the sampling + * do_sample: whether to sample + """ + conver_pipe=kwargs.pop("conver_pipe", None) + if conver_pipe is None: + raise ValueError("conversation pipeline is required") + + con=kwargs.pop("con", None) + if con is None: + raise ValueError("con is required") + + return conver_pipe(con, **kwargs) + + @classmethod + def summary_with_pipe(cls, *args, **kwargs): + r""" + Summary conversaion records with the summarization pipeline. + """ + summary_pipe=kwargs.pop("summary_pipe", None) + if summary_pipe is None: + raise ValueError("summary pipeline is required") + + paragraph=kwargs.pop("paragraph", None) + if paragraph is None: + raise ValueError("paragraph is required") + + return summary_pipe(paragraph, **kwargs) + + + diff --git a/src/kimchima/utils/downloader.py b/src/kimchima/utils/downloader.py index d897ea4..da08c07 100644 --- a/src/kimchima/utils/downloader.py +++ b/src/kimchima/utils/downloader.py @@ -91,6 +91,7 @@ def auto_downloader(cls, *args, **kwargs): folder_name=kwargs.pop("folder_name", None) if folder_name is None: + #TODO folder_name equal to model_name will casue download issue folder_name = model_name # save_pretrained only saves the model weights, not the configuration