Skip to content

Commit

Permalink
XML-RoBERTa update tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Jan 10, 2020
1 parent b788845 commit 9dbde8d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions sentence_transformers/models/XLMRoBERTa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ def __init__(self, model_name_or_path: str, max_seq_length: int = 128, do_lower_
super(XLMRoBERTa, self).__init__()
self.config_keys = ['max_seq_length', 'do_lower_case']
self.do_lower_case = do_lower_case
self.xlm_roberta = XLMRobertaModel.from_pretrained(model_name_or_path)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name_or_path, do_lower_case=do_lower_case)

if max_seq_length > 511:
logging.warning("RoBERTa only allows a max_seq_length of 511 (514 with special tokens). Value will be set to 511")
max_seq_length = 511
if max_seq_length > self.tokenizer.max_len_single_sentence:
logging.warning("XLM-RoBERTa only allows a max_seq_length of "+self.tokenizer.max_len_single_sentence)
max_seq_length = self.tokenizer.max_len_single_sentence
self.max_seq_length = max_seq_length


self.xlm_roberta = XLMRobertaModel.from_pretrained(model_name_or_path)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name_or_path, do_lower_case=do_lower_case)
self.cls_token_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.cls_token])[0]
self.sep_token_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.sep_token])[0]
self.cls_token_id = self.tokenizer.cls_token_id
self.eos_token_id = self.tokenizer.eos_token_id

def forward(self, features):
"""Returns token_embeddings, cls_token"""
Expand Down Expand Up @@ -58,7 +58,7 @@ def get_sentence_features(self, tokens: List[int], pad_seq_length: int):
pad_seq_length = min(pad_seq_length, self.max_seq_length)

tokens = tokens[:pad_seq_length]
input_ids = [self.cls_token_id] + tokens + [self.sep_token_id] + [self.sep_token_id]
input_ids = [self.cls_token_id] + tokens + [self.eos_token_id]
sentence_length = len(input_ids)

pad_seq_length += 3 ##Add Space for CLS + SEP + SEP token
Expand Down

0 comments on commit 9dbde8d

Please sign in to comment.