Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] How to use loader_mcore and why it requires torch distributed #1266

Open
KookHoiKim opened this issue Oct 29, 2024 · 1 comment
Open

Comments

@KookHoiKim
Copy link

I trained model while setting 'args.ckpt_format = torch_dist', and the checkpoint files saved like '_0.distcp, ..., common.pt, metadata.json'.
When i resume training, load_checkpoint works well.

However, i try to convert my checkpoint using loader_mcore, the checkpoint is not loaded occuring errors below:

Traceback (most recent call last):
  File "/workspace/code/Megatron-LM/tools/checkpoint/convert.py", line 160, in <module>
    main()
  File "/workspace/code/Megatron-LM/tools/checkpoint/convert.py", line 153, in main
    loader.load_checkpoint(queue, args)
  File "/workspace/code/Megatron-LM/tools/checkpoint/loader_mcore.py", line 384, in load_checkpoint
    _load_checkpoint(queue, args)
  File "/workspace/code/Megatron-LM/tools/checkpoint/loader_mcore.py", line 246, in _load_checkpoint
    all_models = [get_models(tp_size, md.params_dtype)]
  File "/workspace/code/Megatron-LM/tools/checkpoint/loader_mcore.py", line 164, in get_models
    load_checkpoint(model_, None, None)
  File "/workspace/code/Megatron-LM/megatron/training/checkpointing.py", line 1094, in load_checkpoint
    state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
  File "/workspace/code/Megatron-LM/megatron/training/checkpointing.py", line 850, in _load_base_checkpoint
    return _load_global_dist_base_checkpoint(
  File "/workspace/code/Megatron-LM/megatron/training/checkpointing.py", line 778, in _load_global_dist_base_checkpoint
    state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness)
  File "/workspace/code/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 126, in load
    local_metadata, global_metadata = determine_global_metadata(sharded_state_dict)
  File "/workspace/code/Megatron-LM/megatron/core/dist_checkpointing/validation.py", line 496, in determine_global_metadata
    global_metadata = [None] * torch.distributed.get_world_size()
  File "/miniforge3/envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1832, in get_world_size
    return _get_group_size(group)
  File "/miniforge3/envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 864, in _get_group_size
    default_pg = _get_default_group()
  File "/miniforge3/envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1025, in _get_default_group
    raise ValueError(
ValueError: Default process group has not been initialized, please make sure to call init_process_group.

It requires torch.distributed.is_initialized, but convert.py , or loader_mcore does not include initialization of distribution.
Does it really need distributed.initialization or i do something wrong?

@lmcafee-nvidia
Copy link
Contributor

lmcafee-nvidia commented Oct 30, 2024

@KookHoiKim , you are correct that tools/checkpoint does not support converting distributed checkpoints. If you'd like to convert a distributed checkpoint via tools/checkpoint, there's a two step process:

  • Convert from torch_dist to torch format. This conversion is done by launching a slightly modified version of your normal training script, but with two arguments added (and leaving all other args the same). The new checkpoint is saved and the system exits before doing any training iterations:
    • --ckpt-convert-format torch: This sets the format for saving the new checkpoint, and we want it to be torch format.
    • --ckpt-convert-save ${PATH_TO_TORCH_CKPT}: This path should be different than your existing --load/--save path, to avoid overwriting your existing checkpoint. This path will also be used for loading in the next step.
  • Convert using loader_mcore.py. For this step we use the path to the newly saved checkpoint by doing --load-dir ${PATH_TO_TORCH_CKPT}.

Please let me know if you have any questions about this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants