Skip to content

Commit

Permalink
Merge pull request #77 from SkywardAI/feat/chat_summary
Browse files Browse the repository at this point in the history
add dialog and summarize pipeline
  • Loading branch information
Aisuko authored Apr 21, 2024
2 parents ae8c8d2 + 9d91164 commit 4a9cb7b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,38 @@ def test_chat_summary(self):
# res is str and should not be None
self.assertIsNotNone(res)
self.assertIsInstance(res, str)


def test_dialog_with_pipe(self):
"""
Test dialog with pipe method
"""
con = PipelinesFactory.init_conversation()
con.add_message({"role": "user", "content": "Dod you like weather of Melbourne?"})
con.add_message({"role": "assistant", "content": "Melbourne is also sunny which is my favourite weather"})
con.add_message({"role": "user", "content": "why Melbourne is a good place to travel?"})
pipe=Dialog.dialog_with_pipe(pipe_con=self.pipe_con, messages=con)

# pipe is list and should not be None
self.assertIsNotNone(pipe)


def test_summary_with_pipe(self):
"""
Test summary with pipe method
"""
paragraph="""
Melbourne, the vibrant capital of Victoria,
Australia, pulsates with a captivating blend of culture, art, and sport.
Its laneways are adorned with striking street art, while grand Victorian-era buildings stand as testaments to its rich history.
The city boasts world-class museums, like the Melbourne Museum and the National Gallery of Victoria,
alongside bustling markets and hidden bars waiting to be discovered.
Sports enthusiasts revel in the electric atmosphere of the Melbourne Cricket Ground and Rod Laver Arena,
hosting iconic events like the Australian Open and the Formula 1 Grand Prix. With its diverse culinary scene,
renowned coffee culture, and thriving nightlife, Melbourne offers an unforgettable experience for every visitor."""
pipe=PipelinesFactory.customized_pipe(model=self.summarization_model, device_map='auto')
res=Dialog.summary_with_pipe(summary_pipe=pipe, paragraph=paragraph, max_length=10)

# res is str and should not be None
self.assertIsNotNone(res)
self.assertIsInstance(res, str)
47 changes: 46 additions & 1 deletion src/kimchima/utils/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
)

@classmethod
def chat_summary(clas, *args,**kwargs)-> str:
def chat_summary(cls, *args,**kwargs)-> str:
r"""
Chat and summarize the conversation.
"""
Expand Down Expand Up @@ -64,3 +64,48 @@ def chat_summary(clas, *args,**kwargs)-> str:
logger.info("Finish summarization pipeline")

return response[0].get('summary_text')

@classmethod
def dialog_with_pipe(cls, *args, **kwargs):
r"""
Conversational pipeline with the conversation.
Args:
* conver_pipe: pipeline with `conversational` task
* like pipeline(task='conversational',model="microsoft/GODEL-v1_1-base-seq2seq", tokenizer=tokenizer)
* con: Huggingface transformers Conversation class instance
* **kwargs:
* max_length: maximum length of the response
* min_length: minimum length of the response
* top_k: top k tokens to sample from
* top_p: top p tokens to sample from
* temperature: temperature of the sampling
* do_sample: whether to sample
"""
conver_pipe=kwargs.pop("conver_pipe", None)
if conver_pipe is None:
raise ValueError("conversation pipeline is required")

con=kwargs.pop("con", None)
if con is None:
raise ValueError("con is required")

return conver_pipe(con, **kwargs)

@classmethod
def summary_with_pipe(cls, *args, **kwargs):
r"""
Summary conversaion records with the summarization pipeline.
"""
summary_pipe=kwargs.pop("summary_pipe", None)
if summary_pipe is None:
raise ValueError("summary pipeline is required")

paragraph=kwargs.pop("paragraph", None)
if paragraph is None:
raise ValueError("paragraph is required")

return summary_pipe(paragraph, **kwargs)



1 change: 1 addition & 0 deletions src/kimchima/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def auto_downloader(cls, *args, **kwargs):

folder_name=kwargs.pop("folder_name", None)
if folder_name is None:
#TODO folder_name equal to model_name will casue download issue
folder_name = model_name

# save_pretrained only saves the model weights, not the configuration
Expand Down

0 comments on commit 4a9cb7b

Please sign in to comment.