Skip to content

Commit

Permalink
Simplify api (#24)
Browse files Browse the repository at this point in the history
* fix(ilql): sampling on variable sized prompts & stage simplified api

* Save strategy (#23)

* Had to add py_modules=trlx to setup.

* Added a save strategy.

* Cleaned up a few things.

* Added save_steps to ilql_config.yaml and save steps strategy to accelerate_ilql_model.py for consistency. The save_steps parameter must be set now because of how TrainConfig.from_dict operates. If not save_steps parameter is given in the configs it throws an error.

* Adding mininal changes to enable step based save strategy in configs/ppo_config.yml, trlx/data/configs.py, and trlx/model_accelerate_ppo_model.py

* Some problems crept in despite merge check. This fixes them.

* Realized I am merging into stage-api not main so fixed an issue with ilql_config.yml

* fix(ilql): eval on a set of betas & add simple timers

* fix: saving checkpoints

* refactor(ilql): subsume under base_model

* fix(ilql): mask prompts

* merge hydra

* fix(ppo): generalize and stage for api

* feat: add architext examples

* fix(ppo,ilql): ddp + accelerate

* refactor: clean pipelines

* feat: add simulacra example

* fix(ppo): single token prompting

* refactor: fully merge models

* refactor(configs): lower batch_sizes & remove dead entries

* refactor(examples): update for new api

* fix(tests,style): one way to pass tests is to change them

* fix(ppo): warnings of the most recent version of transformers

4.23.1 complains if .generate() starts with single bos token, when bos=eos=pad token

* refactor(readme): add api

* chore: add doc strings

* fix: remove dropout

* chore: keep gpt2 small in examples

* chore: revert to previous default configs

* chore(docs): rename classes, remove unused, add examples

* chore(readme): add contributing.md & deepspeed note

* style(readme): US spelling

* chore(examples): add explanations for each task
  • Loading branch information
maxreciprocate authored Oct 21, 2022
1 parent 4ff712b commit 06cd30f
Show file tree
Hide file tree
Showing 30 changed files with 1,088 additions and 1,253 deletions.
155 changes: 24 additions & 131 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,152 +1,45 @@
[docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest
[docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest

# Welcome to Transformer Reinforcement Learning X (`trlX`)
> A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
# Transformer Reinforcement Learning X

[![Docs Status][docs-image]][docs-url]
`trlx` allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented.

**[Documentation](https://trlX.readthedocs.io)**
## Train

## Overview
Inspired by the popular `trl` library, the `trlX` repo allows you to fine-tune Huggingface supported language models up to 20B parameters via either reinforcement learning using a provided scoring function or reward-labeled dataset. We aim to support a range of both online and offline RL algorithms including Proximal Policy Optimization (PPO), Natural Language Policy Optimization (NLPO), Actor Critic (A2C), and Implicit Q Learning (ILQL).

The library supports `gpt2` and `gptj` with plans to include `GPT-NeoX`, `T5` and more. PPO and ILQL algorithms are implemented. Disibtributed training has been implemented via HF Accelerate and tested up to two nodes, each with 8 gpus.

## Structure

The training pipeline is broken into four pieces:
```python
import trlx

- Prompt pipeline: Handles loading of prompts/text used to prompt model for exploration in online methods
- Rollout pipeline: Handles loading and storage of reward labeled data used
- Orchestrator: Handles exploration/rollout collection of online methods. Pushes collected rollouts to the rollout pipeline.
- Model: Wraps the supplied base model (ex: `gpt2`) and implements the desired training method loss (ex: PPO).
# optimize some reward function
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])

Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class.
# or steer a model with a collection of rated samples
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])

## Installation
```bash
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install -e ".[dev]"
pre-commit install # see .pre-commit-config.yaml
# model is a wrapper with some logit preprocessing
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
```

## Example: How to add a task

In the below we implement a sentiment learning task.

### Configure `accelerate`
Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested)

```bash
accelerate config
accelerate launch examples/simulacra.py
```

### Implement a prompt pipeline

```python
@register_datapipeline
class PPOPipeline(BasePipeline):
def __init__(self, tokenizer, config, prompt_dataset_path=None):
super().__init__()

ds = load_dataset("imdb", split="test")
ds = ds.rename_columns({"text": "review", "label": "sentiment"})
ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False)

self.tokens = [
tokenizer(
text,
truncation=True,
padding="max_length",
max_length=config.train.input_size,
return_tensors="pt",
)["input_ids"]
.long()
.flatten()
for text in ds["review"]
]
self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens]

def __getitem__(self, index: int) -> PromptElement:
return PromptElement(self.text[index], self.tokens[index])

def __len__(self) -> int:
return len(self.text)

def create_loader(
self,
batch_size: int,
shuffle: bool,
prep_fn: Callable = None,
num_workers: int = 0,
) -> DataLoader:
# TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
def collate_fn(elems: Iterable[PromptElement]) -> PromptElement:
return PromptBatch(
[elem.text for elem in elems],
torch.stack(
[elem.tokens for elem in elems]
), # Assumes token tensors all same size
)
For more usage see [examples](./examples)

return DataLoader(
self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers
)
```

### Launch training

```python
from typing import List

import torch
from transformers import pipeline

import wandb
from trlx.data.configs import TRLConfig
from trlx.model.accelerate_ppo_model import AcceleratePPOModel
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator
from trlx.pipeline.ppo_pipeline import PPOPipeline
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline

if __name__ == "__main__":
cfg = TRLConfig.load_yaml("configs/ppo_config.yml")

sentiment_pipe = pipeline(
"sentiment-analysis", "lvwerra/distilbert-imdb", device=-1
)

def reward_fn(samples: List[str]):
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": None,
"batch_size": cfg.method.chunk_size,
}
pipe_outputs = sentiment_pipe(samples, **sent_kwargs)
scores = torch.tensor([output[1]["score"] for output in pipe_outputs])
return scores

model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg)
if model.accelerator.is_main_process:
wandb.watch(model.model)

pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg)
orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)(
model, pipeline, reward_fn=reward_fn, chunk_size=cfg.method.chunk_size
)
orch.make_experience(cfg.method.num_rollouts)
model.learn()

