Skip to content

Commit

Permalink
✨ Style Transfer: add trainable distilbert code inside score.py + hyd…
Browse files Browse the repository at this point in the history
…ra flags
  • Loading branch information
simonmeoni committed Feb 28, 2024
1 parent 38196e3 commit 6222c20
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
2 changes: 2 additions & 0 deletions lib/style-transfer/configs/score.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ sem_model:
_partial_: true
use_ground_truth: false
is_logged: false
is_trainable: true
checkpoint: null
37 changes: 21 additions & 16 deletions lib/style-transfer/style_transfer/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def main(cfg):

logging.info("Loading the Semantic Model 🐈‍")
sem_model = SentenceTransformer(cfg.sem_model.name)
if cfg.sem_model.checkpoint:
model_artifact = run.use_artifact(cfg.sem_model.checkpoint)
model_dir = model_artifact.download()
sem_model = sem_model.load(model_dir)

def sem_score(cfg, dataset, sem_model):
score_dict = {}
Expand Down Expand Up @@ -95,25 +99,26 @@ def add_prompt(data_point):
)
]
)
if cfg.sem_model.is_trainable:
if cfg.sem_model.use_ground_truth:
train_examples.extend(
[
InputExample(texts=[ground_text, ground_text], label=1)
for ground_text in train_gen_dataset["ground_texts"]
]
)
train_gen_dataloader = torch.utils.data.DataLoader(
train_examples,
batch_size=cfg.sem_model.batch_size,
)

if cfg.sem_model.use_ground_truth:
train_examples.extend(
[
InputExample(texts=[ground_text, ground_text], label=1)
for ground_text in train_gen_dataset["ground_texts"]
]
train_loss = hydra.utils.instantiate(cfg.sem_model.loss, sem_model)()
sem_model.fit(
train_objectives=[(train_gen_dataloader, train_loss)],
epochs=cfg.sem_model.epochs,
warmup_steps=cfg.sem_model.warmup_steps,
)
train_gen_dataloader = torch.utils.data.DataLoader(
train_examples,
batch_size=cfg.sem_model.batch_size,
)

train_loss = hydra.utils.instantiate(cfg.sem_model.loss, sem_model)()
sem_model.fit(
train_objectives=[(train_gen_dataloader, train_loss)],
epochs=cfg.sem_model.epochs,
warmup_steps=cfg.sem_model.warmup_steps,
)
if cfg.sem_model.is_logged:
sem_model.save(cfg.sem_model.path)
run.log_artifact(cfg.sem_model.path, type="model")
Expand Down

0 comments on commit 6222c20

Please sign in to comment.