diff --git a/lib/style-transfer/configs/rb_gen/dpo/default.yaml b/lib/style-transfer/configs/rb_gen/dpo/default.yaml index c95b7a7..c689cd1 100644 --- a/lib/style-transfer/configs/rb_gen/dpo/default.yaml +++ b/lib/style-transfer/configs/rb_gen/dpo/default.yaml @@ -5,7 +5,7 @@ training_args: save_steps: 50 gradient_accumulation_steps: 16 gradient_checkpointing: false - learning_rate: 4e-6 + learning_rate: 5e-6 weight_decay: 1e-7 eval_strategy: "no" num_train_epochs: 5 @@ -19,6 +19,7 @@ training_args: max_length: 1024 max_prompt_length: 512 report_to: "none" + beta: 0.1 beta: 0.1 checkpoint: null diff --git a/lib/style-transfer/configs/rb_gen/score/default.yaml b/lib/style-transfer/configs/rb_gen/score/default.yaml index 5806ae9..5832123 100644 --- a/lib/style-transfer/configs/rb_gen/score/default.yaml +++ b/lib/style-transfer/configs/rb_gen/score/default.yaml @@ -6,10 +6,10 @@ model: model_name_or_path: "sentence-transformers/all-mpnet-base-v2" train: - warmup_steps: 10 + warmup_steps: 50 use_ground_truth: true - epochs: 5 - train_size: 0.3 + epochs: 2 + train_size: 0.6 loss: - _target_: sentence_transformers.losses.ContrastiveTensionLoss + _target_: sentence_transformers.losses.ContrastiveLoss _partial_: true