Skip to content

Commit

Permalink
Replace deprecated batch_size with max_batch_size when using HybridCa…
Browse files Browse the repository at this point in the history
…che (#35498)

* Replace deprecated batch_size with max_batch_size

- Functionality remains the same, because property getter batch_size(self) returned max_batch_size anyways.
- This change just avoids an unnecessary warning about deprecation.

* Use max_batch_size instead of deprecated batch_size with HybridCache

* Use max_batch_size instead of deprecated batch_size with HybridCache

- Change generated code to match original source
  • Loading branch information
mtreinik authored Jan 16, 2025
1 parent 99e0ab6 commit bef7dde
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,9 +1640,9 @@ def __init__(
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
self.batch_size,
self.max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def forward(
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def forward(
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def forward(
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def forward(
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
Expand Down

0 comments on commit bef7dde

Please sign in to comment.