forked from dvlab-research/LongLoRA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
111 lines (95 loc) · 3.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
import math
import torch
import argparse
import textwrap
import transformers
from peft import PeftModel
from transformers import GenerationConfig
from llama_attn_replace import replace_llama_attn
PROMPT_DICT = {
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--material', type=str, default="")
parser.add_argument('--question', type=str, default="")
parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
parser.add_argument('--cache_dir', type=str, default="./cache")
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
parser.add_argument('--flash_attn', type=bool, default=False, help='')
parser.add_argument('--temperature', type=float, default=0.6, help='')
parser.add_argument('--top_p', type=float, default=0.9, help='')
parser.add_argument('--max_gen_len', type=int, default=512, help='')
args = parser.parse_args()
return args
def read_txt_file(material_txt):
if not material_txt.split(".")[-1]=='txt':
raise ValueError("Only support txt or pdf file.")
content = ""
with open(material_txt) as f:
for line in f.readlines():
content += line
return content
def build_generator(
model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
):
def response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=max_gen_len,
temperature=temperature,
top_p=top_p,
use_cache=use_cache
)
out = tokenizer.decode(output[0], skip_special_tokens=True)
out = out.split(prompt)[1].strip()
return out
return response
def main(args):
if args.flash_attn:
replace_llama_attn(inference=True)
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and args.context_size > orig_ctx_len:
scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
args.base_model,
config=config,
cache_dir=args.cache_dir,
torch_dtype=torch.float16,
device_map="auto",
)
model.resize_token_embeddings(32001)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
padding_side="right",
use_fast=False,
)
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p,
max_gen_len=args.max_gen_len, use_cache=True)
material = read_txt_file(args.material)
prompt_no_input = PROMPT_DICT["prompt_no_input"]
prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question})
output = respond(prompt=prompt)
print("output", output)
if __name__ == "__main__":
args = parse_config()
main(args)