print("DONE!")
## Install
```bash
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda
pip install -e .
```

And run `accelerate launch my_script.py`

## References
For development check out these [guidelines](./CONTRIBUTING.md)
and also read our [docs](https://trlX.readthedocs.io)

### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
## Acknowledgements

### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face.
Thanks Leandro for starting the original [trl](https://github.com/lvwerra/trl/)
40 changes: 17 additions & 23 deletions configs/ilql_config.yml
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
model:
model_path : "gpt2"
model_type : "ILQLModel"
device : "cuda"
tokenizer_path: "gpt2"
model_type : "ILQLModel"
num_layers_unfrozen: -1

train:
n_ctx : 512
epochs : 1
total_steps : 80000
batch_size : 80
grad_clip : 1.0
seq_length: 64
batch_size: 128
epochs: 10
total_steps: 10000

lr_ramp_steps : 100
lr_decay_steps : 3366
weight_decay : 1.0e-6
learning_rate_init : 1.0e-3
learning_rate_target : 1.0e-3
lr_ramp_steps: 100
lr_decay_steps: 3366
weight_decay: 1e-6
learning_rate_init: 1e-4
learning_rate_target: 1e-4
opt_betas: [0.9, 0.95]

log_interval : 25
checkpoint_interval : 100
eval_interval : 50

input_size: 1
gen_size: 32
checkpoint_interval: 1000
eval_interval: 16

pipeline : "OfflinePipeline"
orchestrator : "OfflineOrchestrator"

accelerate : true
seed: 1000

method:
name: "ilqlconfig"
tau: 0.7
gamma: 0.99
cql_scale: 0.1
awac_scale: 1
alpha: 1
steps_for_target_q_sync: 10
beta: 4
alpha: 0.005
steps_for_target_q_sync: 1
betas: [16]
two_qs: true
80 changes: 36 additions & 44 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
@@ -1,52 +1,44 @@
model:
model_path : "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path : "gpt2" # Name of hf tokenizer to load
model_type : "AcceleratePPOModel" # Name of accelerate model type to load
device : "cuda" # Train device
num_layers_unfrozen : 2 # Number of bottom layers to freeze during training
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path: "gpt2" # Name of hf tokenizer to load
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training

train:
n_ctx : 512 # Size of LM context
epochs : 10 # Train for max(epochs, total_steps)
total_steps : 80000 # Train for max(epochs, total_steps)
batch_size : 128 # batch size
grad_clip : 1.0 # gradient clipping threshold
seq_length: 48 # Size of LM context
epochs: 1000 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

lr_ramp_steps : 100 # learning rate warm up
lr_decay_steps : 79000 # learning rate decay
weight_decay : 1.0e-6 # weight decay param
learning_rate_init : 1.412e-4 # init learning rate
learning_rate_target : 1.412e-4 # target final learning rate
lr_ramp_steps: 100 # learning rate warm up
lr_decay_steps: 79000 # learning rate decay
weight_decay: 1.0e-6 # weight decay param
learning_rate_init: 1.412e-4 # init learning rate
learning_rate_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas

log_interval : 25 # log interval
checkpoint_interval : 1000000 # checkpoint interval
eval_interval : 16 # eval interval
checkpoint_interval: 10000 # checkpoint interval
eval_interval: 16 # eval interval

pipeline : "PPOPipeline" # prompt pipeline to load
orchestrator : "PPOOrchestrator" # orchestrator to load

input_size : 4 # max input size
gen_size : 48 # max gen size

accelerate : True # Use accelerate
accelerate_config_path : "" # Path to accelerate config(for logging purposes)
pipeline: "PPOPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load

method:
name : 'ppoconfig' # Name of RL method config
num_rollouts : 128 # Number of rollouts to collect per epoch
chunk_size : 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs : 4 # Number of ppo epochs
init_kl_coef : 0.2 # init kl coefficient
target : 6 # target kl coefficient, set None for fixed kl coef
horizon : 10000 # PPO horizon
gamma : 1 # PPO discount
lam : 0.95 # PPO lambda
cliprange : 0.2 # clip range
cliprange_value : 0.2 # clip range
vf_coef : 2.3 # value term weight
gen_kwargs :
max_length : 48 # LM max sample gen length
min_length : 48 # LM min sample gen length
top_k : 0.0 # top k
top_p : 1.0 # top p
do_sample : True # sample
name: 'ppoconfig' # Name of RL method config
num_rollouts: 128 # Number of rollouts to collect per epoch
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs: 4 # Number of ppo epochs
init_kl_coef: 0.2 # init kl coefficient
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 1 # PPO discount
lam: 0.95 # PPO lambda
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
top_k: 0.0 # top k
top_p: 1.0 # top p
do_sample: True # sample
Loading

0 comments on commit 06cd30f

Please sign in to comment.