Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix some issues + hyperparam #43

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=$1
python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct \
model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \
dataset.name=bio-datasets/mimic_style_transfer \
Expand All @@ -9,6 +9,7 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct
dataset.gen_ratio=0.7 \
dataset.sft_dataset=null \
sft.training_args.eval_steps=30 \
score.train.train_size=0.3 \
dpo.training_args.num_train_epochs=80 \
dpo.percentile=70
score.train.train_size=0.6 \
dpo.training_args.num_train_epochs=10 \
dpo.percentile=70 \
score.batch_size=8
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=$1
python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct \
model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \
dataset.name=bio-datasets/mimic_style_transfer \
max_steps=5 \
dataset.num_generated_samples=1500 \
dataset.num_generated_samples=3500 \
score.model.model_name_or_path=sentence-transformers/all-mpnet-base-v2 \
dataset.sft_ratio=0.06 \
dataset.gen_ratio=0.7 \
sft.training_args.eval_steps=30 \
score.train.train_size=0.3 \
dataset.sft_dataset.size=300 \
dpo.training_args.num_train_epochs=80 \
dpo.percentile=70
score.train.train_size=0.6 \
dataset.sft_dataset.size=3000 \
dpo.training_args.num_train_epochs=20 \
dpo.percentile=70 \
score.batch_size=8
7 changes: 4 additions & 3 deletions lib/style-transfer/bash/experiment/rb_gen/az/llama3.2-3b.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=$1
python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct \
model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \
dataset.name=bio-datasets/mimic_style_transfer \
Expand All @@ -10,5 +10,6 @@ python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct
dataset.sft_dataset=null \
sft.training_args.eval_steps=30 \
score.train.train_size=0.3 \
dpo.training_args.num_train_epochs=80 \
dpo.percentile=70
dpo.training_args.num_train_epochs=40 \
dpo.percentile=70 \
score.batch_size=64
2 changes: 1 addition & 1 deletion lib/style-transfer/bash/experiment/rb_gen/az/test-azure.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=$1
python style_transfer/run_rb_gen.py
3 changes: 2 additions & 1 deletion lib/style-transfer/configs/rb_gen/dpo/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ training_args:
save_steps: 50
gradient_accumulation_steps: 16
gradient_checkpointing: false
learning_rate: 4e-6
learning_rate: 5e-6
weight_decay: 1e-7
eval_strategy: "no"
num_train_epochs: 5
Expand All @@ -19,6 +19,7 @@ training_args:
max_length: 1024
max_prompt_length: 512
report_to: "none"
beta: 0.1

beta: 0.1
checkpoint: null
Expand Down
8 changes: 4 additions & 4 deletions lib/style-transfer/configs/rb_gen/score/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ model:

train:
warmup_steps: 50
use_ground_truth: false
epochs: 1
train_size: 0.5
use_ground_truth: true
epochs: 2
train_size: 0.6
loss:
_target_: sentence_transformers.losses.ContrastiveTensionLoss
_target_: sentence_transformers.losses.ContrastiveLoss
_partial_: true
23 changes: 19 additions & 4 deletions lib/style-transfer/style_transfer/rb_gen/steps/dpo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import hydra
import numpy as np
import pandas as pd
import peft
import wandb
from datasets import Dataset
from peft import AutoPeftModelForCausalLM
from omegaconf import ListConfig
from style_transfer.rb_gen.utils.utils import CustomWandbCallback
from transformers import PreTrainedTokenizerBase
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from trl import DPOTrainer


Expand Down Expand Up @@ -62,15 +63,29 @@ def dpo_train(
cfg.dpo.training_args.output_dir = f"models/{wandb.run.id}/dpo/{step}"
args = hydra.utils.instantiate(cfg.dpo.training_args)
args.padding_value = tokenizer.eos_token_id
model = AutoPeftModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_on_path=f"models/{wandb.run.id}/merged/"
)
model.enable_input_require_grads()
peft_config = hydra.utils.instantiate(cfg.model.peft_config)
peft_config.target_modules = (
list(peft_config.target_modules)
if isinstance(peft_config.target_modules, ListConfig)
else peft_config.target_modules
)
model = peft.get_peft_model(
model,
peft_config,
)
model.add_adapter(peft_config=peft_config, adapter_name="reference")
dpo_trainer = DPOTrainer(
args=args,
ref_model=None,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
callbacks=[CustomWandbCallback],
model_adapter_name="default",
ref_adapter_name="reference",
)
dpo_trainer.train()

Expand Down
4 changes: 1 addition & 3 deletions lib/style-transfer/style_transfer/rb_gen/steps/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import os
import shutil
import sqlite3
from typing import Callable

Expand Down Expand Up @@ -78,7 +77,6 @@ def generate(
del llm
gc.collect()
torch.cuda.empty_cache()
shutil.rmtree(f"models/{wandb.run.id}/merged/")
return gen_pred_dataset


Expand All @@ -96,7 +94,7 @@ def batch_generate(cfg, step, dataloader, llm, wb_ds_name) -> Dataset:
)
batch_logs = {
"prompts": batch["query"],
"ground_texts": batch["text"],
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
gen_df = pd.DataFrame.from_dict(batch_logs)
Expand Down
8 changes: 6 additions & 2 deletions lib/style-transfer/style_transfer/rb_gen/steps/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ def sft_train(

cfg.sft.training_args.output_dir = f"models/{wandb.run.id}/sft"
args = hydra.utils.instantiate(cfg.sft.training_args)
wandb.config.update({"state": "sft"}, allow_val_change=True)
test_sft_dataset = None
if cfg.dataset.sft_dataset is not None:
sft_dataset, test_sft_dataset = sft_dataset.train_test_split(
train_size=0.1, shuffle=False
).values()
args.load_best_model_at_end = True
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
eval_dataset=test_dataset,
eval_dataset=test_dataset if test_sft_dataset is None else test_sft_dataset,
callbacks=[CustomWandbCallback],
)
trainer.train()
Expand Down
5 changes: 4 additions & 1 deletion lib/style-transfer/style_transfer/run_rb_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import shutil

import hydra
import wandb
Expand All @@ -14,7 +15,7 @@
logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_LOG_MODEL"] = "none"
os.environ["WANDB_START_METHOD"] = "thread"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
tqdm.pandas()
Expand Down Expand Up @@ -65,6 +66,7 @@ def main(cfg: DictConfig):
sft_train(cfg, sft_dataset, test_dataset, current_model_path)
logger.info("Bootstrapping done, Iterative Reward-based Generation Training begins...")
for step in range(cfg.max_steps):
logger.info(f"🔄 Step {step} ...")
sth_dataset = generate(
cfg,
step,
Expand Down Expand Up @@ -97,6 +99,7 @@ def main(cfg: DictConfig):
sth_dataset,
checkpoint=eval_model_path,
)
shutil.rmtree(f"models/{wandb.run.id}/merged/")
wandb.finish()


Expand Down
Loading