From 0417a0b3b7204aa14f1c2279d718b1cb17aae449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Su=C3=A1rez-Paniagua?= Date: Wed, 13 Dec 2023 16:52:02 +0000 Subject: [PATCH] Simplify the smart_batching_collate function (#1852) * Simplify the smart_batching_collate function * Simplify slightly further, update type hints --------- Co-authored-by: Tom Aarsen --- sentence_transformers/SentenceTransformer.py | 31 +++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index cd645caf7..826f1a611 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -5,7 +5,7 @@ import stat from collections import OrderedDict import warnings -from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal +from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal, TYPE_CHECKING import numpy as np from numpy import ndarray import transformers @@ -31,6 +31,10 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from sentence_transformers.readers import InputExample + + def get_device_name() -> Literal["mps", "cuda", "cpu"]: """ Returns the name of the device where this module is running on. @@ -536,36 +540,21 @@ def save_to_hub(self, return folder_url - def smart_batching_collate(self, batch): + def smart_batching_collate(self, batch: List["InputExample"]) -> Tuple[List[Dict[str, Tensor]], Tensor]: """ Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model - Here, batch is a list of tuples: [(tokens, label), ...] + Here, batch is a list of InputExample instances: [InputExample(...), ...] :param batch: a batch from a SmartBatchingDataset :return: a batch of tensors for the model """ - num_texts = len(batch[0].texts) - texts = [[] for _ in range(num_texts)] - labels = [] - - for example in batch: - for idx, text in enumerate(example.texts): - texts[idx].append(text) - - labels.append(example.label) - - labels = torch.tensor(labels) - - sentence_features = [] - for idx in range(num_texts): - tokenized = self.tokenize(texts[idx]) - sentence_features.append(tokenized) - + texts = [example.texts for example in batch] + sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)] + labels = torch.tensor([example.label for example in batch]) return sentence_features, labels - def _text_length(self, text: Union[List[int], List[List[int]]]): """ Help function to get the length for the input text. Text can be either