Skip to content

Commit

Permalink
✨Style Transfer: the model is serve in a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Feb 7, 2024
1 parent d526de8 commit 07e92aa
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ repos:
hooks:
- id: mypy
exclude: apps/|notebooks/|tests/
additional_dependencies: ["types-requests"]
additional_dependencies: ["types-requests", "types-PyYAML"]

# Safety is used to check if there are hardcoded secrets inside code and history
- repo: https://github.com/gitleaks/gitleaks
Expand Down
71 changes: 26 additions & 45 deletions lib/style-transfer/style_transfer/gen.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging

import hydra
import mii
import pandas as pd
import torch
import wandb
from peft import AutoPeftModelForCausalLM
from omegaconf import omegaconf
from style_transfer.utils import PROMPT, build_dataset, split_dataset
from tqdm import tqdm
from transformers import AutoTokenizer


@hydra.main(version_base="1.3", config_path="../configs", config_name="gen.yaml")
Expand Down Expand Up @@ -40,49 +37,33 @@ def add_prompt(data_point):
gen_dataset,
batch_size=cfg.batch_size,
)
client = mii.client("models/merged/")

with wandb.init(project="gen-style-transfer") as run:
model_artifact = run.use_artifact(cfg.checkpoint)
model_dir = model_artifact.download()

model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = model.merge_and_unload()
model.save_pretrained("models/merged/")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
tokenizer.save_pretrained("models/merged/")
del model
del tokenizer
logging.info("Model + Tokenizer saved at models/merged/")
logging.info("Loading model to pipeline 🐉 ...")
pipe = mii.pipeline("models/merged/")
logging.info("Model loaded to pipeline ! 🎉")

new_dataset = []
for batch in tqdm(dataloader):
generated_sequences = []
for _ in range(cfg.num_generated_sequences):
responses = pipe(batch["prompts"], max_new_tokens=cfg.max_new_tokens)
generated_sequences.append([response.generated_text for response in responses])
wandb.config = omegaconf.OmegaConf.to_container(
cfg,
resolve=True,
throw_on_missing=True,
)
wandb.init(project="gen-style-transfer")
for batch in tqdm(dataloader):
flattened_gs_dict = {}
for g_seq in range(cfg.num_generated_sequences):
responses = client.generate(
batch["prompts"],
max_new_tokens=cfg.max_new_tokens,
)
flattened_gs_dict[f"generation_{g_seq}"] = [
response.generated_text for response in responses
]
batch_logs = {
"prompts": batch["prompts"],
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
df = pd.DataFrame.from_dict(batch_logs)
wandb.log({"generation_predictions": wandb.Table(dataframe=df)})

responses = list(map(list, zip(*generated_sequences)))
flattened_gs_dict = {
f"generation_{reponse_id}": response
for reponse_id, response in enumerate(responses)
}
batch_logs = {
"prompts": batch["prompts"],
"ground_texts": batch["ground_texts"],
}
batch_logs = {**batch_logs, **flattened_gs_dict}
new_dataset.extend([dict(zip(batch_logs, t)) for t in zip(*batch_logs.values())])
table = wandb.Table(dataframe=pd.DataFrame(batch_logs))
wandb.log({"generation_predictions": table})
df = pd.DataFrame(new_dataset)
wandb.log({"dataframe_table": wandb.Table(dataframe=df)})
client.terminate_server()
wandb.finish()


Expand Down
37 changes: 37 additions & 0 deletions lib/style-transfer/style_transfer/gen_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

import mii
import torch
import wandb
import yaml
from omegaconf import omegaconf
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

with open("configs/default.yaml", "r") as f:
cfg = omegaconf.OmegaConf.create(yaml.safe_load(f))
with open("configs/gen.yaml", "r") as f:
cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.OmegaConf.create(yaml.safe_load(f)))

api = wandb.Api()
model_artifact = api.artifact(cfg.checkpoint)
model_dir = model_artifact.download()
model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = model.merge_and_unload()
model.save_pretrained("models/merged/")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
tokenizer.save_pretrained("models/merged/")
del model
del tokenizer
logging.info("Model + Tokenizer saved at models/merged/")
logging.info("Loading model to pipeline 🐉 ...")
client = mii.serve(
"models/merged/",
tensor_parallel=4,
deployment_name=cfg.checkpoint,
)
logging.info("Model loaded to pipeline ! 🎉")

0 comments on commit 07e92aa

Please sign in to comment.