Skip to content

Commit

Permalink
✨ Style Transfer: add eval dataset from public seed at the sft step
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Nov 4, 2024
1 parent 795866c commit 898bb09
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct
dataset.gen_ratio=0.7 \
dataset.sft_dataset=null \
sft.training_args.eval_steps=30 \
score.train.train_size=0.3 \
dpo.training_args.num_train_epochs=40 \
score.train.train_size=0.6 \
dpo.training_args.num_train_epochs=10 \
dpo.percentile=70 \
score.batch_size=64
score.batch_size=8
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct
model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \
dataset.name=bio-datasets/mimic_style_transfer \
max_steps=5 \
dataset.num_generated_samples=1500 \
dataset.num_generated_samples=3500 \
score.model.model_name_or_path=sentence-transformers/all-mpnet-base-v2 \
dataset.sft_ratio=0.06 \
dataset.gen_ratio=0.7 \
sft.training_args.eval_steps=30 \
score.train.train_size=0.3 \
dataset.sft_dataset.size=300 \
dpo.training_args.num_train_epochs=40 \
score.train.train_size=0.6 \
dataset.sft_dataset.size=3000 \
dpo.training_args.num_train_epochs=20 \
dpo.percentile=70 \
score.batch_size=64
score.batch_size=8
8 changes: 6 additions & 2 deletions lib/style-transfer/style_transfer/rb_gen/steps/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ def sft_train(

cfg.sft.training_args.output_dir = f"models/{wandb.run.id}/sft"
args = hydra.utils.instantiate(cfg.sft.training_args)
wandb.config.update({"state": "sft"}, allow_val_change=True)
test_sft_dataset = None
if cfg.dataset.sft_dataset is not None:
sft_dataset, test_sft_dataset = sft_dataset.train_test_split(
train_size=0.1, shuffle=False
).values()
args.load_best_model_at_end = True
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
eval_dataset=test_dataset,
eval_dataset=test_dataset if test_sft_dataset is None else test_sft_dataset,
callbacks=[CustomWandbCallback],
)
trainer.train()
Expand Down

0 comments on commit 898bb09

Please sign in to comment.