+
+# Lightning-Hydra-Template
+
+[![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
+[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
+[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
+[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
+[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
+[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
+[![tests](https://github.com/ashleve/lightning-hydra-template/actions/workflows/test.yml/badge.svg)](https://github.com/ashleve/lightning-hydra-template/actions/workflows/test.yml)
+[![code-quality](https://github.com/ashleve/lightning-hydra-template/actions/workflows/code-quality-main.yaml/badge.svg)](https://github.com/ashleve/lightning-hydra-template/actions/workflows/code-quality-main.yaml)
+[![codecov](https://codecov.io/gh/ashleve/lightning-hydra-template/branch/main/graph/badge.svg)](https://codecov.io/gh/ashleve/lightning-hydra-template)
+[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/ashleve/lightning-hydra-template#license)
+[![PRs](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/ashleve/lightning-hydra-template/pulls)
+[![contributors](https://img.shields.io/github/contributors/ashleve/lightning-hydra-template.svg)](https://github.com/ashleve/lightning-hydra-template/graphs/contributors)
+
+A clean template to kickstart your deep learning project 🚀⚡🔥
+Click on [Use this template](https://github.com/ashleve/lightning-hydra-template/generate) to initialize new repository.
+
+_Suggestions are always welcome!_
+
+
+
+
+
+![](https://github.com/ashleve/lightning-hydra-template/blob/resources/terminal.png)
+
+
+
+## ⚡ Your Superpowers
+
+
+
+# Your Project Name
+
+
+
+
+
+[![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)
+[![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
+
+
+
+## Description
+
+What it does
+
+## Installation
+
+#### Pip
+
+```bash
+# clone project
+git clone https://github.com/YourGithubName/your-repo-name
+cd your-repo-name
+
+# [OPTIONAL] create conda environment
+conda create -n myenv python=3.9
+conda activate myenv
+
+# install pytorch according to instructions
+# https://pytorch.org/get-started/
+
+# install requirements
+pip install -r requirements.txt
+```
+
+#### Conda
+
+```bash
+# clone project
+git clone https://github.com/YourGithubName/your-repo-name
+cd your-repo-name
+
+# create conda environment and install dependencies
+conda env create -f environment.yaml -n myenv
+
+# activate conda environment
+conda activate myenv
+```
+
+## How to run
+
+Train model with default configuration
+
+```bash
+# train on CPU
+python gen/train.py trainer=cpu
+
+# train on GPU
+python gen/train.py trainer=gpu
+```
+
+Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
+
+```bash
+python gen/train.py experiment=experiment_name.yaml
+```
+
+You can override any parameter from command line like this
+
+```bash
+python gen/train.py trainer.max_epochs=20 data.batch_size=64
+```
diff --git a/lib/deft2023/gen/__init__.py b/lib/deft2023/gen/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/configs/__init__.py b/lib/deft2023/gen/configs/__init__.py
new file mode 100644
index 0000000..56bf7f4
--- /dev/null
+++ b/lib/deft2023/gen/configs/__init__.py
@@ -0,0 +1 @@
+# this file is needed here to include configs when building project as a package
diff --git a/lib/deft2023/gen/configs/callbacks/default.yaml b/lib/deft2023/gen/configs/callbacks/default.yaml
new file mode 100644
index 0000000..7f21121
--- /dev/null
+++ b/lib/deft2023/gen/configs/callbacks/default.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - model_checkpoint.yaml
+ - early_stopping.yaml
+ - model_summary.yaml
+ - rich_progress_bar.yaml
+ - _self_
+
+model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/exact_match_score"
+ mode: "max"
+ save_last: False
+ auto_insert_metric_name: False
+
+early_stopping:
+ monitor: "val/exact_match_score"
+ patience: 100
+ mode: "max"
+
+model_summary:
+ max_depth: -1
diff --git a/lib/deft2023/gen/configs/callbacks/early_stopping.yaml b/lib/deft2023/gen/configs/callbacks/early_stopping.yaml
new file mode 100644
index 0000000..c826c8d
--- /dev/null
+++ b/lib/deft2023/gen/configs/callbacks/early_stopping.yaml
@@ -0,0 +1,15 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
+
+early_stopping:
+ _target_: lightning.pytorch.callbacks.EarlyStopping
+ monitor: ??? # quantity to be monitored, must be specified !!!
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
+ patience: 3 # number of checks with no improvement after which training will be stopped
+ verbose: False # verbosity mode
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
diff --git a/lib/deft2023/gen/configs/callbacks/model_checkpoint.yaml b/lib/deft2023/gen/configs/callbacks/model_checkpoint.yaml
new file mode 100644
index 0000000..46f32a6
--- /dev/null
+++ b/lib/deft2023/gen/configs/callbacks/model_checkpoint.yaml
@@ -0,0 +1,17 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
+
+model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: null # directory to save the model file
+ filename: null # checkpoint filename
+ monitor: null # name of the logged metric which determines when model is improving
+ verbose: False # verbosity mode
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 1 # save k best models (determined by above metric)
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
+ save_weights_only: True # if True, then only the model’s weights will be saved
+ every_n_train_steps: null # number of training steps between checkpoints
+ train_time_interval: null # checkpoints are monitored at the specified time interval
+ every_n_epochs: null # number of epochs between checkpoints
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
diff --git a/lib/deft2023/gen/configs/callbacks/model_summary.yaml b/lib/deft2023/gen/configs/callbacks/model_summary.yaml
new file mode 100644
index 0000000..b75981d
--- /dev/null
+++ b/lib/deft2023/gen/configs/callbacks/model_summary.yaml
@@ -0,0 +1,5 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
+
+model_summary:
+ _target_: lightning.pytorch.callbacks.RichModelSummary
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
diff --git a/lib/deft2023/gen/configs/callbacks/none.yaml b/lib/deft2023/gen/configs/callbacks/none.yaml
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/configs/callbacks/rich_progress_bar.yaml b/lib/deft2023/gen/configs/callbacks/rich_progress_bar.yaml
new file mode 100644
index 0000000..de6f1cc
--- /dev/null
+++ b/lib/deft2023/gen/configs/callbacks/rich_progress_bar.yaml
@@ -0,0 +1,4 @@
+# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
+
+rich_progress_bar:
+ _target_: lightning.pytorch.callbacks.RichProgressBar
diff --git a/lib/deft2023/gen/configs/data/t5.yaml b/lib/deft2023/gen/configs/data/t5.yaml
new file mode 100644
index 0000000..49dfcd4
--- /dev/null
+++ b/lib/deft2023/gen/configs/data/t5.yaml
@@ -0,0 +1,6 @@
+_target_: gen.gen.data.t5_datamodule.T5DataModule
+num_workers: 0
+pin_memory: False
+seed: ${seed}
+tokenizer: ${model.model}
+batch_size: 16
diff --git a/lib/deft2023/gen/configs/debug/default.yaml b/lib/deft2023/gen/configs/debug/default.yaml
new file mode 100644
index 0000000..31b6208
--- /dev/null
+++ b/lib/deft2023/gen/configs/debug/default.yaml
@@ -0,0 +1,29 @@
+# @package _global_
+
+# default debugging setup, runs 1 full epoch
+# other debugging configs can inherit from this one
+
+# overwrite task name so debugging logs are stored in separate folder
+task_name: "debug"
+
+extras:
+ ignore_warnings: False
+ enforce_tags: False
+
+# sets level of all command line loggers to 'DEBUG'
+# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
+hydra:
+ job_logging:
+ root:
+ level: DEBUG
+
+ # use this to also set hydra loggers to 'DEBUG'
+ # verbose: True
+
+trainer:
+ max_epochs: 1
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
+
+data:
+ num_workers: 0 # debuggers don't like multiprocessing
+ pin_memory: False # disable gpu memory pin
diff --git a/lib/deft2023/gen/configs/debug/fdr.yaml b/lib/deft2023/gen/configs/debug/fdr.yaml
new file mode 100644
index 0000000..98eba22
--- /dev/null
+++ b/lib/deft2023/gen/configs/debug/fdr.yaml
@@ -0,0 +1,9 @@
+# @package _global_
+
+# runs 1 train, 1 validation and 1 test step
+
+defaults:
+ - default.yaml
+
+trainer:
+ fast_dev_run: true
diff --git a/lib/deft2023/gen/configs/debug/limit.yaml b/lib/deft2023/gen/configs/debug/limit.yaml
new file mode 100644
index 0000000..cc28852
--- /dev/null
+++ b/lib/deft2023/gen/configs/debug/limit.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# uses only 1% of the training data and 5% of validation/test data
+
+defaults:
+ - default.yaml
+
+trainer:
+ max_epochs: 3
+ limit_train_batches: 0.01
+ limit_val_batches: 0.05
+ limit_test_batches: 0.05
diff --git a/lib/deft2023/gen/configs/debug/overfit.yaml b/lib/deft2023/gen/configs/debug/overfit.yaml
new file mode 100644
index 0000000..d1f63e8
--- /dev/null
+++ b/lib/deft2023/gen/configs/debug/overfit.yaml
@@ -0,0 +1,13 @@
+# @package _global_
+
+# overfits to 3 batches
+
+defaults:
+ - default.yaml
+
+trainer:
+ max_epochs: 20
+ overfit_batches: 3
+
+# model ckpt and early stopping need to be disabled during overfitting
+callbacks: null
diff --git a/lib/deft2023/gen/configs/debug/profiler.yaml b/lib/deft2023/gen/configs/debug/profiler.yaml
new file mode 100644
index 0000000..e18df1c
--- /dev/null
+++ b/lib/deft2023/gen/configs/debug/profiler.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# runs with execution time profiling
+
+defaults:
+ - default.yaml
+
+trainer:
+ max_epochs: 1
+ profiler: "simple"
+ # profiler: "advanced"
+ # profiler: "pytorch"
diff --git a/lib/deft2023/gen/configs/eval.yaml b/lib/deft2023/gen/configs/eval.yaml
new file mode 100644
index 0000000..adc37ee
--- /dev/null
+++ b/lib/deft2023/gen/configs/eval.yaml
@@ -0,0 +1,19 @@
+# @package _global_
+
+defaults:
+ - _self_
+ - data: t5.yaml # choose datamodule with `test_dataloader()` for evaluation
+ - model: t5.yaml
+ - logger: null
+ - trainer: default.yaml
+ - paths: default.yaml
+ - callbacks: default.yaml
+ - extras: default.yaml
+ - hydra: default.yaml
+
+task_name: "eval"
+
+tags: ["dev"]
+
+# passing checkpoint path is necessary for evaluation
+ckpt_path: ???
diff --git a/lib/deft2023/gen/configs/experiment/default.yaml b/lib/deft2023/gen/configs/experiment/default.yaml
new file mode 100644
index 0000000..aa26c06
--- /dev/null
+++ b/lib/deft2023/gen/configs/experiment/default.yaml
@@ -0,0 +1,49 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: t5.yaml
+ - override /model: t5.yaml
+ - override /callbacks: default.yaml
+ - override /trainer: default.yaml
+ - override /hydra/sweeper: basic
+ - override /hydra/launcher: joblib
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["seq2seq"]
+
+seed: 42
+
+trainer:
+ min_epochs: 1
+ max_epochs: 10
+ gradient_clip_val: 0.5
+
+model:
+ model: "google/flan-t5-base"
+
+data:
+ batch_size: 24
+
+logger:
+ wandb:
+ tags: ["gen"]
+
+hydra:
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
+ launcher:
+ _target_: hydra_plugins.hydra_joblib_launcher.joblib_launcher.JoblibLauncher
+ n_jobs: 1
+ backend: loky
+ timeout: null
+ sweeper:
+ params:
+ model.model: choice("t5-large",
+ "t5-base",
+ "google/flan-t5-base",
+ "google/mt5-base",
+ "razent/SciFive-base-Pubmed_PMC")
diff --git a/lib/deft2023/gen/configs/experiment/large.yaml b/lib/deft2023/gen/configs/experiment/large.yaml
new file mode 100644
index 0000000..e995603
--- /dev/null
+++ b/lib/deft2023/gen/configs/experiment/large.yaml
@@ -0,0 +1,39 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: t5.yaml
+ - override /model: t5.yaml
+ - override /callbacks: default.yaml
+ - override /trainer: default.yaml
+ - override /hydra/sweeper: basic
+ - override /hydra/launcher: joblib
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["seq2seq"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ accumulate_grad_batches: 2
+
+model:
+ model: "t5-large"
+ optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 4e-5
+
+data:
+ batch_size: 8
+
+logger:
+ wandb:
+ tags: ["gen"]
diff --git a/lib/deft2023/gen/configs/experiment/xl.yaml b/lib/deft2023/gen/configs/experiment/xl.yaml
new file mode 100644
index 0000000..b8b6d43
--- /dev/null
+++ b/lib/deft2023/gen/configs/experiment/xl.yaml
@@ -0,0 +1,43 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: t5.yaml
+ - override /model: t5-xxl.yaml
+ - override /callbacks: default.yaml
+ - override /trainer: default.yaml
+ - override /hydra/sweeper: basic
+ - override /hydra/launcher: joblib
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["seq2seq"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ accumulate_grad_batches: 4
+
+model:
+ model: "google/flan-t5-xl"
+ lora_config:
+ _target_: peft.LoraConfig
+ r: 24
+ lora_alpha: 32
+ target_modules: [q, v]
+ lora_dropout: 0.05
+ bias: none
+ task_type: TaskType.SEQ_2_SEQ_LM
+
+data:
+ batch_size: 4
+
+logger:
+ wandb:
+ tags: ["gen"]
diff --git a/lib/deft2023/gen/configs/experiment/xxl.yaml b/lib/deft2023/gen/configs/experiment/xxl.yaml
new file mode 100644
index 0000000..705fc01
--- /dev/null
+++ b/lib/deft2023/gen/configs/experiment/xxl.yaml
@@ -0,0 +1,43 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: t5.yaml
+ - override /model: t5-xxl.yaml
+ - override /callbacks: default.yaml
+ - override /trainer: default.yaml
+ - override /hydra/sweeper: basic
+ - override /hydra/launcher: joblib
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["seq2seq"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ accumulate_grad_batches: 4
+
+model:
+ model: "philschmid/flan-t5-xxl-sharded-fp16"
+ lora_config:
+ _target_: peft.LoraConfig
+ r: 16
+ lora_alpha: 32
+ target_modules: [q, v]
+ lora_dropout: 0.05
+ bias: none
+ task_type: TaskType.SEQ_2_SEQ_LM
+
+data:
+ batch_size: 4
+
+logger:
+ wandb:
+ tags: ["gen"]
diff --git a/lib/deft2023/gen/configs/extras/default.yaml b/lib/deft2023/gen/configs/extras/default.yaml
new file mode 100644
index 0000000..b9c6b62
--- /dev/null
+++ b/lib/deft2023/gen/configs/extras/default.yaml
@@ -0,0 +1,8 @@
+# disable python warnings if they annoy you
+ignore_warnings: False
+
+# ask user for tags if none are provided in the config
+enforce_tags: True
+
+# pretty print config tree at the start of the run using Rich library
+print_config: True
diff --git a/lib/deft2023/gen/configs/hydra/default.yaml b/lib/deft2023/gen/configs/hydra/default.yaml
new file mode 100644
index 0000000..3533023
--- /dev/null
+++ b/lib/deft2023/gen/configs/hydra/default.yaml
@@ -0,0 +1,13 @@
+# https://hydra.cc/docs/configure_hydra/intro/
+
+# enable color logging
+defaults:
+ - override hydra_logging: colorlog
+ - override job_logging: colorlog
+
+# output directory, generated dynamically on each run
+run:
+ dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
+sweep:
+ dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
diff --git a/lib/deft2023/gen/configs/local/.gitkeep b/lib/deft2023/gen/configs/local/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/configs/logger/aim.yaml b/lib/deft2023/gen/configs/logger/aim.yaml
new file mode 100644
index 0000000..8f9f6ad
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/aim.yaml
@@ -0,0 +1,28 @@
+# https://aimstack.io/
+
+# example usage in lightning module:
+# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
+
+# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
+# `aim up`
+
+aim:
+ _target_: aim.pytorch_lightning.AimLogger
+ repo: ${paths.root_dir} # .aim folder will be created here
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
+
+ # aim allows to group runs under experiment name
+ experiment: null # any string, set to "default" if not specified
+
+ train_metric_prefix: "train/"
+ val_metric_prefix: "val/"
+ test_metric_prefix: "test/"
+
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
+
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
+ log_system_params: true
+
+ # enable/disable tracking console logs (default value is true)
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
diff --git a/lib/deft2023/gen/configs/logger/comet.yaml b/lib/deft2023/gen/configs/logger/comet.yaml
new file mode 100644
index 0000000..e078927
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/comet.yaml
@@ -0,0 +1,12 @@
+# https://www.comet.ml
+
+comet:
+ _target_: lightning.pytorch.loggers.comet.CometLogger
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
+ save_dir: "${paths.output_dir}"
+ project_name: "lightning-hydra-template"
+ rest_api_key: null
+ # experiment_name: ""
+ experiment_key: null # set to resume experiment
+ offline: False
+ prefix: ""
diff --git a/lib/deft2023/gen/configs/logger/csv.yaml b/lib/deft2023/gen/configs/logger/csv.yaml
new file mode 100644
index 0000000..fa028e9
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/csv.yaml
@@ -0,0 +1,7 @@
+# csv logger built in lightning
+
+csv:
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
+ save_dir: "${paths.output_dir}"
+ name: "csv/"
+ prefix: ""
diff --git a/lib/deft2023/gen/configs/logger/many_loggers.yaml b/lib/deft2023/gen/configs/logger/many_loggers.yaml
new file mode 100644
index 0000000..801444d
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/many_loggers.yaml
@@ -0,0 +1,9 @@
+# train with many loggers at once
+
+defaults:
+ # - comet.yaml
+ - csv.yaml
+ # - mlflow.yaml
+ # - neptune.yaml
+ - tensorboard.yaml
+ - wandb.yaml
diff --git a/lib/deft2023/gen/configs/logger/mlflow.yaml b/lib/deft2023/gen/configs/logger/mlflow.yaml
new file mode 100644
index 0000000..f8fb7e6
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/mlflow.yaml
@@ -0,0 +1,12 @@
+# https://mlflow.org
+
+mlflow:
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
+ # experiment_name: ""
+ # run_name: ""
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
+ tags: null
+ # save_dir: "./mlruns"
+ prefix: ""
+ artifact_location: null
+ # run_id: ""
diff --git a/lib/deft2023/gen/configs/logger/neptune.yaml b/lib/deft2023/gen/configs/logger/neptune.yaml
new file mode 100644
index 0000000..8233c14
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/neptune.yaml
@@ -0,0 +1,9 @@
+# https://neptune.ai
+
+neptune:
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
+ project: username/lightning-hydra-template
+ # name: ""
+ log_model_checkpoints: True
+ prefix: ""
diff --git a/lib/deft2023/gen/configs/logger/tensorboard.yaml b/lib/deft2023/gen/configs/logger/tensorboard.yaml
new file mode 100644
index 0000000..2bd31f6
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/tensorboard.yaml
@@ -0,0 +1,10 @@
+# https://www.tensorflow.org/tensorboard/
+
+tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.output_dir}/tensorboard/"
+ name: null
+ log_graph: False
+ default_hp_metric: True
+ prefix: ""
+ # version: ""
diff --git a/lib/deft2023/gen/configs/logger/wandb.yaml b/lib/deft2023/gen/configs/logger/wandb.yaml
new file mode 100644
index 0000000..ea6221f
--- /dev/null
+++ b/lib/deft2023/gen/configs/logger/wandb.yaml
@@ -0,0 +1,16 @@
+# https://wandb.ai
+
+wandb:
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # name: "" # name of the run (normally generated by wandb)
+ save_dir: "${paths.output_dir}"
+ offline: False
+ id: null # pass correct id to resume experiment!
+ anonymous: null # enable anonymous logging
+ project: "deft2023"
+ log_model: False # upload lightning ckpts
+ prefix: "" # a string to put at the beginning of metric keys
+ entity: "clinical-dream-team" # set to name of your wandb team
+ group: ""
+ tags: ["gen"]
+ job_type: ""
diff --git a/lib/deft2023/gen/configs/model/t5-xxl.yaml b/lib/deft2023/gen/configs/model/t5-xxl.yaml
new file mode 100644
index 0000000..a0c0b18
--- /dev/null
+++ b/lib/deft2023/gen/configs/model/t5-xxl.yaml
@@ -0,0 +1,20 @@
+_target_: gen.gen.models.t5_xxl_module.T5XllModule
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 2e-5
+ weight_decay: 0.01
+
+lora_config:
+ _target_: peft.LoraConfig
+ r: 24
+ lora_alpha: 32
+ target_modules: [q, v]
+ lora_dropout: 0.05
+ bias: none
+ task_type: TaskType.SEQ_2_SEQ_LM
+
+scheduler: null
+
+model: "t5-small"
diff --git a/lib/deft2023/gen/configs/model/t5.yaml b/lib/deft2023/gen/configs/model/t5.yaml
new file mode 100644
index 0000000..a0221a8
--- /dev/null
+++ b/lib/deft2023/gen/configs/model/t5.yaml
@@ -0,0 +1,11 @@
+_target_: gen.gen.models.t5_module.T5Module
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 2e-5
+ weight_decay: 0.01
+
+scheduler: null
+
+model: "t5-small"
diff --git a/lib/deft2023/gen/configs/paths/default.yaml b/lib/deft2023/gen/configs/paths/default.yaml
new file mode 100644
index 0000000..ec81db2
--- /dev/null
+++ b/lib/deft2023/gen/configs/paths/default.yaml
@@ -0,0 +1,18 @@
+# path to root directory
+# this requires PROJECT_ROOT environment variable to exist
+# you can replace it with "." if you want the root to be the current working directory
+root_dir: ${oc.env:PROJECT_ROOT}
+
+# path to data directory
+data_dir: ${paths.root_dir}/data/
+
+# path to logging directory
+log_dir: ${paths.root_dir}/logs/
+
+# path to output directory, created dynamically by hydra
+# path generation pattern is specified in `configs/hydra/default.yaml`
+# use it to store all files generated during the run, like ckpts and metrics
+output_dir: ${hydra:runtime.output_dir}
+
+# path to working directory
+work_dir: ${hydra:runtime.cwd}
diff --git a/lib/deft2023/gen/configs/train.yaml b/lib/deft2023/gen/configs/train.yaml
new file mode 100644
index 0000000..86b4d60
--- /dev/null
+++ b/lib/deft2023/gen/configs/train.yaml
@@ -0,0 +1,52 @@
+# @package _global_
+
+# specify here default configuration
+# order of defaults determines the order in which configs override each other
+defaults:
+ - _self_
+ - data: t5.yaml
+ - model: t5.yaml
+ - callbacks: default.yaml
+ - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
+ - trainer: default.yaml
+ - paths: default.yaml
+ - extras: default.yaml
+ - hydra: default.yaml
+
+ # experiment configs allow for version control of specific hyperparameters
+ # e.g. best hyperparameters for given model and datamodule
+ - experiment: null
+
+ # config for hyperparameter optimization
+ - hparams_search: null
+
+ # optional local config for machine/user specific settings
+ # it's optional since it doesn't need to exist and is excluded from version control
+ - optional local: default.yaml
+
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
+ - debug: null
+
+# task name, determines output directory path
+task_name: "train"
+
+# tags to help you identify your experiments
+# you can overwrite this in experiment configs
+# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
+tags: ["dev"]
+
+# set False to skip model training
+train: True
+
+# evaluate on test set, using best model weights achieved during training
+# lightning chooses best weights based on the metric specified in checkpoint callback
+test: True
+
+# compile model for faster training with pytorch 2.0
+compile: False
+
+# simply provide checkpoint path to resume training
+ckpt_path: null
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: 42
diff --git a/lib/deft2023/gen/configs/trainer/cpu.yaml b/lib/deft2023/gen/configs/trainer/cpu.yaml
new file mode 100644
index 0000000..640f71d
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/cpu.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default.yaml
+
+accelerator: cpu
+devices: 1
diff --git a/lib/deft2023/gen/configs/trainer/ddp.yaml b/lib/deft2023/gen/configs/trainer/ddp.yaml
new file mode 100644
index 0000000..96bef39
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/ddp.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - default.yaml
+
+strategy: ddp
+
+accelerator: gpu
+devices: 4
+num_nodes: 1
+sync_batchnorm: True
diff --git a/lib/deft2023/gen/configs/trainer/ddp_sim.yaml b/lib/deft2023/gen/configs/trainer/ddp_sim.yaml
new file mode 100644
index 0000000..42626be
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/ddp_sim.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - default.yaml
+
+# simulate DDP on CPU, useful for debugging
+accelerator: cpu
+devices: 2
+strategy: ddp_spawn
diff --git a/lib/deft2023/gen/configs/trainer/default.yaml b/lib/deft2023/gen/configs/trainer/default.yaml
new file mode 100644
index 0000000..50905e7
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/default.yaml
@@ -0,0 +1,19 @@
+_target_: lightning.pytorch.trainer.Trainer
+
+default_root_dir: ${paths.output_dir}
+
+min_epochs: 1 # prevents early stopping
+max_epochs: 10
+
+accelerator: cpu
+devices: 1
+
+# mixed precision for extra speed-up
+# precision: 16
+
+# perform a validation loop every N training epochs
+check_val_every_n_epoch: 1
+
+# set True to to ensure deterministic results
+# makes training slower but gives more reproducibility than just setting seeds
+deterministic: False
diff --git a/lib/deft2023/gen/configs/trainer/gpu.yaml b/lib/deft2023/gen/configs/trainer/gpu.yaml
new file mode 100644
index 0000000..d5e5773
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/gpu.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default.yaml
+
+accelerator: gpu
+devices: 1
diff --git a/lib/deft2023/gen/configs/trainer/mps.yaml b/lib/deft2023/gen/configs/trainer/mps.yaml
new file mode 100644
index 0000000..73d2cdd
--- /dev/null
+++ b/lib/deft2023/gen/configs/trainer/mps.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default.yaml
+
+accelerator: mps
+devices: 1
diff --git a/lib/deft2023/gen/data/.gitkeep b/lib/deft2023/gen/data/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/gen/__init__.py b/lib/deft2023/gen/gen/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/gen/data/__init__.py b/lib/deft2023/gen/gen/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/gen/data/components/__init__.py b/lib/deft2023/gen/gen/data/components/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/gen/data/t5_datamodule.py b/lib/deft2023/gen/gen/data/t5_datamodule.py
new file mode 100644
index 0000000..27e9825
--- /dev/null
+++ b/lib/deft2023/gen/gen/data/t5_datamodule.py
@@ -0,0 +1,183 @@
+from typing import Any, Dict, Optional
+
+import torch
+from commons.data.load_and_preprocess_dataset import load_and_preprocess_dataset
+from lightning import LightningDataModule
+from lightning.pytorch.utilities.types import EVAL_DATALOADERS
+from torch.utils.data import DataLoader, Dataset
+from transformers import AutoTokenizer
+
+
+class T5DataModule(LightningDataModule):
+ """Example of LightningDataModule for MNIST dataset.
+
+ A DataModule implements 6 key methods:
+ def prepare_data(self):
+ # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
+ # download data, pre-process, split, save to disk, etc...
+ def setup(self, stage):
+ # things to do on every process in DDP
+ # load data, set variables, etc...
+ def train_dataloader(self):
+ # return train dataloader
+ def val_dataloader(self):
+ # return validation dataloader
+ def test_dataloader(self):
+ # return test dataloader
+ def teardown(self):
+ # called on every process in DDP
+ # clean up after fit or test
+
+ This allows you to share a full dataset without explaining how to download,
+ split, transform and process the data.
+
+ Read the docs:
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
+ """
+
+ def __init__(
+ self,
+ num_workers: int = 0,
+ batch_size: int = 8,
+ pin_memory: bool = False,
+ seed: int = 42,
+ tokenizer: str = "",
+ ):
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ self.dataset: Dataset = None
+ # data transformations
+
+ self.data_train: Optional[Dataset] = None
+ self.data_val: Optional[Dataset] = None
+ self.data_test: Optional[Dataset] = None
+ self.data_predict: Optional[Dataset] = None
+ self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.tokenizer)
+
+ def setup(self, stage: Optional[str] = None):
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
+ careful not to execute things like random split twice!
+ """
+ # load and split datasets only if not loaded already
+ (
+ self.data_train,
+ self.data_val,
+ self.data_test,
+ self.data_predict,
+ ) = load_and_preprocess_dataset()
+
+ def train_dataloader(self):
+ return DataLoader(
+ collate_fn=self.collate_fn,
+ dataset=self.data_train,
+ batch_size=self.hparams.batch_size,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ shuffle=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ collate_fn=self.collate_fn,
+ dataset=self.data_val,
+ batch_size=self.hparams.batch_size,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ shuffle=False,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ collate_fn=self.collate_fn,
+ dataset=self.data_test,
+ batch_size=self.hparams.batch_size,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ shuffle=False,
+ )
+
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
+ return DataLoader(
+ collate_fn=self.collate_fn,
+ dataset=self.data_test,
+ batch_size=1,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ shuffle=False,
+ )
+
+ def teardown(self, stage: Optional[str] = None):
+ """Clean up after fit or test."""
+ pass
+
+ def state_dict(self):
+ """Extra things to save to checkpoint."""
+ return {}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]):
+ """Things to do when loading checkpoint."""
+ pass
+
+ def collate_fn(self, batch):
+ prompts = []
+ responses = []
+ labels = []
+ for point in batch:
+ prompt = (
+ "Please find the right answers between the possible answers "
+ "to the following question.\n"
+ "Question: {q}\n"
+ "Possible answers:\n"
+ "- a: {a} \n"
+ "- b: {b} \n"
+ "- c: {c} \n"
+ "- d: {d} \n"
+ "- e: {e}"
+ )
+ prompt = prompt.format(
+ q=point["question"],
+ a=point["answer_a"],
+ b=point["answer_b"],
+ c=point["answer_c"],
+ d=point["answer_d"],
+ e=point["answer_e"],
+ )
+ response = ",".join(
+ [["a", "b", "c", "d", "e"][response] for response in point["correct_answers"]]
+ )
+
+ label = [0.0] * 5
+
+ for answer_id in point["correct_answers"]:
+ label[answer_id] = 1.0
+
+ prompts.append(prompt)
+ responses.append(response)
+ labels.append(label)
+
+ # Tokenize prompts with local padding and set max_length
+ tokenized_prompts = self.tokenizer(
+ prompts,
+ padding="longest",
+ truncation=True,
+ )
+ tokenized_responses = self.tokenizer(responses, padding="longest")
+
+ return (
+ {
+ "input_ids": torch.tensor(tokenized_prompts["input_ids"]),
+ "attention_mask": torch.tensor(tokenized_prompts["attention_mask"]),
+ "labels": torch.tensor(tokenized_responses["input_ids"]),
+ },
+ {"labels_text": responses, "input_ids_text": prompts, "ground_truth": labels},
+ )
+
+
+if __name__ == "__main__":
+ _ = T5DataModule()
diff --git a/lib/deft2023/gen/gen/eval.py b/lib/deft2023/gen/gen/eval.py
new file mode 100644
index 0000000..ce3db04
--- /dev/null
+++ b/lib/deft2023/gen/gen/eval.py
@@ -0,0 +1,116 @@
+import os
+import sys
+import time
+from typing import List, Tuple
+
+import hydra
+import pyrootutils
+import wandb
+from lightning import LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+sys.path.append(parent_dir)
+from commons.submission.to_txt import to_txt # noqa: E402
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from gen import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+from gen.gen import utils # noqa: E402
+
+log = utils.get_pylogger(__name__)
+
+
+@utils.task_wrapper
+def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Evaluates given checkpoint on a datamodule testset.
+
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """
+
+ assert cfg.ckpt_path
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ log.info("Starting testing!")
+ # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
+
+ predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
+ run_name = "fdr" if logger == [] else logger[0].name
+ submission_file_path = f"submissions/submission-gen_{run_name}_{timestamp}.txt"
+ id2label = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
+ to_txt(
+ predictions,
+ datamodule.data_predict,
+ submission_file_path,
+ id2label=id2label,
+ )
+ # log submission file to wandb
+ artifact = wandb.Artifact("submission-file", type="submission")
+ artifact.add_file(submission_file_path)
+ wandb.log_artifact(artifact)
+ # for predictions use trainer.predict(...)
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
+
+ metric_dict = trainer.callback_metrics
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
+def main(cfg: DictConfig) -> None:
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ utils.extras(cfg)
+
+ evaluate(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lib/deft2023/gen/gen/models/__init__.py b/lib/deft2023/gen/gen/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/deft2023/gen/gen/models/backbone.py b/lib/deft2023/gen/gen/models/backbone.py
new file mode 100644
index 0000000..33e95bc
--- /dev/null
+++ b/lib/deft2023/gen/gen/models/backbone.py
@@ -0,0 +1,147 @@
+import re
+from typing import Any
+
+import numpy as np
+import torch
+from commons.metrics.exact_match_ratio import exact_match_ratio
+from commons.metrics.hamming_score import hamming_score
+from lightning import LightningModule
+from torchmetrics import MeanMetric
+
+
+class Backbone(LightningModule):
+ """Example of LightningModule for MNIST classification.
+
+ A LightningModule organizes your PyTorch code into 6 sections:
+ - Initialization (__init__)
+ - Train Loop (training_step)
+ - Validation loop (validation_step)
+ - Test loop (test_step)
+ - Prediction Loop (predict_step)
+ - Optimizers and LR Schedulers (configure_optimizers)
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ # loss function
+ self.criterion = torch.nn.CrossEntropyLoss()
+ # for averaging loss across batches
+ self.train_loss = MeanMetric()
+ self.val_loss = MeanMetric()
+ self.test_loss = MeanMetric()
+
+ def forward(self, x: dict):
+ return self.model(**x)
+
+ def on_train_start(self):
+ # by default lightning executes validation step sanity checks before training starts,
+ # so it's worth to make sure validation metrics don't store results from these checks
+ self.val_loss.reset()
+
+ def model_step(self, batch: Any):
+ input_model, input_text = batch
+ output = self.forward(input_model)
+ loss = output.loss
+ generated_text = self.model.generate(
+ input_ids=input_model["input_ids"],
+ attention_mask=input_model["attention_mask"],
+ max_length=32,
+ num_beams=4,
+ early_stopping=True,
+ )
+ generated_text = self.tokenizer.batch_decode(generated_text, skip_special_tokens=True)
+ return (loss, input_text, generated_text)
+
+ @staticmethod
+ def is_valid_output(output_string):
+ return bool(re.match(r"^[a-z](,[a-z])*$|^[a-z]$", output_string))
+
+ def training_step(self, batch: Any, batch_idx: int):
+ loss, input_text, generated_text = self.model_step(batch)
+ # update and log metrics
+ self.train_loss(loss)
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
+ # return loss or backpropagation will fail
+ return loss
+
+ def on_train_epoch_end(self):
+ pass
+
+ def validation_step(self, batch: Any, batch_idx: int):
+ # batch ({inputs, attention, label}, {label_text ...})
+ loss, input_text, generated_text = self.model_step(batch)
+ preds, targets = self.text_to_one_hot(generated_text, input_text)
+ # update and log metrics
+ self.val_loss(loss)
+ self.log_dict(
+ {
+ "val/hamming_score": hamming_score(np.array(preds), np.array(targets)),
+ "val/exact_match_score": exact_match_ratio(preds, targets),
+ },
+ on_step=False,
+ on_epoch=True,
+ prog_bar=True,
+ )
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
+
+ def text_to_one_hot(self, generated_text, input_text):
+ preds = []
+ for text in generated_text:
+ if not self.is_valid_output(text):
+ ValueError("Badly formatted")
+ predicted_answers_letters = text.split(",")
+ predicted_answers_letters = [s.strip() for s in predicted_answers_letters]
+ p = [int(letter in predicted_answers_letters) for letter in "abcde"]
+ preds.append(p)
+ targets = input_text["ground_truth"]
+ return preds, targets
+
+ def test_step(self, batch: Any, batch_idx: int):
+ loss, input_text, generated_text = self.model_step(batch)
+ preds, targets = self.text_to_one_hot(generated_text, input_text)
+ # update and log metrics
+ self.test_loss(loss)
+ self.log_dict(
+ {
+ "test/hamming_score": hamming_score(np.array(preds), np.array(targets)),
+ "test/exact_match_score": exact_match_ratio(preds, targets),
+ },
+ on_step=False,
+ on_epoch=True,
+ prog_bar=True,
+ )
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ _, input_text, generated_text = self.model_step(batch)
+ preds, targets = self.text_to_one_hot(generated_text, input_text)
+ return preds
+
+ def on_test_epoch_end(self):
+ pass
+
+ def configure_optimizers(self):
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+
+ Examples:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
+ """
+ optimizer = self.hparams.optimizer(params=self.parameters())
+ if self.hparams.scheduler is not None:
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "monitor": "val/loss",
+ "interval": "epoch",
+ "frequency": 1,
+ },
+ }
+ return {"optimizer": optimizer}
diff --git a/lib/deft2023/gen/gen/models/t5_module.py b/lib/deft2023/gen/gen/models/t5_module.py
new file mode 100644
index 0000000..0bc128b
--- /dev/null
+++ b/lib/deft2023/gen/gen/models/t5_module.py
@@ -0,0 +1,31 @@
+import torch
+from gen.gen.models.backbone import Backbone
+from transformers import AutoTokenizer, T5ForConditionalGeneration
+
+
+class T5Module(Backbone):
+ """Example of LightningModule for MNIST classification.
+
+ A LightningModule organizes your PyTorch code into 6 sections:
+ - Initialization (__init__)
+ - Train Loop (training_step)
+ - Validation loop (validation_step)
+ - Test loop (test_step)
+ - Prediction Loop (predict_step)
+ - Optimizers and LR Schedulers (configure_optimizers)
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, model: str
+ ):
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ self.model = T5ForConditionalGeneration.from_pretrained(model)
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
diff --git a/lib/deft2023/gen/gen/models/t5_xxl_module.py b/lib/deft2023/gen/gen/models/t5_xxl_module.py
new file mode 100644
index 0000000..d89de75
--- /dev/null
+++ b/lib/deft2023/gen/gen/models/t5_xxl_module.py
@@ -0,0 +1,46 @@
+import peft
+import torch
+from gen.gen.models.backbone import Backbone
+from peft import get_peft_model, prepare_model_for_int8_training
+from transformers import AutoTokenizer, T5ForConditionalGeneration
+
+
+class T5XllModule(Backbone):
+ """Example of LightningModule for MNIST classification.
+
+ A LightningModule organizes your PyTorch code into 6 sections:
+ - Initialization (__init__)
+ - Train Loop (training_step)
+ - Validation loop (validation_step)
+ - Test loop (test_step)
+ - Prediction Loop (predict_step)
+ - Optimizers and LR Schedulers (configure_optimizers)
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ lora_config: peft.LoraConfig,
+ model: str,
+ ):
+ super().__init__()
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+ # Define LoRA Config
+
+ lora_config = self.hparams.lora_config
+ # prepare int-8 model for training
+ model_id = model
+ self.model = T5ForConditionalGeneration.from_pretrained(
+ model_id, load_in_8bit=True, device_map="auto"
+ )
+
+ self.model = prepare_model_for_int8_training(self.model)
+ self.model = get_peft_model(self.model, lora_config)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
diff --git a/lib/deft2023/gen/gen/train.py b/lib/deft2023/gen/gen/train.py
new file mode 100644
index 0000000..7abcb11
--- /dev/null
+++ b/lib/deft2023/gen/gen/train.py
@@ -0,0 +1,150 @@
+import os
+import sys
+import time
+from typing import List, Optional, Tuple
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+import wandb
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+sys.path.append(parent_dir)
+from commons.submission.to_txt import to_txt # noqa: E402
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from gen import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+from gen.gen import utils # noqa: E402
+
+log = utils.get_pylogger(__name__)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("compile"):
+ log.info("Compiling model!")
+ model = torch.compile(model)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = None
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
+ run_name = "fdr" if logger == [] else logger[0].name
+ submission_file_path = f"submissions/submission-gen_{run_name}_{timestamp}.txt"
+ id2label = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4}
+ to_txt(
+ predictions,
+ datamodule.data_predict,
+ submission_file_path,
+ id2label=id2label,
+ )
+ # log submission file to wandb
+ artifact = wandb.Artifact("submission-file", type="submission")
+ artifact.add_file(submission_file_path)
+ wandb.log_artifact(artifact)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
+def main(cfg: DictConfig) -> Optional[float]:
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ utils.extras(cfg)
+
+ # train the model
+ metric_dict, _ = train(cfg)
+
+ # safely retrieve metric value for hydra-based hyperparameter optimization
+ metric_value = utils.get_metric_value(
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
+ )
+
+ # return optimized metric
+ return metric_value
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lib/deft2023/gen/gen/utils/__init__.py b/lib/deft2023/gen/gen/utils/__init__.py
new file mode 100644
index 0000000..92485a0
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/__init__.py
@@ -0,0 +1,5 @@
+from gen.gen.utils.instantiators import instantiate_callbacks, instantiate_loggers
+from gen.gen.utils.logging_utils import log_hyperparameters
+from gen.gen.utils.pylogger import get_pylogger
+from gen.gen.utils.rich_utils import enforce_tags, print_config_tree
+from gen.gen.utils.utils import extras, get_metric_value, task_wrapper
diff --git a/lib/deft2023/gen/gen/utils/instantiators.py b/lib/deft2023/gen/gen/utils/instantiators.py
new file mode 100644
index 0000000..02c2e2d
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/instantiators.py
@@ -0,0 +1,49 @@
+from typing import List
+
+import hydra
+from gen.gen.utils import pylogger
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+log = pylogger.get_pylogger(__name__)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/lib/deft2023/gen/gen/utils/logging_utils.py b/lib/deft2023/gen/gen/utils/logging_utils.py
new file mode 100644
index 0000000..151925c
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/logging_utils.py
@@ -0,0 +1,49 @@
+from gen.gen.utils import pylogger
+from lightning.pytorch.utilities import rank_zero_only
+
+log = pylogger.get_pylogger(__name__)
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/lib/deft2023/gen/gen/utils/pylogger.py b/lib/deft2023/gen/gen/utils/pylogger.py
new file mode 100644
index 0000000..62176ca
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/pylogger.py
@@ -0,0 +1,17 @@
+import logging
+
+from lightning.pytorch.utilities import rank_zero_only
+
+
+def get_pylogger(name=__name__) -> logging.Logger:
+ """Initializes multi-GPU-friendly python command line logger."""
+
+ logger = logging.getLogger(name)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ for level in logging_levels:
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
diff --git a/lib/deft2023/gen/gen/utils/rich_utils.py b/lib/deft2023/gen/gen/utils/rich_utils.py
new file mode 100644
index 0000000..8a07e43
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/rich_utils.py
@@ -0,0 +1,97 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from gen.gen.utils import pylogger
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+log = pylogger.get_pylogger(__name__)
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config
+ components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ queue.append(field) if field in cfg else log.warning(
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config."""
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/lib/deft2023/gen/gen/utils/utils.py b/lib/deft2023/gen/gen/utils/utils.py
new file mode 100644
index 0000000..326b7b7
--- /dev/null
+++ b/lib/deft2023/gen/gen/utils/utils.py
@@ -0,0 +1,113 @@
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+from gen.gen.utils import pylogger, rich_utils
+from omegaconf import DictConfig
+
+log = pylogger.get_pylogger(__name__)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found!