From 15b76bdc764e65c834dee2d363d08d72d28c2f19 Mon Sep 17 00:00:00 2001 From: Simon Meoni Date: Thu, 17 Oct 2024 14:33:00 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=97=EF=B8=8F=20Style=20Transfer:=20modify?= =?UTF-8?q?=20hyperparameters=20to=20fix=20dpo=20learning=20problem?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../experiment/rb_gen/az/llama3.2-3b-complete.sh | 14 ++++++++++++++ .../experiment/rb_gen/az/llama3.2-3b-pb-seed.sh | 4 +++- .../bash/experiment/rb_gen/az/llama3.2-3b.sh | 4 +++- lib/style-transfer/configs/rb_gen/dpo/default.yaml | 3 ++- 4 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-complete.sh diff --git a/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-complete.sh b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-complete.sh new file mode 100644 index 0000000..4d64aca --- /dev/null +++ b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-complete.sh @@ -0,0 +1,14 @@ +export CUDA_VISIBLE_DEVICES=1 +python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct \ + model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \ + dataset.name=bio-datasets/mimic_style_transfer \ + max_steps=5 \ + dataset.num_generated_samples=3500 \ + score.model.model_name_or_path=sentence-transformers/all-mpnet-base-v2 \ + dataset.sft_ratio=0.06 \ + dataset.gen_ratio=0.7 \ + dataset.sft_dataset=null \ + sft.training_args.eval_steps=30 \ + score.train.train_size=0.3 \ + dpo.training_args.num_train_epochs=80 \ + dpo.percentile=70 diff --git a/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-pb-seed.sh b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-pb-seed.sh index b6d109f..4c05844 100644 --- a/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-pb-seed.sh +++ b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b-pb-seed.sh @@ -9,4 +9,6 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct dataset.gen_ratio=0.7 \ sft.training_args.eval_steps=30 \ score.train.train_size=0.3 \ - dataset.sft_dataset.size=300 + dataset.sft_dataset.size=300 \ + dpo.training_args.num_train_epochs=80 \ + dpo.percentile=70 diff --git a/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b.sh b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b.sh index 6dd65c9..78c2fab 100644 --- a/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b.sh +++ b/lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b.sh @@ -9,4 +9,6 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct dataset.gen_ratio=0.7 \ dataset.sft_dataset=null \ sft.training_args.eval_steps=30 \ - score.train.train_size=0.3 + score.train.train_size=0.3 \ + dpo.training_args.num_train_epochs=80 \ + dpo.percentile=70 diff --git a/lib/style-transfer/configs/rb_gen/dpo/default.yaml b/lib/style-transfer/configs/rb_gen/dpo/default.yaml index 4cb6130..c95b7a7 100644 --- a/lib/style-transfer/configs/rb_gen/dpo/default.yaml +++ b/lib/style-transfer/configs/rb_gen/dpo/default.yaml @@ -5,10 +5,11 @@ training_args: save_steps: 50 gradient_accumulation_steps: 16 gradient_checkpointing: false - learning_rate: 5e-7 + learning_rate: 4e-6 weight_decay: 1e-7 eval_strategy: "no" num_train_epochs: 5 + output_dir: "models/dpo/" optim: "adafactor" save_only_model: true remove_unused_columns: false