Skip to content

Commit

Permalink
Suport for embedded representation (#3156)
Browse files Browse the repository at this point in the history
* inputs_embeds

* leave feature validation to transformers

---------

Co-authored-by: rchivereanu <[email protected]>
  • Loading branch information
Radu1999 and rchivereanu authored Jan 10, 2025
1 parent a41aada commit a7e3707
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def __repr__(self) -> str:

def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
"""Returns token_embeddings, cls_token"""
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]
trans_features = {
key: value
for key, value in features.items()
if key in ["input_ids", "attention_mask", "token_type_ids", "inputs_embeds"]
}

output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
output_tokens = output_states[0]
Expand Down

0 comments on commit a7e3707

Please sign in to comment.