Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Model and Config File Mismatch in Wenet Conformer #2645

Open
yhuangece7 opened this issue Oct 22, 2024 · 7 comments
Open

Issue with Model and Config File Mismatch in Wenet Conformer #2645

yhuangece7 opened this issue Oct 22, 2024 · 7 comments
Assignees

Comments

@yhuangece7
Copy link

Describe the bug
There is a mismatch between the train.yaml configuration file and the loaded model weights (final.pt) when using the Wenet pretrained model wenetspeech_u2pp_conformer_exp. Specifically, when attempting to load the weights with the given configuration, several missing and unexpected keys are reported, which may indicate inconsistency between the model architecture defined in the YAML file and the actual pretrained weights.

To Reproduce
Steps to reproduce the behavior:

  1. Download the wenetspeech_u2pp_conformer_exp.tar.gz model from Wenet pretrained models page(https://wenet.org.cn/wenet/pretrained_models.en.html).
  2. Extract the downloaded archive on a Windows machine.
  3. Update the train.yaml file to adjust the paths for units.txt and global_cmvn (the default path is not accommodated for the fact that the yaml file is under the same directory as units.txt and global_cmvn)
  4. Use the following Python script to verify the consistency between train.yaml and final.pt:
import torch
import yaml
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.encoder import ConformerEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.ctc import CTC
import os

model_dir = "C:/path/to/model_directory"

# Load the configuration file with UTF-8 encoding
config_file = os.path.join(model_dir, "train.yaml")
with open(config_file, 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

# Extract relevant parameters from the config
vocab_size = config['output_dim']
encoder_conf = config['encoder_conf']
decoder_conf = config['decoder_conf']

# Remove unsupported parameters for TransformerDecoder
unsupported_decoder_keys = {'r_num_blocks'}
filtered_decoder_conf = {k: v for k, v in decoder_conf.items() if k not in unsupported_decoder_keys}

# Initialize Encoder and Decoder
encoder = ConformerEncoder(input_size=80, **encoder_conf)
decoder = TransformerDecoder(vocab_size=vocab_size, encoder_output_size=encoder_conf['output_size'], **filtered_decoder_conf)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_conf['output_size'])

# Initialize the ASR model
model = ASRModel(vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **config['model_conf'])

# Load pretrained model weights
checkpoint_path = os.path.join(model_dir, "final.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Load weights and obtain missing and unexpected keys
load_result = model.load_state_dict(checkpoint, strict=False)

# Print missing and unexpected keys
print(f"Missing keys: {load_result.missing_keys}")
print(f"Unexpected keys: {load_result.unexpected_keys}")

Expected behavior
The script should load the pretrained weights without any missing or unexpected keys, indicating a consistent configuration between train.yaml and final.pt.

Screenshots
When I ran the code above, I can print the missing keys and unexpected keys below:

Missing keys: ['encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', 'decoder.embed.1.pe', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.decoders.1.self_attn.linear_q.weight', 'decoder.decoders.1.self_attn.linear_q.bias', 'decoder.decoders.1.self_attn.linear_k.weight', 'decoder.decoders.1.self_attn.linear_k.bias', 'decoder.decoders.1.self_attn.linear_v.weight', 'decoder.decoders.1.self_attn.linear_v.bias', 'decoder.decoders.1.self_attn.linear_out.weight', 'decoder.decoders.1.self_attn.linear_out.bias', 'decoder.decoders.1.src_attn.linear_q.weight', 'decoder.decoders.1.src_attn.linear_q.bias', 'decoder.decoders.1.src_attn.linear_k.weight', 'decoder.decoders.1.src_attn.linear_k.bias', 'decoder.decoders.1.src_attn.linear_v.weight', 'decoder.decoders.1.src_attn.linear_v.bias', 'decoder.decoders.1.src_attn.linear_out.weight', 'decoder.decoders.1.src_attn.linear_out.bias', 'decoder.decoders.1.feed_forward.w_1.weight', 'decoder.decoders.1.feed_forward.w_1.bias', 'decoder.decoders.1.feed_forward.w_2.weight', 'decoder.decoders.1.feed_forward.w_2.bias', 'decoder.decoders.1.norm1.weight', 'decoder.decoders.1.norm1.bias', 'decoder.decoders.1.norm2.weight', 'decoder.decoders.1.norm2.bias', 'decoder.decoders.1.norm3.weight', 'decoder.decoders.1.norm3.bias', 'decoder.decoders.2.self_attn.linear_q.weight', 'decoder.decoders.2.self_attn.linear_q.bias', 'decoder.decoders.2.self_attn.linear_k.weight', 'decoder.decoders.2.self_attn.linear_k.bias', 'decoder.decoders.2.self_attn.linear_v.weight', 'decoder.decoders.2.self_attn.linear_v.bias', 'decoder.decoders.2.self_attn.linear_out.weight', 'decoder.decoders.2.self_attn.linear_out.bias', 'decoder.decoders.2.src_attn.linear_q.weight', 'decoder.decoders.2.src_attn.linear_q.bias', 'decoder.decoders.2.src_attn.linear_k.weight', 'decoder.decoders.2.src_attn.linear_k.bias', 'decoder.decoders.2.src_attn.linear_v.weight', 'decoder.decoders.2.src_attn.linear_v.bias', 'decoder.decoders.2.src_attn.linear_out.weight', 'decoder.decoders.2.src_attn.linear_out.bias', 'decoder.decoders.2.feed_forward.w_1.weight', 'decoder.decoders.2.feed_forward.w_1.bias', 'decoder.decoders.2.feed_forward.w_2.weight', 'decoder.decoders.2.feed_forward.w_2.bias', 'decoder.decoders.2.norm1.weight', 'decoder.decoders.2.norm1.bias', 'decoder.decoders.2.norm2.weight', 'decoder.decoders.2.norm2.bias', 'decoder.decoders.2.norm3.weight', 'decoder.decoders.2.norm3.bias']
Unexpected keys: ['encoder.global_cmvn.mean', 'encoder.global_cmvn.istd', 'decoder.left_decoder.embed.0.weight', 'decoder.left_decoder.after_norm.weight', 'decoder.left_decoder.after_norm.bias', 'decoder.left_decoder.output_layer.weight', 'decoder.left_decoder.output_layer.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.0.norm1.weight', 'decoder.left_decoder.decoders.0.norm1.bias', 'decoder.left_decoder.decoders.0.norm2.weight', 'decoder.left_decoder.decoders.0.norm2.bias', 'decoder.left_decoder.decoders.0.norm3.weight', 'decoder.left_decoder.decoders.0.norm3.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.1.norm1.weight', 'decoder.left_decoder.decoders.1.norm1.bias', 'decoder.left_decoder.decoders.1.norm2.weight', 'decoder.left_decoder.decoders.1.norm2.bias', 'decoder.left_decoder.decoders.1.norm3.weight', 'decoder.left_decoder.decoders.1.norm3.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.2.norm1.weight', 'decoder.left_decoder.decoders.2.norm1.bias', 'decoder.left_decoder.decoders.2.norm2.weight', 'decoder.left_decoder.decoders.2.norm2.bias', 'decoder.left_decoder.decoders.2.norm3.weight', 'decoder.left_decoder.decoders.2.norm3.bias', 'decoder.right_decoder.embed.0.weight', 'decoder.right_decoder.after_norm.weight', 'decoder.right_decoder.after_norm.bias', 'decoder.right_decoder.output_layer.weight', 'decoder.right_decoder.output_layer.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.0.norm1.weight', 'decoder.right_decoder.decoders.0.norm1.bias', 'decoder.right_decoder.decoders.0.norm2.weight', 'decoder.right_decoder.decoders.0.norm2.bias', 'decoder.right_decoder.decoders.0.norm3.weight', 'decoder.right_decoder.decoders.0.norm3.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.1.norm1.weight', 'decoder.right_decoder.decoders.1.norm1.bias', 'decoder.right_decoder.decoders.1.norm2.weight', 'decoder.right_decoder.decoders.1.norm2.bias', 'decoder.right_decoder.decoders.1.norm3.weight', 'decoder.right_decoder.decoders.1.norm3.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.2.norm1.weight', 'decoder.right_decoder.decoders.2.norm1.bias', 'decoder.right_decoder.decoders.2.norm2.weight', 'decoder.right_decoder.decoders.2.norm2.bias', 'decoder.right_decoder.decoders.2.norm3.weight', 'decoder.right_decoder.decoders.2.norm3.bias']

Desktop (please complete the following information):

  • OS: Windows 11
  • Python Version: 3.8.17
  • PyTorch Version: 2.0.1+cu117
  • Wenet Version: 20220506_u2pp_conformer_exp_wenetspeech

Smartphone (please complete the following information):
N/A

Additional context

  1. The missing keys include parameters such as 'encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', and other weights related to the decoder structure.
  2. The unexpected keys include parameters like 'encoder.global_cmvn.mean', 'decoder.left_decoder.embed.0.weight', and others.
  3. I am concerned that these mismatches may affect the fine-tuning process, as my goal is to use this model as a base for further training.
  4. I have checked the README but couldn't find a way to resolve this issue by modifying the train.yaml file. I am wondering if there is an updated version of the model or if this is a known issue that can be safely ignored.
  5. Any guidance or suggestions would be greatly appreciated, especially if there's a way to download a compatible version of the pretrained model.

Thank you!

@xingchensong
Copy link
Member

use bitransformerdecoder instead of transformerdecoder

@yhuangece7
Copy link
Author

使用 bitransformerdecoder 代替 transformerdecoder

Hi Xingchen,
I checked the train.yaml in exp folder and the default line is decoder: bitransformer .
From my understanding it is using bitransformer, instead of transformer.
Are you saying if I change it to decoder: bitransformerdecoder that can solve this mismatch problem between the pt file and the yaml?

Thank you!

@xingchensong
Copy link
Member

image

@yhuangece7
Copy link
Author

image

Hi Xingchen,
Thank you for your advise!
I tried to use your suggestions but still find there are several missing keys between final.pt and the encoder/decoder code from ConformerEncode/BiTransformerDecoder. Therefore, I am attaching my test code and the printed result for your reference.

1. This is the key part of my checking code

import torch
import yaml
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.encoder import ConformerEncoder
from wenet.transformer.decoder import BiTransformerDecoder
from wenet.transformer.ctc import CTC
import os

# 模型目录路径
model_dir = "<local path>/wenet/exp/20220506_u2pp_conformer_exp_wenetspeech"

# 加载配置文件
config_file = os.path.join(model_dir, "train.yaml")
with open(config_file, 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

# 提取配置参数
vocab_size = config['output_dim']
encoder_conf = config['encoder_conf']
decoder_conf = config['decoder_conf']

# 过滤不支持的参数
unsupported_decoder_keys = {'r_num_blocks'}
filtered_decoder_conf = {k: v for k, v in decoder_conf.items() if k not in unsupported_decoder_keys}

# 初始化 Encoder, Decoder 和 CTC
encoder = ConformerEncoder(input_size=80, **encoder_conf)
decoder = BiTransformerDecoder(vocab_size=vocab_size, encoder_output_size=encoder_conf['output_size'], **filtered_decoder_conf)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_conf['output_size'])

# 初始化 ASR 模型
model = ASRModel(vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **config['model_conf'])

# 加载权重
checkpoint_path = os.path.join(model_dir, "final.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# 加载权重到模型并记录加载结果
load_result = model.load_state_dict(checkpoint, strict=False)

# 打印详细信息
missing_keys = load_result.missing_keys
unexpected_keys = load_result.unexpected_keys

2. The following is the printed missing keys and unexpected keysm, while the unexpected keys include

======= Missing Keys =======
Missing Key: encoder.embed.pos_enc.pe
Missing Key: decoder.left_decoder.embed.1.pe
Missing Key: decoder.right_decoder.embed.1.pe
======= Unexpected Keys =======
Unexpected Key: encoder.global_cmvn.mean
Unexpected Key: encoder.global_cmvn.istd
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.0.self_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.0.src_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.0.feed_forward.w_1.weight
Unexpected Key: decoder.right_decoder.decoders.0.feed_forward.w_1.bias
Unexpected Key: decoder.right_decoder.decoders.0.feed_forward.w_2.weight
Unexpected Key: decoder.right_decoder.decoders.0.feed_forward.w_2.bias
Unexpected Key: decoder.right_decoder.decoders.0.norm1.weight
Unexpected Key: decoder.right_decoder.decoders.0.norm1.bias
Unexpected Key: decoder.right_decoder.decoders.0.norm2.weight
Unexpected Key: decoder.right_decoder.decoders.0.norm2.bias
Unexpected Key: decoder.right_decoder.decoders.0.norm3.weight
Unexpected Key: decoder.right_decoder.decoders.0.norm3.bias
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.1.self_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.1.src_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.1.feed_forward.w_1.weight
Unexpected Key: decoder.right_decoder.decoders.1.feed_forward.w_1.bias
Unexpected Key: decoder.right_decoder.decoders.1.feed_forward.w_2.weight
Unexpected Key: decoder.right_decoder.decoders.1.feed_forward.w_2.bias
Unexpected Key: decoder.right_decoder.decoders.1.norm1.weight
Unexpected Key: decoder.right_decoder.decoders.1.norm1.bias
Unexpected Key: decoder.right_decoder.decoders.1.norm2.weight
Unexpected Key: decoder.right_decoder.decoders.1.norm2.bias
Unexpected Key: decoder.right_decoder.decoders.1.norm3.weight
Unexpected Key: decoder.right_decoder.decoders.1.norm3.bias
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.2.self_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_q.weight
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_q.bias
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_k.weight
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_k.bias
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_v.weight
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_v.bias
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_out.weight
Unexpected Key: decoder.right_decoder.decoders.2.src_attn.linear_out.bias
Unexpected Key: decoder.right_decoder.decoders.2.feed_forward.w_1.weight
Unexpected Key: decoder.right_decoder.decoders.2.feed_forward.w_1.bias
Unexpected Key: decoder.right_decoder.decoders.2.feed_forward.w_2.weight
Unexpected Key: decoder.right_decoder.decoders.2.feed_forward.w_2.bias
Unexpected Key: decoder.right_decoder.decoders.2.norm1.weight
Unexpected Key: decoder.right_decoder.decoders.2.norm1.bias
Unexpected Key: decoder.right_decoder.decoders.2.norm2.weight
Unexpected Key: decoder.right_decoder.decoders.2.norm2.bias
Unexpected Key: decoder.right_decoder.decoders.2.norm3.weight
Unexpected Key: decoder.right_decoder.decoders.2.norm3.bias

========
By seeing these results, and my purpose is to perform the fine tune of the final.pt,
(1) if something is still not right to cause the printed message, would you advise?
(2) if you acknowledge the mismatched keys and missing keys do exist, do you think is it safe to ignore these keys during my fine tune?
Thank you!

@xingchensong
Copy link
Member

image

@yhuangece7
Copy link
Author

image

Hi Xingchen,

Thank you! That solved my decoder problems. But the missing keys are the same as before, while the unexpected keys are related to cmvn.

======= Missing Keys =======
Missing Key: encoder.embed.pos_enc.pe
Missing Key: decoder.left_decoder.embed.1.pe
Missing Key: decoder.right_decoder.embed.1.pe

======= Unexpected Keys =======
Unexpected Key: encoder.global_cmvn.mean
Unexpected Key: encoder.global_cmvn.istd

My questions are:

  1. how should I handle the three missing keys mentioned above?
  2. for the unexpected keys, since the global_cmvn file has mean_stat, var_stat, and frame_num, should I change the text "mean_stat" in cmvn file to "mean". and change the text "var_stat" in cmvn file to "istd"?

Thank you!

@yhuangece7
Copy link
Author

Hi Xingchen,
Regarding the missing keys and unexpected keys issues, I have some new findings and questions

======= Missing Keys =======
Missing Key: encoder.embed.pos_enc.pe
Missing Key: decoder.left_decoder.embed.1.pe
Missing Key: decoder.right_decoder.embed.1.pe

======= Unexpected Keys =======
Unexpected Key: encoder.global_cmvn.mean
Unexpected Key: encoder.global_cmvn.istd

(1) regarding the missing keys, I traced the bin/export_jit.py to utils/init_model.py, and found these missing keys can be detected when the model was loaded via load_checkpoint function imported from utils/checkpoint.py. Therefore, I assume the missing keys should not be an issue once the pt file is loaded using export_jit.py.

(2) regarding the unexpected keys, since we have "mean_stat", "var_stat", and "frame_num" in global_cmvn, do you think I should perform the following conversion in my final tune program per my understanding of these the meanings of these parameters?

                mean = mean_stat / frame_num
                istd = (var_stat / frame_num - mean ** 2) ** -0.5 

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants