Skip to content

Commit

Permalink
✨ Style Transfer: merge eval + score and add experimentation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Feb 13, 2024
1 parent 3bb153d commit c35062b
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 70 deletions.
3 changes: 1 addition & 2 deletions lib/style-transfer/bash/gen_09-2024.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/bin/bash

cd ..
python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.99_gen-ratio-0.7:v281 sft_ratio=0.99 gen_ratio=0.7
python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.99_gen-ratio-0.7:v281 sft_ratio=0.99 gen_ratio=0.7 &&
python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.02_gen-ratio-0.7:v11 sft_ratio=0.02 gen_ratio=0.7
6 changes: 6 additions & 0 deletions lib/style-transfer/bash/gen_10-2024.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

nohup python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.004_gen-ratio-0.7:v5 sft_ratio=0.004 gen_ratio=0.7 &&
nohup python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.006_gen-ratio-0.7:v7 sft_ratio=0.006 gen_ratio=0.7 &&
nohup python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.01_gen-ratio-0.7:v7 sft_ratio=0.01 gen_ratio=0.7 &&
nohup python style_transfer/gen.py checkpoint=clinical-dream-team/sft-style-transfer/checkpoint-sft-ratio-0.99_gen-ratio-0.7:v281 sft_ratio=0.99 gen_ratio=0.7
19 changes: 19 additions & 0 deletions lib/style-transfer/bash/score_10-2024.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-byeksy0m-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-byeksy0m-test_dataset:v0 \
sft_ratio=0.004 \
gen_ratio=0.7

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-2vkesygs-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-2vkesygs-test_dataset:v0 \
sft_ratio=0.006 \
gen_ratio=0.7

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-ka1qr4r7-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-ka1qr4r7-test_dataset:v0 \
sft_ratio=0.01 \
gen_ratio=0.7
19 changes: 19 additions & 0 deletions lib/style-transfer/bash/score_13-2024.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-byeksy0m-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-byeksy0m-test_dataset:v0 \
sft_ratio=0.004 \
gen_ratio=0.7

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-2vkesygs-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-2vkesygs-test_dataset:v0 \
sft_ratio=0.006 \
gen_ratio=0.7

python style_transfer/score.py \
gen_dataset=clinical-dream-team/gen-style-transfer/run-ka1qr4r7-gen_dataset:v0 \
test_dataset=clinical-dream-team/gen-style-transfer/run-ka1qr4r7-test_dataset:v0 \
sft_ratio=0.01 \
gen_ratio=0.7
4 changes: 1 addition & 3 deletions lib/style-transfer/bash/sft_09-2024.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/bin/bash

#TODO run this script
cd ..
python style_transfer/sft.py -m sft_ratio=0.01,0.008,0.006,0.004,0.002 gen_ratio=0.7 training_args.num_train_epochs=50
accelerate launch --num_cpu_threads_per_process=16 --config_file=accelerate-config.yaml style_transfer/sft.py -m sft_ratio=0.01,0.008,0.006,0.004,0.002 gen_ratio=0.7 training_args.num_train_epochs=50
1 change: 1 addition & 0 deletions lib/style-transfer/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ gen_ratio: 0.7
seed: 0
max_seq_length: 1024
num_generated_sequences: 4
dpo_gen: 0
4 changes: 2 additions & 2 deletions lib/style-transfer/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ training_args:
seed: 0

beta: 0.1
checkpoint: clinical-dream-team/sft-style-transfer/checkpoint-pgj3j74y:v11
dataset: clinical-dream-team/score-style-transfer/run-jrhhyjad-score_dataset:v0
checkpoint: ???
dataset: ???
max_length: 1024
max_prompt_length: 512
6 changes: 0 additions & 6 deletions lib/style-transfer/configs/eval.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions lib/style-transfer/configs/score.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ defaults:
- _self_

