Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support qwen2 and siglip weight conversion script to enable training … #1221

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions megatron/core/models/vision/clip_vit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def __init__(
self.img_h = img_h
self.img_w = img_w

assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
if model_subtype == "siglip":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we moving this assert so it's only used in the siglip case? We want it for the base clip case don't we?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving this assert to be used only in the siglip case because the model used in siglip does not satisfy this assert, while the models used in several clip cases do satisfy it. By the way, do we have any specific reason for adding this assert before? If there is no other special reason, we can consider removing it.

assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
Expand Down Expand Up @@ -161,6 +162,7 @@ def forward(
x = x + self.position_embeddings(self.position_ids)
if self.ln_pre:
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
# `permute` can make the tensor non-contiguous, breaking pipelining.
x = x.contiguous()
Expand Down
30 changes: 26 additions & 4 deletions tools/checkpoint/loader_llama_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def add_arguments(parser):

# TODO(jbarker): Need assertion to make sure *exactly* one of these is used
parser.add_argument('--model-size', type=str, required=True,
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B'],
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B', 'qwen2-7B'],
help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B` (for pretrained models), '
'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf` and `mistral-7Bf` (for chat-finetuned models).')
parser.add_argument('--checkpoint-type', type=str, required=True,
Expand Down Expand Up @@ -59,6 +59,7 @@ def verify_transformers_version():
"mistral-7B": 1,
"mistral-7Bf": 1,
"yi-34B": 8,
"qwen2-7B": 1,
}


Expand Down Expand Up @@ -87,6 +88,8 @@ def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path):
from transformers import LlamaConfig as ModelConfig
elif "mistral" in model_size:
from transformers import MistralConfig as ModelConfig
elif "qwen" in model_size:
from transformers import Qwen2Config as ModelConfig

# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
Expand All @@ -111,7 +114,7 @@ def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path):

if "llama2" in model_size:
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
elif model_size in ["llama3", "mistral"]:
elif model_size in ["llama3", "mistral", "qwen"]:
tokenizer_class = transformers.AutoTokenizer.from_pretrained
else:
raise AttributeError(f"model_size={model_size} not supported")
Expand All @@ -126,6 +129,9 @@ def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path):
elif "mistral" in model_size:
tokenizer = tokenizer_class.from_file(tokenizer_path)
vocab_size = 32768
elif "qwen" in model_size:
tokenizer = tokenizer_class(tokenizer_path)
vocab_size = 152064
else:
raise AttributeError(f"model_size={model_size} is not supported")

Expand Down Expand Up @@ -162,7 +168,7 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
# Unsharded
q_proj = loaded[f"layers.{layer_i}.attention.wq.weight"]
k_proj = loaded[f"layers.{layer_i}.attention.wk.weight"]
if ("llama2" in model_size) or ("mistral" in model_size):
if ("llama2" in model_size) or ("mistral" in model_size) or ("qwen" in model_size):
q_proj = permute(q_proj)
k_proj = permute(k_proj)
state_dict = {
Expand Down Expand Up @@ -353,6 +359,13 @@ def set_attn_state(args, layer, hf_layer):
hf_attn.k_proj.weight.reshape((ng, dim, -1)),
hf_attn.v_proj.weight.reshape((ng, dim, -1)),
], dim=1).reshape((-1, args.hidden_size)))
if args.add_qkv_bias:
attn.query_key_value.bias.data.copy_(torch.cat([
hf_attn.q_proj.bias.reshape((ng, dim*nh//ng)),
hf_attn.k_proj.bias.reshape((ng, dim)),
hf_attn.v_proj.bias.reshape((ng, dim)),
], dim=1).reshape((-1)))

attn.dense.weight.data.copy_(hf_attn.o_proj.weight)


Expand Down Expand Up @@ -458,6 +471,9 @@ def _load_checkpoint(queue, args):
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "mistral" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "qwen" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
margs.add_qkv_bias = True

# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
Expand Down Expand Up @@ -530,6 +546,7 @@ def check_for_arg(arg_name, default=None):
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.qkv_bias = margs.add_qkv_bias
md.norm_has_bias = False
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
Expand All @@ -543,7 +560,8 @@ def check_for_arg(arg_name, default=None):

# Get true (non-padded) vocab size
tokenizer = transformers.AutoTokenizer.from_pretrained(margs.tokenizer_model)
md.true_vocab_size = tokenizer._tokenizer.get_vocab_size(with_added_tokens=True)
md.true_vocab_size = None


# Get first pipe stage.
mpu.set_tensor_model_parallel_rank(0)
Expand Down Expand Up @@ -594,6 +612,8 @@ def queue_put(name, msg):
if md.linear_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
if md.qkv_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)

# Handle gated linear units.
if md.swiglu:
Expand All @@ -618,6 +638,8 @@ def queue_put(name, msg):
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
if md.qkv_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)

queue_put(f"transformer layer {layer_num}", message)

Expand Down
17 changes: 14 additions & 3 deletions tools/checkpoint/saver_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import sys
import torch
from importlib.metadata import version
from pkg_resources import packaging

from setter import ModelSetter
from utils import get_mcore_transformer_block_key, print_memory_usage
from megatron.core.utils import get_te_version, is_te_min_version


class MCoreSetter(ModelSetter):
Expand All @@ -28,6 +29,7 @@ def set_embeddings(
word=None,
pos=None,
):
print(f"=============================== model.embedding.word_embeddings.weight.shape, {model.embedding.word_embeddings.weight.shape}")
cls.set_tensor(model.embedding.word_embeddings.weight, word)
if pos is not None:
cls.set_tensor(model.embedding.position_embeddings.weight, pos)
Expand Down Expand Up @@ -287,8 +289,9 @@ def add_arguments(parser):
def save_checkpoint(queue, args):

# Transformer engine >= 0.12.0, for CPU initialization.
assert is_te_min_version("0.12.0"), \
"transformer engine version: %s (>=0.12.0 required)." % get_te_version()
te_version = packaging.version.Version(version("transformer-engine"))
assert te_version >= packaging.version.Version("0.12.0"), \
"transformer engine version: %s (>=0.12.0 required)." % te_version

# Search in directory above this
sys.path.append(os.path.abspath(
Expand Down Expand Up @@ -402,6 +405,8 @@ def check_message(msg):
sys.argv.append('--untie-embeddings-and-output-weights')
if not md.linear_bias:
sys.argv.append('--disable-bias-linear')
if md.qkv_bias:
sys.argv.append('--add-qkv-bias')

if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
Expand Down Expand Up @@ -500,6 +505,7 @@ def check_message(msg):
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)


# Deal with padding
def pad_weight(orig_word_embed, true_vocab_size):
if true_vocab_size is not None:
Expand Down Expand Up @@ -638,6 +644,8 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1):
mlp_l0_bias = torch.cat((mlp_l0_bias_W, mlp_l0_bias_V), dim=-1)
else:
mlp_l0_bias = chunk_bias(msg.pop("mlp l0 bias"), 'column', args.target_tensor_parallel_size, args.target_expert_parallel_size)
if md.qkv_bias:
qkv_bias = chunk_bias(msg.pop("qkv bias"), 'column', args.target_tensor_parallel_size)

# Save them to the model
for ep_rank in range(args.target_expert_parallel_size):
Expand Down Expand Up @@ -677,6 +685,9 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1):
"mlp_fc1_bias" : mlp_l0_bias[tp_rank],
"mlp_fc2_bias" : mlp_l1_bias
})
if md.qkv_bias:
params_dict.update({"self_attn_qkv_bias" : qkv_bias[tp_rank]})

if margs.num_experts:
params_dict.update({
"router_weight": router
Expand Down