diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f38fc8f9824d3b..822692315cd681 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1640,9 +1640,9 @@ def __init__( ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) sliding_cache_shape = ( - self.batch_size, + self.max_batch_size, self.num_key_value_heads, min(config.sliding_window, max_cache_len), self.head_dim, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 9811c02163aa67..1dde941fe0c9c0 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -579,7 +579,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 145905287acde4..78419e78c08b5c 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -459,7 +459,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 116504a7312799..e64559b266509f 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -577,7 +577,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f73b9ea840aec8..5f21fc6bfffd61 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -403,7 +403,7 @@ def forward( batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - batch_size=batch_size, + max_batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype,