forked from ai-forever/ru-gpts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneration_wrapper.py
294 lines (258 loc) · 11.8 KB
/
generation_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import random
from typing import Union, Iterable
import numpy as np
import torch
from deepspeed import DeepSpeedConfig
from torch.nn import CrossEntropyLoss
from transformers import GPT2Tokenizer, PreTrainedModel, PretrainedConfig
import mpu
from fp16 import FP16_Module
from model import GPT2Model
from download_utils import download_model_files
from transformers.utils import logging
logger = logging.get_logger(__name__)
NoneType = type(None)
def get_deepspeed_config(path):
return DeepSpeedConfig(path)
def get_sparse_attention_config(path, num_heads):
ds_config = get_deepspeed_config(path)
if hasattr(ds_config, 'sparse_attention') and ds_config.sparse_attention:
sa_config = ds_config.sparse_attention
sa_mode = sa_config.get('mode')
if sa_mode == 'dense':
from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig
elif sa_mode == 'fixed':
from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig
elif sa_mode == 'bigbird':
from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig
elif sa_mode == 'bslongformer':
from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig
elif sa_mode == 'variable':
from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig
else:
raise NotImplementedError(
f'Given sparsity mode, {sa_mode}, has not been implemented yet!'
)
del sa_config['mode']
return STConfig(num_heads=num_heads, **sa_config)
else:
return None
def get_model(deepspeed_config_path):
num_local_heads = 16
sparse_mode = 'alternating'
deepspeed_sparsity_config = get_sparse_attention_config(deepspeed_config_path, num_local_heads)
if deepspeed_sparsity_config is not None:
logger.info(f"Use sparse attention with mode {sparse_mode}")
else:
logger.info(f"Use dense attention")
model = GPT2Model(num_layers=24,
vocab_size=50264,
hidden_size=2048,
num_attention_heads=num_local_heads,
embedding_dropout_prob=0.1, attention_dropout_prob=0.1, output_dropout_prob=0.1,
max_sequence_length=2048,
checkpoint_activations=False,
checkpoint_num_layers=1,
parallel_output=False,
deepspeed_sparsity_config=deepspeed_sparsity_config,
sparse_mode=sparse_mode)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
model = FP16_Module(model)
return model
def setup_model(weights_path, deepspeed_config_path):
model = get_model(deepspeed_config_path)
logger.info("Load checkpoint from " + weights_path)
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module']
model.load_state_dict(checkpoint)
model.eval()
logger.info("Model Loaded")
return model
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
# Find indices where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indices:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
return attention_mask, loss_mask, position_ids
class ModelOutput(object):
def __init__(self, logits, loss=None):
self.logits = logits
self.loss = loss
def __getitem__(self, key):
if key == "logits":
return self.logits
raise StopIteration
class RuGPT3XL(PreTrainedModel):
def __init__(self, model, tokenizer, model_path, seq_len=512):
super().__init__(PretrainedConfig())
self.model = model
self.pad_token_id = tokenizer.encoder['<pad>']
self.eos_token_id = tokenizer.encoder['<|endoftext|>']
self.seq_len = seq_len
self.model_path = model_path
self.tokenizer = tokenizer
@classmethod
def from_pretrained(cls, model_name_or_path, seq_len=512):
init_method = 'tcp://' + os.getenv('MASTER_ADDR', 'localhost') + ':' + os.getenv('MASTER_PORT', '6000')
torch.distributed.init_process_group(backend='nccl', world_size=1, rank=0, init_method=init_method)
mpu.initialize_model_parallel(1)
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
logger.info("Check cached model files...")
weights_path, deepspeed_config_path = download_model_files(model_name_or_path)
model = setup_model(weights_path, deepspeed_config_path)
model.cuda()
model = model.eval()
return cls(model, tokenizer=tokenizer, seq_len=seq_len, model_path=model_name_or_path)
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
kwargs.update({"input_ids": input_ids})
return kwargs
def generate(
self, text: Union[str, NoneType] = None,
input_ids: Union[torch.LongTensor, NoneType] = None,
max_length: Union[int, None] = None,
min_length: Union[int, NoneType] = None,
do_sample: Union[bool, NoneType] = None,
early_stopping: Union[bool, NoneType] = None,
num_beams: Union[int, NoneType] = None,
temperature: Union[float, NoneType] = None,
top_k: Union[int, NoneType] = None,
top_p: Union[float, NoneType] = None,
repetition_penalty: Union[float, NoneType] = None,
bad_words_ids: Union[Iterable[int], NoneType] = None,
bos_token_id: Union[int, NoneType] = None,
pad_token_id: Union[int, NoneType] = None,
eos_token_id: Union[int, NoneType] = None,
length_penalty: Union[float, NoneType] = None,
no_repeat_ngram_size: Union[int, NoneType] = None,
num_return_sequences: Union[int, NoneType] = None,
decoder_start_token_id: Union[int, NoneType] = None,
use_cache: Union[bool, NoneType] = None,
**model_kwargs):
if text is not None:
input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']])
if eos_token_id is None:
eos_token_id = self.eos_token_id
if pad_token_id is None:
pad_token_id = self.pad_token_id
res = super().generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
**model_kwargs
)
return list(map(self.tokenizer.decode, res.tolist()))
def __call__(self, text=None, input_ids=None, labels=None, **kwargs):
if input_ids is None:
if text is None:
text = ""
input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']])
if isinstance(input_ids, list):
input_ids = torch.cuda.LongTensor(input_ids)
if isinstance(labels, list):
labels = torch.cuda.LongTensor(labels)
res = []
if labels is not None:
lbls = labels
else:
lbls = [None] * len(input_ids)
loss = None
original_context_length = 0
for tokens, lbl in zip(input_ids, lbls):
context_tokens = tokens.tolist()
context_length = len(context_tokens)
original_context_length = len(context_tokens)
if context_length < self.seq_len:
context_tokens.extend([self.pad_token_id] * (self.seq_len - context_length))
if labels is not None:
lbl = lbl.tolist()
lbl.extend([self.pad_token_id] * (self.seq_len - context_length))
lbl = torch.cuda.LongTensor(lbl)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
# context_length = context_length_tensor[0].item()
tokens = context_tokens_tensor
tokens = tokens.view(1, -1).contiguous()
tokens = tokens.to(torch.cuda.current_device())
attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, self.pad_token_id, False,
False)
lm_logits = self.model(tokens, position_ids, attention_mask)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lbl[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
res.append((lm_logits, loss))
logits = torch.cat([x[0] for x in res], dim=0)[:, : original_context_length, :]
if loss is not None:
loss = [x[1] for x in res]
return ModelOutput(logits, loss)