evaluator: kaist-ai/Prometheus-13b-v1.0
dataset: ???
gen_dataset: ???
test_dataset: ???
use_sem_score: true
use_g_score: false
batch_size: 4
max_new_tokens: 1024
dpo_gen: 0
9 changes: 7 additions & 2 deletions lib/style-transfer/style_transfer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="dpo.yaml")
def main(cfg):
def dpo(cfg):
api = wandb.Api()
dataset = api.artifact(cfg.dataset)
dataset = dataset.files()[0].download(replace=True)
Expand Down Expand Up @@ -78,10 +78,15 @@ def add_preferences(data_point):
filtered_columns = pd.DataFrame(df_point).filter(regex="^eval_sem_scores")
max_labels = filtered_columns.max().idxmax()[-1]
best_generation = df_point[f"generation_{max_labels}"].values[0]
best_score = filtered_columns.max().max()
min_labels = filtered_columns.min().idxmin()[-1]
worst_generation = df_point[f"generation_{min_labels}"].values[0]
worst_score = filtered_columns.min().min()
data_point["chosen"] = best_generation
data_point["rejected"] = worst_generation
data_point["chosen_score"] = best_score
data_point["rejected_score"] = worst_score
data_point["deviation_score"] = best_score - worst_score
return data_point

dataset = dataset.map(
Expand Down Expand Up @@ -121,4 +126,4 @@ def setup(self, args, state, model, **kwargs):


if __name__ == "__main__":
main()
dpo()
32 changes: 0 additions & 32 deletions lib/style-transfer/style_transfer/eval.py

This file was deleted.

16 changes: 9 additions & 7 deletions lib/style-transfer/style_transfer/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="gen.yaml")
def main(cfg):
def gen(cfg):
api = wandb.Api()
model_artifact = api.artifact(cfg.checkpoint)
model_dir = model_artifact.download()
Expand Down Expand Up @@ -59,7 +59,7 @@ def add_prompt(data_point):
)
sft_dataset, gen_dataset, test_dataset = split_dataset(dataset, cfg.sft_ratio, cfg.gen_ratio)
gen_dataset = gen_dataset.remove_columns(["input_ids", "max_gen_len"])
test_dataset = sft_dataset.remove_columns(["input_ids", "max_gen_len"])
test_dataset = test_dataset.remove_columns(["input_ids", "max_gen_len"])
dataloader = torch.utils.data.DataLoader(
gen_dataset,
batch_size=cfg.batch_size,
Expand All @@ -80,6 +80,7 @@ def add_prompt(data_point):
project="gen-style-transfer",
name=f"sft-ratio-{cfg.sft_ratio}_gen-ratio-{cfg.gen_ratio}",
)
gen_df = None
for batch in tqdm(dataloader):
flattened_gs_dict = {}
for g_seq in range(cfg.num_generated_sequences):
Expand All @@ -95,12 +96,13 @@ def add_prompt(data_point):
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
df = pd.DataFrame.from_dict(batch_logs)
dataset.append(df)
gen_df = pd.DataFrame.from_dict(batch_logs)
dataset.append(gen_df)

wandb.log({"gen_dataset": wandb.Table(dataframe=pd.concat(dataset))})

test_dataset = []
test_df = None
for batch in tqdm(test_dataloader):
flattened_gs_dict = {}
for g_seq in range(cfg.num_generated_sequences):
Expand All @@ -116,12 +118,12 @@ def add_prompt(data_point):
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
df = pd.DataFrame.from_dict(batch_logs)
test_dataset.append(df)
test_df = pd.DataFrame.from_dict(batch_logs)
test_dataset.append(test_df)

wandb.log({"test_dataset": wandb.Table(dataframe=pd.concat(test_dataset))})
wandb.finish()


if __name__ == "__main__":
main()
gen()
41 changes: 29 additions & 12 deletions lib/style-transfer/style_transfer/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="score.yaml")
def main(cfg):
def score(cfg):
if cfg.use_g_score:
model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=cfg.evaluator,
Expand All @@ -41,15 +41,12 @@ def main(cfg):
else:
client = None

