diff --git a/llama-index-core/llama_index/core/callbacks/token_counting.py b/llama-index-core/llama_index/core/callbacks/token_counting.py index a1b5edc207e16..736743dde8f60 100644 --- a/llama-index-core/llama_index/core/callbacks/token_counting.py +++ b/llama-index-core/llama_index/core/callbacks/token_counting.py @@ -30,6 +30,7 @@ class TokenCountingEvent: completion_token_count: int prompt_token_count: int total_token_count: int = 0 + cached_tokens: int = 0 event_id: str = "" def __post_init__(self) -> None: @@ -56,7 +57,8 @@ def get_tokens_from_response( possible_input_keys = ("prompt_tokens", "input_tokens") possible_output_keys = ("completion_tokens", "output_tokens") - + openai_prompt_tokens_details_key = 'prompt_tokens_details' + prompt_tokens = 0 for input_key in possible_input_keys: if input_key in usage: @@ -68,8 +70,12 @@ def get_tokens_from_response( if output_key in usage: completion_tokens = usage[output_key] break - - return prompt_tokens, completion_tokens + + cached_tokens = 0 + if openai_prompt_tokens_details_key in usage: + cached_tokens = usage[openai_prompt_tokens_details_key]['cached_tokens'] + + return prompt_tokens, completion_tokens, cached_tokens def get_llm_token_counts( @@ -83,9 +89,9 @@ def get_llm_token_counts( if completion: # get from raw or additional_kwargs - prompt_tokens, completion_tokens = get_tokens_from_response(completion) + prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(completion) else: - prompt_tokens, completion_tokens = 0, 0 + prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0 if prompt_tokens == 0: prompt_tokens = token_counter.get_string_tokens(str(prompt)) @@ -99,6 +105,7 @@ def get_llm_token_counts( prompt_token_count=prompt_tokens, completion=str(completion), completion_token_count=completion_tokens, + cached_tokens=cached_tokens, ) elif EventPayload.MESSAGES in payload: @@ -109,9 +116,9 @@ def get_llm_token_counts( response_str = str(response) if response: - prompt_tokens, completion_tokens = get_tokens_from_response(response) + prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(response) else: - prompt_tokens, completion_tokens = 0, 0 + prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0 if prompt_tokens == 0: prompt_tokens = token_counter.estimate_tokens_in_messages(messages) @@ -125,6 +132,7 @@ def get_llm_token_counts( prompt_token_count=prompt_tokens, completion=response_str, completion_token_count=completion_tokens, + cached_tokens=cached_tokens, ) else: return TokenCountingEvent( @@ -133,6 +141,7 @@ def get_llm_token_counts( prompt_token_count=0, completion="", completion_token_count=0, + cached_tokens=0, ) @@ -214,7 +223,9 @@ def on_event_end( "LLM Prompt Token Usage: " f"{self.llm_token_counts[-1].prompt_token_count}\n" "LLM Completion Token Usage: " - f"{self.llm_token_counts[-1].completion_token_count}", + f"{self.llm_token_counts[-1].completion_token_count}" + "LLM Cached Tokens: " + f"{self.llm_token_counts[-1].cached_tokens}", ) elif ( event_type == CBEventType.EMBEDDING @@ -251,6 +262,11 @@ def prompt_llm_token_count(self) -> int: def completion_llm_token_count(self) -> int: """Get the current total LLM completion token count.""" return sum([x.completion_token_count for x in self.llm_token_counts]) + + @property + def total_cached_token_count(self) -> int: + """Get the current total cached token count.""" + return sum([x.cached_tokens for x in self.llm_token_counts]) @property def total_embedding_token_count(self) -> int: