-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Self-Rewarding Algorithm with TRT Support #321
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
* trtllm0.9 changes Signed-off-by: jiemingz <=> * fix typos Signed-off-by: jiemingz <=> * address comments Signed-off-by: jiemingz <=> * fixes Signed-off-by: jiemingz <=> * fix Signed-off-by: jiemingz <=> * fix nemo generations with PP Signed-off-by: jiemingz <=> * add engine_unload Signed-off-by: jiemingz <=> * cleanup trtllm Signed-off-by: jiemingz <=> * address comments Signed-off-by: jiemingz <=> --------- Signed-off-by: jiemingz <=> Co-authored-by: jiemingz <=>
- preference_loss: the raw DPO variant loss | ||
- sft_loss: if adding an SFT loss (categorical cross-entropy loss) for the chosen response, then you can see that raw loss here | ||
|
||
The ``reward`` in this case is calculated as the difference between model log probs and the reference log probs, multiplied by the KL penalty (beta in the original paper), for the ground truth and generated responses. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix punctuation.
The reward
, in this case, is calculated as the difference between model log probs and the reference log probs, multiplied by the KL penalty (beta in the original paper), for the ground truth and generated responses.
All metrics will be grouped by either ``train/`` or ``val/`` in WandB, representing whether that metric is from the training or validation set, respectively. | ||
You can also see a table which will print out the prompt, chosen response, and rejected response for each validation step. This allows you to keep track of response quality and hallucinations. | ||
|
||
When it comes to ideal hyperparameters for Self-Rewarding training, much will depend on the characteristics of your SFT (or base/foundational) model and your training data, so there is no one-size-fits-all parameter set which will work in all cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix capitalization, revise sentence.
When it comes to ideal hyperparameters for self-rewarding training, much will depend on the characteristics of your SFT (or base/foundational) model and your training data. Therefore, there is no one-size-fits-all parameter set that will work in all cases.
You can also see a table which will print out the prompt, chosen response, and rejected response for each validation step. This allows you to keep track of response quality and hallucinations. | ||
|
||
When it comes to ideal hyperparameters for Self-Rewarding training, much will depend on the characteristics of your SFT (or base/foundational) model and your training data, so there is no one-size-fits-all parameter set which will work in all cases. | ||
Additionally, Self-Rewarding (with or without meta) is a complex algorithm with a lot of moving pieces and a lot of parameters, so finding what works well for your model and data can be difficult. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix capitalization, revise.
Additionally, self-rewarding training (with or without meta) is a complex algorithm with a lot of moving pieces and a lot of parameters, so finding what works well for your model and data can be difficult.
|
||
When it comes to ideal hyperparameters for Self-Rewarding training, much will depend on the characteristics of your SFT (or base/foundational) model and your training data, so there is no one-size-fits-all parameter set which will work in all cases. | ||
Additionally, Self-Rewarding (with or without meta) is a complex algorithm with a lot of moving pieces and a lot of parameters, so finding what works well for your model and data can be difficult. | ||
Below are some of observations from the Nvidia Alignment team as to what parameters we have seen work well: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix capitalization, revise sentence.
Below are some observations from the NVIDIA Alignment team regarding parameters that we have found to work well:
Additionally, Self-Rewarding (with or without meta) is a complex algorithm with a lot of moving pieces and a lot of parameters, so finding what works well for your model and data can be difficult. | ||
Below are some of observations from the Nvidia Alignment team as to what parameters we have seen work well: | ||
|
||
* global_batch_size: we recommend using 64, and going up to 128 only for large models (70B+) that are also training with large datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
global_batch_size: We recommend using 64, and increasing to 128 only for large models (70B+) that are also training with large datasets.
Below are some of observations from the Nvidia Alignment team as to what parameters we have seen work well: | ||
|
||
* global_batch_size: we recommend using 64, and going up to 128 only for large models (70B+) that are also training with large datasets | ||
* iterations/epochs: the original paper uses 3 iterations with 1 epoch per iteration, and we find this to be sufficient for most use cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
iterations/epochs: The original paper uses 3 iterations with 1 epoch per iteration. We find this to be sufficient for most use cases.
|
||
* global_batch_size: we recommend using 64, and going up to 128 only for large models (70B+) that are also training with large datasets | ||
* iterations/epochs: the original paper uses 3 iterations with 1 epoch per iteration, and we find this to be sufficient for most use cases | ||
* learning rate: for SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
learning rate: For SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7 is recommended.
* global_batch_size: we recommend using 64, and going up to 128 only for large models (70B+) that are also training with large datasets | ||
* iterations/epochs: the original paper uses 3 iterations with 1 epoch per iteration, and we find this to be sufficient for most use cases | ||
* learning rate: for SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7. | ||
* ref_policy_kl_penalty: we did not see large changes from perturbations to this value; we recommend 0.1 - 0.001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
ef_policy_kl_penalty: We did not see large changes from perturbations to this value. We recommend 0.1 - 0.001.
* iterations/epochs: the original paper uses 3 iterations with 1 epoch per iteration, and we find this to be sufficient for most use cases | ||
* learning rate: for SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7. | ||
* ref_policy_kl_penalty: we did not see large changes from perturbations to this value; we recommend 0.1 - 0.001 | ||
* length_control: depends very much on model size and data, but we found good results with [0,0,0.1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
length_control: This parameter depends very much on model size and data, but we found good results with [0,0,0.1].
* learning rate: for SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7. | ||
* ref_policy_kl_penalty: we did not see large changes from perturbations to this value; we recommend 0.1 - 0.001 | ||
* length_control: depends very much on model size and data, but we found good results with [0,0,0.1] | ||
* use_meta_judge: we have found stronger results when settings this to true, which is in line with the paper's results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
use_meta_judge: We found stronger results when setting this parameter to true
, which is in line with the paper's results
* ref_policy_kl_penalty: we did not see large changes from perturbations to this value; we recommend 0.1 - 0.001 | ||
* length_control: depends very much on model size and data, but we found good results with [0,0,0.1] | ||
* use_meta_judge: we have found stronger results when settings this to true, which is in line with the paper's results | ||
* meta_judge_pcnt: we recommend you do not set this higher than 0.15 (15%). Any higher, and we have observed that the llm-as-a-judge model starts to output identical scores for every response (always a 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise
meta_judge_pcnt: We recommend not setting this higher than 0.15 (15%). Any higher, and we have observed that the LLM-as-a-judge model starts to output identical scores for every response (always a 5).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completed the technical edit of CHANGELOG.md and
docs/user-guide/self_rewarding.rst. Please review the edits, make the changes in the files, and mark each open thread "resolved."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still WIP but submitting first batch of comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file needed for Self-Rewarding? If not let's move it to a different PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's needed if you want to follow the self rewarding paper exactly to generate the EFT dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it'd be good to keep it then, but it also needs to be documented so that people understand how to generate this EFT dataset. At quick glance I'm not seeing it referenced in the self-rewarding doc => could you add it to explain how to generate an EFT dataset?
Signed-off-by: Daniel Egert <[email protected]>
Signed-off-by: Daniel Egert <[email protected]>
Signed-off-by: Daniel Egert <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: NeMo-Aligner CI <[email protected]>
Signed-off-by: Daniel Egert <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: NeMo-Aligner CI <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of minor typos
Signed-off-by: Daniel Egert <[email protected]>
…/NeMo-Aligner into degert/self-rewarding-trt
Signed-off-by: Daniel Egert <[email protected]>
Signed-off-by: Daniel Egert <[email protected]>
Signed-off-by: Daniel Egert <[email protected]>
I completed the technical edit of CHANGELOG.md and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to submit review in chunks so you can start addressing comments right away
max_steps: -1 | ||
limit_train_batches: 1.0 | ||
|
||
# Accelerate training times by accelerating inference stage using TRTLLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency
# Accelerate training times by accelerating inference stage using TRTLLM | |
# Speed-up training by accelerating inference stage using TRTLLM |
# reshard: False # reshard is not supported in generation | ||
|
||
# TRTLLM preallocates activation memory according to the number of input tokens | ||
# By default, assume the max input length is half of the model sequence length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just remove this line
# By default, assume the max input length is half of the model sequence length |
(btw, same in gpt_self_rewarding.yaml
and gpt_spin.yaml
)
# By default, assume the max input length is half of the model sequence length | ||
max_input_len: ${subtract:${model.encoder_seq_length}, ${model.generation.length_params.max_length}} | ||
|
||
model_type: gptnext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_type: gptnext | |
model_type: gptnext # can be gptj, gptnext, llama, gemma, falcon |
|
||
model_type: gptnext | ||
|
||
# Unload and reload the TRTLLM engine before and after the training stage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Unload and reload the TRTLLM engine before and after the training stage | |
# Save GPU memory by unloading and reloading the TRTLLM engine before and after the training stage |
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
resume_if_exists: True | ||
resume_ignore_no_checkpoint: True | ||
create_checkpoint_callback: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be set to False?
custom_trainer_state_dict = None | ||
consumed_samples = 0 | ||
|
||
if os.path.exists(gen_file := os.path.join(cfg.exp_manager.explicit_log_dir, "generations", "generations.jsonl")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this if
block replace the previous if
block above? (both seem to reload state from a previous run, but I guess only this one matters?)
""" | ||
dp_group = parallel_state.get_data_parallel_group() | ||
calc_gbs = cfg.model.generation.rollout_micro_batch_size * dp_group.size() | ||
with open_dict(cfg): | ||
cfg.model.global_batch_size = calc_gbs | ||
with open_dict(ptl_model.cfg): | ||
ptl_model.cfg.global_batch_size = calc_gbs | ||
if hasattr(ptl_model, "global_batch_size"): | ||
ptl_model.global_batch_size = calc_gbs | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like debug code that may be removed?
consumed_samples=consumed_samples, | ||
mbs=cfg.model.micro_batch_size, | ||
gbs=cfg.model.global_batch_size, | ||
collate_fn=eye, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use identity_collate
from nemo_aligner/data/nlp/builders.py
?
# eos_id = ptl_model.tokenizer.eos_id | ||
|
||
# collate fn to pad to the max seq length in the batch | ||
# collate_fn = partial( | ||
# self_rewarding_custom_collate, | ||
# eos_id=eos_id, | ||
# reset_position_ids=cfg.model.data.get("reset_position_ids", False), | ||
# reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), | ||
# eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be removed?
) | ||
|
||
init_using_ptl(trainer, ptl_model, train_dataloader, train_ds) | ||
# optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove
# optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) |
Signed-off-by: Daniel Egert <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more comments
consumed_samples=consumed_samples, | ||
mbs=cfg.model.micro_batch_size, | ||
gbs=cfg.model.global_batch_size, | ||
collate_fn=eye, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use identity_collate
@@ -37,6 +37,12 @@ | |||
) | |||
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu | |||
|
|||
try: | |||
import torch._dynamo | |||
torch._dynamo.config.suppress_errors = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment explaining why we need this
@@ -177,6 +181,7 @@ def main(cfg) -> None: | |||
logger=logger, | |||
ckpt_callback=ckpt_callback, | |||
run_timer=timer, | |||
exp_manager=cfg.exp_manager, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this arg is unused in SPINTrainer, what's up with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments on generation
|
||
|
||
class GenerationTrainer: | ||
"""Trainer to coordinate Self-Rewarding training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment to update
# input_ids = [item["input_ids"] for item in batch] | ||
# masks = [item["mask"] for item in batch] | ||
context_ids = [item["context_ids"] for item in batch] | ||
# answer_ids = [item["answer_ids"] for item in batch] | ||
context_lengths = torch.LongTensor([len(x) for x in context_ids]) | ||
# combined_lengths = torch.LongTensor([len(x) for x in input_ids]) | ||
|
||
# input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=eos_id) | ||
# masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True, padding_value=False) | ||
context_ids = torch.nn.utils.rnn.pad_sequence(context_ids, batch_first=True, padding_value=eos_id) | ||
# answer_ids = torch.nn.utils.rnn.pad_sequence(answer_ids, batch_first=True, padding_value=eos_id) | ||
|
||
output = { | ||
# "prompts_and_answers": input_ids, | ||
# "masks": masks, | ||
"prompts_only": context_ids, | ||
# "answers_only": answer_ids, | ||
"prompt_lengths": context_lengths, | ||
# "combined_lengths": combined_lengths, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to keep all the commented stuff? If not let's remove it to make it more readable.
self.set_max_steps() | ||
|
||
''' | ||
def augment_dataloader(self, dataloader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this whole code block that is commented out?
# assert ( | ||
# self.model.cfg.generation.rollout_micro_batch_size % dp_batch_size == 0 | ||
# ), f"rollout_micro_batch_size [{self.model.cfg.generation.rollout_micro_batch_size}] must be a multiple of GBS [{self.model.cfg.global_batch_size}] // DP [{parallel_state.get_data_parallel_world_size()}]" | ||
# self.rollout_micro_batch_size = self.model.cfg.generation.rollout_micro_batch_size | ||
# assert self.rollout_micro_batch_size > 0, "`rollout_micro_batch_size` must be > 0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be removed?
max_input_len=self.cfg.trt_llm.get( | ||
"max_input_len", self.model.cfg.encoder_seq_length - self.length_params["max_length"] | ||
), | ||
generation_batch_size=dp_batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp_batch_size
is based on the global batch size. I'd suggest instead to use micro_batch_size
, because it's a more natural hyper-parameter to tweak to trade between generation speed and memory usage for any DP size.
(and I would remove global_batch_size
from the config, overriding it in the code to micro_batch_size * DP
)
return # training ended | ||
|
||
global_pbar = tqdm( | ||
self.augment_dataloader(self.train_dataloader), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using augment_dataloader()
seems somewhat convoluted, why don't we just iterate on the dataloader (in the for
loop below) and run generation on each batch?
self.consumed_samples += self.model.cfg.global_batch_size | ||
self.step += 1 | ||
|
||
if torch.distributed.get_rank() == 0 and gen_tokens_list is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell the second check is useless
if torch.distributed.get_rank() == 0 and gen_tokens_list is not None: | |
if torch.distributed.get_rank() == 0: |
prompt = self.model.tokenizer.ids_to_text(t_[:s_].long().tolist()) | ||
response = self.model.tokenizer.ids_to_text(t_[s_:e_].long().tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note that this might be potentially dangerous. Some tokenizers behave in a weird way, and I'm not 100% sure we can always guarantee that decoding a subset of the token IDs is recovering the correct text of the response. No need to change it for now (you can resolve) since my quick tests suggest it should be fine, but IMO a safer approach is to decode the full sequence, ensure it starts with the original prompt (in text form), and keep only what's after this prompt. Just letting you know in case you run into some weird things in the future as new fancy tokenizers are introduced...
Also, not a huge deal but those two lines may be moved under the if v_:
below.
|
||
self.logger.finalize() | ||
|
||
if torch.distributed.get_rank() == 0 and self.generations_fh is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should never be None, right?
if torch.distributed.get_rank() == 0 and self.generations_fh is not None: | |
if torch.distributed.get_rank() == 0: |
if self.use_trtllm_generation: | ||
self.trtllm_generate.free() | ||
|
||
def save(self, extra_candidates=None, is_train_end=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is save()
called anywhere? Seems like we shouldn't need it since the state is saved in the JSONL output (then we could also probably get rid of state_dict()
)
What does this PR do ?
Adds support for the Self-Rewarding and Meta-Rewarding algorithms from the following two papers:
https://arxiv.org/abs/2401.10020
https://arxiv.org/abs/2407.19594
Changelog
Usage
Please see the new tutorial document at:
docs/user-guide/self_rewarding.rst
Before your PR is "Ready for review"
Pre checks:
Checklist when contributing a new algorithm
max_steps=-1
andvalidation
?Additional Information