wandb.config = omegaconf.OmegaConf.to_container(
cfg,
)
with wandb.init(
project="score-style-transfer",
name=f"sft-ratio-{cfg.sft_ratio}_gen-ratio-{cfg.gen_ratio}"
f"{'' if cfg.dpo_gen == 0 else f'_dpo{cfg.dpo_gen}'}",
) as run:
gen_dataset = run.use_artifact(cfg.dataset)
gen_dataset = run.use_artifact(cfg.gen_dataset)
json_file = json.load(gen_dataset.files()[0].download(replace=True))
df = pd.DataFrame(data=json_file["data"], columns=json_file["columns"])
gen_dataset = datasets.Dataset.from_pandas(df)
Expand Down Expand Up @@ -121,7 +118,7 @@ def sem_scores(batch, seq):
]
batch.setdefault(f"eval_sem_scores_{seq}", []).extend(scores)

gen_dataset = []
gen_df = []
for batch in tqdm(gen_dataloader):
for seq in range(cfg.num_generated_sequences):
if cfg.use_sem_score:
Expand All @@ -130,11 +127,11 @@ def sem_scores(batch, seq):
g_scores(batch, seq)

df = pd.DataFrame(batch)
gen_dataset.append(df)
gen_df.append(df)

wandb.log({"gen_score_dataset": wandb.Table(dataframe=pd.concat(gen_dataset))})
wandb.log({"gen_score_dataset": wandb.Table(dataframe=pd.concat(gen_df))})

test_dataset = []
test_df = []
for batch in tqdm(test_dataloader):
for seq in range(cfg.num_generated_sequences):
if cfg.use_sem_score:
Expand All @@ -143,14 +140,34 @@ def sem_scores(batch, seq):
g_scores(batch, seq)

df = pd.DataFrame(batch)
test_dataset.append(df)
test_df.append(df)

eval_cols = [f"eval_sem_scores_{i}" for i in range(cfg.num_generated_sequences)]

df["max_score"] = df[eval_cols].max(axis=1)
df["min_score"] = df[eval_cols].min(axis=1)
df["mean_score"] = df[eval_cols].mean(axis=1)
df["std_score"] = df[eval_cols].std(axis=1)
# Determine the best generated text based on the maximum score
best_generation_indices = df[eval_cols].idxmax(axis=1).apply(lambda x: int(x[-1]))
df["best_generation"] = best_generation_indices.apply(
lambda x: df["generation_" + str(x)].iloc[x]
)
df["worst_generation"] = df[eval_cols].idxmin(axis=1).apply(lambda x: int(x[-1]))

wandb.log({"test_score_dataset": wandb.Table(dataframe=pd.concat(test_dataset))})
wandb.log({"test/max/mean": df["max_score"].mean()})
wandb.log({"test/min/mean": df["min_score"].mean()})
wandb.log({"test/mean/mean": df["mean_score"].mean()})
wandb.log({"test/std/mean": df["std_score"].mean()})
wandb.log({"test_score_dataset": wandb.Table(dataframe=pd.concat(test_df))})

if cfg.use_g_score:
client.terminate_server()
wandb.config = omegaconf.OmegaConf.to_container(
cfg,
)
wandb.finish()


if __name__ == "__main__":
main()
score()
5 changes: 3 additions & 2 deletions lib/style-transfer/style_transfer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="sft.yaml")
def main(cfg):
def sft(cfg):
set_seed(cfg.seed)
dataset = build_dataset(
dataset_name=cfg.dataset,
Expand Down Expand Up @@ -67,6 +67,7 @@ def setup(self, args, state, model, **kwargs):
self._wandb.config["test_dataset_size"] = len(test_dataset)

args.run_name = f"sft-ratio-{cfg.sft_ratio}_gen-ratio-{cfg.gen_ratio}"
args.load_best_model_at_end = True
trainer = SFTTrainer(
model=model,
args=args,
Expand All @@ -82,4 +83,4 @@ def setup(self, args, state, model, **kwargs):


if __name__ == "__main__":
main()
sft()

0 comments on commit c35062b

Please sign in to comment.