Skip to content

Commit

Permalink
add weight convert
Browse files Browse the repository at this point in the history
Signed-off-by: shunxing12345 <[email protected]>
  • Loading branch information
shunxing12345 committed Jan 17, 2025
1 parent f28eeaa commit e4a74b1
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 618 deletions.
30 changes: 15 additions & 15 deletions src/transformers/models/telechat2/configuration_telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,27 @@ class TeleChat2Config(PretrainedConfig):
`inputs_ids` passed when calling [`TeleChat2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
ffn_hidden_size (`int`, *optional*, defaults to 11008):
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the hidden layers in the MLP blocks.
n_layer (`int`, *optional*, defaults to 30):
num_hidden_layers (`int`, *optional*, defaults to 30):
Number of hidden layers in the Transformer decoder.
n_head (`int`, *optional*, defaults to 32):
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=n_head`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`n_head`.
`num_attention_heads`.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. TeleChat2 1 supports up to 8192 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-06):
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
Expand Down Expand Up @@ -153,14 +153,14 @@ def __init__(
self,
vocab_size=32000,
hidden_size=4096,
ffn_hidden_size=11008,
intermediate_size=11008,
hidden_dropout=0.0,
n_layer=30,
n_head=32,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_epsilon=1e-6,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
Expand All @@ -177,17 +177,17 @@ def __init__(
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout = hidden_dropout
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = n_head
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
Expand Down
161 changes: 161 additions & 0 deletions src/transformers/models/telechat2/convert_telechat2_weigths_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import json
import os
import re

import torch
from safetensors.torch import load_file
from tokenizers import processors

from transformers import TeleChat2Config, TeleChat2ForCausalLM


# fmt: off
# `None` means we drop the key
STATE_DICT_MAPPING = {
# Model keys
r"transformer.word_embeddings.weight": r"model.embed_tokens.weight",
r"transformer.ln_f.weight": r"model.norm.weight",

# Layers keys
r"transformer.h.(\d+).input_layernorm.weight": r"model.layers.\1.input_layernorm.weight",
r"transformer.h.(\d+).post_attention_layernorm.weight": r"model.layers.\1.post_attention_layernorm.weight",

# Attention keys
r"transformer.h.(\d+).self_attention.dense.weight": r"model.layers.\1.self_attn.o_proj.weight",
# qkv_proj will later be split in q|k|v|_proj
r"transformer.h.(\d+).self_attention.key_value.(weight|bias)": r"model.layers.\1.self_attn.key_value.\2",
r"transformer.h.(\d+).self_attention.query.(weight|bias)": r"model.layers.\1.self_attn.query.\2",

# MLP keys
r"transformer.h.(\d+).mlp.gate_proj.weight": r"model.layers.\1.mlp.gate_proj.weight",
r"transformer.h.(\d+).mlp.up_proj.weight": r"model.layers.\1.mlp.up_proj.weight",
r"transformer.h.(\d+).mlp.down_proj.weight": r"model.layers.\1.mlp.down_proj.weight",
}
# fmt: on


def load_weights(input_dir: str):
safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")]
bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")]

all_weights = {}

if safetensor_files:
safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in safetensor_files:
tensors = load_file(file)
all_weights.update(tensors)
return all_weights

elif bin_files:
bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in bin_files:
tensors = torch.load(file, map_location="cpu")
all_weights.update(tensors)
return all_weights

else:
raise ValueError("No .safetensors or .bin files found in the specified directory.")


def map_old_key_to_new(old_key):
for pattern, replacement in STATE_DICT_MAPPING.items():
if replacement is None:
if re.fullmatch(pattern, old_key):
return None
else:
new_key, n_replace = re.subn(pattern, replacement, old_key)
# Early exit of the loop
if n_replace > 0:
return new_key

raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")


def convert_state_dict(original_state_dict: dict, config: TeleChat2Config):
new_dict = {}

head_dim = config.hidden_size // config.num_attention_heads
query_size = config.num_attention_heads * head_dim
kv_size = config.num_key_value_heads * head_dim

for old_key, value in original_state_dict.items():
new_key = map_old_key_to_new(old_key)
if new_key is None:
continue

if "key_value." in new_key:
k_proj, v_proj = (
value[:head_dim, ...],
value[head_dim : 2 * head_dim, ...],
)
new_dict[new_key.replace("key_value.", "k_proj.")] = k_proj
new_dict[new_key.replace("key_value.", "v_proj.")] = v_proj
else:
new_dict[new_key] = value
return new_dict


def convert_config(original_config: dict):
key_mapping = {
"intermediate_size": "ffn_hidden_size",
"rms_norm_eps": "layer_norm_epsilon",
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
}
similar_keys_to_keep = [
"max_position_embeddings",
"hidden_size",
"num_key_value_heads",
"head_dim",
"attention_dropout",
"use_cache",
"eos_token_id",
"pad_token_id",
"tie_word_embeddings",
"vocab_size",
]
new_config_kwargs = {k: original_config[v] for k, v in key_mapping.items()}
new_config_kwargs.update({k: v for k, v in original_config.items() if k in similar_keys_to_keep})

new_config = TeleChat2Config(**new_config_kwargs)
return new_config


def convert_telechat2_model(input_dir, output_dir, use_post_processor=False):
# Load and convert config
with open(os.path.join(input_dir, "config.json")) as f:
original_config = json.load(f)
config = convert_config(original_config)
config.save_pretrained(output_dir)

# Load and convert weights
original_state_dict = load_weights(input_dir)
new_dict = convert_state_dict(original_state_dict, config)
with torch.device("meta"):
model = TeleChat2ForCausalLM(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
type=str,
help="Location of the local folder copied from the Hub.",
)
parser.add_argument(
"output_dir",
type=str,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--use_post_processor",
action="store_true",
help="Whether to apply post processor with special tokens",
)

args = parser.parse_args()
convert_telechat2_model(args.input_dir, args.output_dir, args.use_post_processor)
Loading

0 comments on commit e4a74b1

Please sign in to comment.