Skip to content

Commit

Permalink
✨ Style Transfer: add gen script
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Jan 31, 2024
1 parent d85c5e8 commit ebea6ab
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 32 deletions.
2 changes: 0 additions & 2 deletions lib/style-transfer/.dockerignore

This file was deleted.

23 changes: 23 additions & 0 deletions lib/style-transfer/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
lora:
_target_: peft.LoraConfig
task_type: CAUSAL_LM
r: 8
lora_alpha: 16
lora_dropout: 0.05
bias: none
target_modules:
["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"]

bnb_config:
_target_: transformers.BitsAndBytesConfig
load_in_4bit: true
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: True
bnb_4bit_compute_dtype: bfloat16

model: "mistralai/Mistral-7B-Instruct-v0.1"
dataset: "bio-datasets/mimic_style_transfer"
sft_ratio: 0.1
gen_ratio: 0.7
seed: 0
max_seq_length: 1024
8 changes: 8 additions & 0 deletions lib/style-transfer/configs/gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _global_
defaults:
- _self_
- default

checkpoint: "clinical-dream-team/sft-style-transfer/checkpoint-tc5l40v2:v0"
batch_size: 4
num_generated_sequences: 4
27 changes: 3 additions & 24 deletions lib/style-transfer/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# @package _global_
defaults:
- _self_
- default

training_args:
_target_: transformers.TrainingArguments
Expand All @@ -21,27 +24,3 @@ training_args:
num_train_epochs: 15
save_only_model: true
save_safetensors: false

lora:
_target_: peft.LoraConfig
task_type: CAUSAL_LM
r: 8
lora_alpha: 16
lora_dropout: 0.05
bias: none
target_modules:
["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"]

bnb_config:
_target_: transformers.BitsAndBytesConfig
load_in_4bit: true
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: True
bnb_4bit_compute_dtype: bfloat16

model: "mistralai/Mistral-7B-Instruct-v0.1"
dataset: "bio-datasets/mimic_style_transfer"
sft_ratio: 0.1
gen_ratio: 0.7
seed: 0
max_seq_length: 1024
90 changes: 90 additions & 0 deletions lib/style-transfer/style_transfer/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging

import hydra
import mii
import pandas as pd
import torch
import wandb
from peft import AutoPeftModelForCausalLM
from style_transfer.utils import PROMPT, build_dataset, split_dataset
from tqdm import tqdm
from transformers import AutoTokenizer


@hydra.main(version_base="1.3", config_path="../configs", config_name="gen.yaml")
def main(cfg):
dataset = build_dataset(
dataset_name=cfg.dataset,
model_name=cfg.model,
max_sampler_length=cfg.max_seq_length,
)

def add_prompt(data_point):
data_point["prompts"] = (
"<s>[INST]"
+ str.format(
PROMPT,
data_point["keywords"],
)
+ "[INST]\n"
)
return data_point

dataset = dataset.map(
add_prompt,
batched=False,
)
_, gen_dataset, _ = split_dataset(dataset, cfg.sft_ratio, cfg.gen_ratio)
gen_dataset = dataset.remove_columns(["input_ids", "max_gen_len"])

with wandb.init(project="gen-style-transfer") as run:
my_model_artifact = run.use_artifact(cfg.checkpoint)
model_dir = my_model_artifact.download()

model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = model.merge_and_unload()
model.save_pretrained("models/merged/")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
tokenizer.save_pretrained("models/merged/")
del model
del tokenizer
logging.info("Model + Tokenizer saved at models/merged/")
logging.info("Loading model to pipeline 🐉 ...")
pipe = mii.pipeline("models/merged/")
logging.info("Model loaded to pipeline ! 🎉")
dataloader = torch.utils.data.DataLoader(
gen_dataset,
batch_size=cfg.batch_size,
)
new_dataset = []
for batch in tqdm(dataloader):
generated_sequences = []
for _ in range(cfg.num_generated_sequences):
responses = pipe(batch["prompts"], max_new_tokens=12)
generated_sequences.append([response.generated_text for response in responses])

responses = list(map(list, zip(*generated_sequences)))
flattened_gs_dict = {
f"generation_{reponse_id}": response
for reponse_id, response in enumerate(responses)
}
batch_logs = {
"prompts": batch["prompts"],
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
new_dataset.extend([dict(zip(batch_logs, t)) for t in zip(*batch_logs.values())])
table = wandb.Table(dataframe=pd.DataFrame(batch_logs))
wandb.log({"generation_predictions": table})
break
df = pd.DataFrame(new_dataset)
wandb.log({"dataframe_table": wandb.Table(dataframe=df)})
wandb.finish()


if __name__ == "__main__":
main()
11 changes: 6 additions & 5 deletions lib/style-transfer/style_transfer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ def add_prompt(data_point):
)
return data_point

dataset = dataset.map(
add_prompt,
batched=False,
)
train_dataset, gen_, test_dataset = split_dataset(dataset, cfg.sft_ratio, cfg.gen_ratio)

model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=cfg.model,
torch_dtype=torch.bfloat16,
device_map={"": Accelerator().local_process_index},
quantization_config=hydra.utils.instantiate(cfg.bnb_config),
)
dataset = dataset.map(
add_prompt,
batched=False,
)
train_dataset, gen_, test_dataset = split_dataset(dataset, cfg.sft_ratio, cfg.gen_ratio)

args = hydra.utils.instantiate(cfg.training_args)

Expand Down
2 changes: 1 addition & 1 deletion lib/style-transfer/style_transfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def build_dataset(dataset_name, model_name, max_sampler_length):
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset(dataset_name, split="train", trust_remote_code=True)
ds = load_dataset(dataset_name, split="train")

ds_dict = {"keywords": [], "text": []}
for keywords, text in zip(ds["keywords"], ds["text"]):
Expand Down

0 comments on commit ebea6ab

Please sign in to comment.