Skip to content

Commit

Permalink
[fix] Prevent IndexError if output_hidden_states & ONNX (#3008)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Oct 21, 2024
1 parent a028b58 commit f286d9f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc

features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})

if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
if self.auto_model.config.output_hidden_states and len(output_states) > 2:
all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

Expand Down

0 comments on commit f286d9f

Please sign in to comment.