diff --git a/sentence_transformers/losses/SoftmaxLoss.py b/sentence_transformers/losses/SoftmaxLoss.py index aa3d87ae9..8b35d56ac 100644 --- a/sentence_transformers/losses/SoftmaxLoss.py +++ b/sentence_transformers/losses/SoftmaxLoss.py @@ -55,7 +55,7 @@ def __init__(self, if concatenation_sent_multiplication: num_vectors_concatenated += 1 logger.info("Softmax loss: #Vectors concatenated: {}".format(num_vectors_concatenated)) - self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels) + self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels, device=model.device) self.loss_fct = loss_fct def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):