Skip to content

Commit

Permalink
Simplify the smart_batching_collate function (#1852)
Browse files Browse the repository at this point in the history
* Simplify the smart_batching_collate function

* Simplify slightly further, update type hints

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
vsuarezpaniagua and tomaarsen authored Dec 13, 2023
1 parent 1565ef6 commit 0417a0b
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import stat
from collections import OrderedDict
import warnings
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal, TYPE_CHECKING
import numpy as np
from numpy import ndarray
import transformers
Expand All @@ -31,6 +31,10 @@
logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from sentence_transformers.readers import InputExample


def get_device_name() -> Literal["mps", "cuda", "cpu"]:
"""
Returns the name of the device where this module is running on.
Expand Down Expand Up @@ -536,36 +540,21 @@ def save_to_hub(self,
return folder_url


def smart_batching_collate(self, batch):
def smart_batching_collate(self, batch: List["InputExample"]) -> Tuple[List[Dict[str, Tensor]], Tensor]:
"""
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
Here, batch is a list of tuples: [(tokens, label), ...]
Here, batch is a list of InputExample instances: [InputExample(...), ...]
:param batch:
a batch from a SmartBatchingDataset
:return:
a batch of tensors for the model
"""
num_texts = len(batch[0].texts)
texts = [[] for _ in range(num_texts)]
labels = []

for example in batch:
for idx, text in enumerate(example.texts):
texts[idx].append(text)

labels.append(example.label)

labels = torch.tensor(labels)

sentence_features = []
for idx in range(num_texts):
tokenized = self.tokenize(texts[idx])
sentence_features.append(tokenized)

texts = [example.texts for example in batch]
sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
labels = torch.tensor([example.label for example in batch])
return sentence_features, labels


def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
Expand Down

0 comments on commit 0417a0b

Please sign in to comment.