Skip to content

Commit

Permalink
✨ Style Transfer: add score script
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Jan 31, 2024
1 parent ebea6ab commit 192ccfc
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 15 deletions.
4 changes: 2 additions & 2 deletions lib/style-transfer/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ lora:
lora_alpha: 16
lora_dropout: 0.05
bias: none
target_modules:
["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"]
target_modules: ["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"]

bnb_config:
_target_: transformers.BitsAndBytesConfig
Expand All @@ -21,3 +20,4 @@ sft_ratio: 0.1
gen_ratio: 0.7
seed: 0
max_seq_length: 1024
num_generated_sequences: 4
6 changes: 3 additions & 3 deletions lib/style-transfer/configs/gen.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# @package _global_
defaults:
- _self_
- default
- _self_

checkpoint: "clinical-dream-team/sft-style-transfer/checkpoint-tc5l40v2:v0"
checkpoint: ???
batch_size: 4
num_generated_sequences: 4
max_new_tokens: 1024
8 changes: 8 additions & 0 deletions lib/style-transfer/configs/score.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- default
- _self_

evaluator: "kaist-ai/Prometheus-13b-v1.0"
dataset: "clinical-dream-team/gen-style-transfer/run-yjc36777-dataframe_table:v0"
batch_size: 4
max_new_tokens: 1024
2 changes: 1 addition & 1 deletion lib/style-transfer/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_
defaults:
- _self_
- default
- _self_

training_args:
_target_: transformers.TrainingArguments
Expand Down
18 changes: 9 additions & 9 deletions lib/style-transfer/style_transfer/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def add_prompt(data_point):
PROMPT,
data_point["keywords"],
)
+ "[INST]\n"
+ "[/INST]\n"
)
return data_point

Expand All @@ -36,10 +36,14 @@ def add_prompt(data_point):
)
_, gen_dataset, _ = split_dataset(dataset, cfg.sft_ratio, cfg.gen_ratio)
gen_dataset = dataset.remove_columns(["input_ids", "max_gen_len"])
dataloader = torch.utils.data.DataLoader(
gen_dataset,
batch_size=cfg.batch_size,
)

with wandb.init(project="gen-style-transfer") as run:
my_model_artifact = run.use_artifact(cfg.checkpoint)
model_dir = my_model_artifact.download()
model_artifact = run.use_artifact(cfg.checkpoint)
model_dir = model_artifact.download()

model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_dir,
Expand All @@ -56,15 +60,12 @@ def add_prompt(data_point):
logging.info("Loading model to pipeline 🐉 ...")
pipe = mii.pipeline("models/merged/")
logging.info("Model loaded to pipeline ! 🎉")
dataloader = torch.utils.data.DataLoader(
gen_dataset,
batch_size=cfg.batch_size,
)

new_dataset = []
for batch in tqdm(dataloader):
generated_sequences = []
for _ in range(cfg.num_generated_sequences):
responses = pipe(batch["prompts"], max_new_tokens=12)
responses = pipe(batch["prompts"], max_new_tokens=cfg.max_new_tokens)
generated_sequences.append([response.generated_text for response in responses])

responses = list(map(list, zip(*generated_sequences)))
Expand All @@ -80,7 +81,6 @@ def add_prompt(data_point):
new_dataset.extend([dict(zip(batch_logs, t)) for t in zip(*batch_logs.values())])
table = wandb.Table(dataframe=pd.DataFrame(batch_logs))
wandb.log({"generation_predictions": table})
break
df = pd.DataFrame(new_dataset)
wandb.log({"dataframe_table": wandb.Table(dataframe=df)})
wandb.finish()
Expand Down
88 changes: 88 additions & 0 deletions lib/style-transfer/style_transfer/score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import logging

import datasets
import hydra
import mii
import pandas as pd
import torch
import wandb
from fastchat.conversation import get_conv_template
from style_transfer.utils import EVAL_PROMPT
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


@hydra.main(version_base="1.3", config_path="../configs", config_name="score.yaml")
def main(cfg):
with wandb.init(project="score-style-transfer") as run:
dataset = run.use_artifact(cfg.dataset)
json_file = json.load(dataset.files()[0].download(replace=True))
df = pd.DataFrame(data=json_file["data"], columns=json_file["columns"])
dataset = datasets.Dataset.from_pandas(df)

def add_prompt(data_point):
for seq in range(cfg.num_generated_sequences):
data_point[f"eval_prompt_{seq}"] = str.format(
EVAL_PROMPT,
data_point["prompts"],
data_point[f"generation_{seq}"],
data_point["ground_texts"],
)
conv = get_conv_template("llama-2")
conv.set_system_message("You are a fair evaluator language model.")
conv.append_message(conv.roles[0], data_point[f"eval_prompt_{seq}"])
conv.append_message(conv.roles[1], None)
data_point[f"eval_prompt_{seq}"] = conv.get_prompt()
return data_point

dataset = dataset.map(
add_prompt,
batched=False,
)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg.batch_size,
)

logging.info("Model + Tokenizer saved at models/merged/")
logging.info("Loading model to pipeline 🐉 ...")
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=cfg.evaluator,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.save_pretrained("models/evaluator/")
tokenizer = AutoTokenizer.from_pretrained(cfg.evaluator)
tokenizer.save_pretrained("models/evaluator/")
del model
del tokenizer
pipe = mii.pipeline("models/evaluator/")
logging.info("Model loaded to pipeline ! 🎉")

new_dataset = []
for batch in tqdm(dataloader):
for seq in range(cfg.num_generated_sequences):
responses = pipe(batch[f"eval_prompt_{seq}"], max_new_tokens=cfg.max_new_tokens)
scores = [response.generated_text[-1] for response in responses]
scores = [
float(score) if score.isdigit() and 0 <= float(score) <= 5 else 0
for score in scores
]
feedbacks = [
response.generated_text.split("[RESULT]")[0].strip() for response in responses
]
batch.setdefault(f"eval_scores_{seq}", []).extend(scores)
batch.setdefault(f"eval_feedbacks_{seq}", []).extend(feedbacks)

new_dataset.extend([dict(zip(batch, t)) for t in zip(*batch.values())])
table = wandb.Table(dataframe=pd.DataFrame(batch))
wandb.log({"generation_predictions": table})
df = pd.DataFrame(new_dataset)
wandb.log({"dataframe_table": wandb.Table(dataframe=df)})
wandb.finish()


if __name__ == "__main__":
main()

0 comments on commit 192ccfc

Please sign in to comment.