Skip to content

Commit

Permalink
palindrome changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexGibson0 committed Jan 25, 2024
1 parent d2db3d3 commit d0cf251
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 33 deletions.
68 changes: 60 additions & 8 deletions gbmi/exp_group_finetuning/groups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar

Check failure on line 2 in gbmi/exp_group_finetuning/groups.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'typing.Any' imported but unused

Check failure on line 2 in gbmi/exp_group_finetuning/groups.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'typing.Dict' imported but unused

Check failure on line 2 in gbmi/exp_group_finetuning/groups.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'typing.Optional' imported but unused

Check failure on line 2 in gbmi/exp_group_finetuning/groups.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'typing.cast' imported but unused

Check failure on line 2 in gbmi/exp_group_finetuning/groups.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'typing.Literal' imported but unused

import json
import torch

T = TypeVar("T")

Expand All @@ -11,6 +12,10 @@ class Group(ABC, Generic[T]):
def id() -> T:
...

@abstractmethod
def toJSON(self):
...

@abstractmethod
def name(self) -> str:
...
Expand All @@ -19,29 +24,72 @@ def name(self) -> str:
def size(self) -> int:
...

@abstractmethod
def index(self) -> int:
...

@staticmethod
@abstractmethod
def parameternames() -> List[str]:
...

@staticmethod
@abstractmethod
def op(a: T, b: T) -> T:
def op(self, a: T, b: T) -> T:
...

@classmethod
def reduce(cls, xs: T) -> T:
accumulator = cls.id()
def reduce(self, xs: T) -> T:
accumulator = self.__class__.id()
for x in xs:
accumulator = cls.op(accumulator, x)
accumulator = self.op(accumulator, x)

return accumulator


class DihedralGroup(Group):
def __init__(self, n: int):
self.n = n
self.lookup = []
for x in range(2 * n):
self.lookup.append([])
for y in range(2 * n):
j = x % 2
if j == 0:
result = (y % 2 + (2 * ((x // 2 + y // 2) % n))) % (2 * n)
else:
result = ((y % 2 + 1) % 2 + (2 * ((x // 2 - y // 2) % n))) % (2 * n)

self.lookup[x].append(result)
self.lookup = torch.tensor(self.lookup).to("cuda")

def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)

def name(self) -> str:
return "DihedralGroup" + str(2 * self.n)

def size(self) -> int:
return 2 * self.n

def index(self) -> int:
return self.n

def parameternames() -> List[str]:
return ["modulus"]

def id():
return 0

def op(self, x, y):
return self.lookup[x][:, y]


class CyclicGroup(Group):
def __init__(self, n: int):
self.n = n

def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)

def name(self) -> str:
return "CyclicGroup" + str(self.n)

Expand All @@ -51,12 +99,16 @@ def size(self) -> int:
def parameternames() -> List[str]:
return ["modulus"]

def index(self) -> int:
return self.n

def id():
return 0

def op(self, x, y):
return (x + y) % self.n


GroupDict = {"Cyclic": CyclicGroup}
GroupDict = {"CyclicGroup": CyclicGroup, "DihedralGroup": DihedralGroup}
cycle = CyclicGroup(5)
dihedral = DihedralGroup(4)
52 changes: 36 additions & 16 deletions gbmi/exp_group_finetuning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from dataclasses import field
from collections.abc import Callable

Check failure on line 5 in gbmi/exp_group_finetuning/train.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

'collections.abc.Callable' imported but unused

from groups import Group, GroupDict, CyclicGroup
from gbmi.exp_group_finetuning.groups import (
Group,
GroupDict,
CyclicGroup,
DihedralGroup,
)
import sys
from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar
from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar, Type
from gbmi import utils

import numpy as np
Expand Down Expand Up @@ -42,7 +47,9 @@
class ModularFineTuning(ExperimentConfig):
model_config: HookedTransformerConfig
# using int instead of abstract class because i'm clueless what's going on with typing
group: Group
group_family: str
group_index: int
group_size: int
group_name: str
zero_biases: bool = True
attention_rate: float = 0 # 0 is use attention, 1 is uniformly constant attention
Expand All @@ -62,7 +69,7 @@ def get_datamodule(self):

def get_summary_slug(self, config: Config[ModularFineTuning]) -> str:
return (
f"GroupFineTuning-{config.experiment.model_config.n_ctx}-{config.train_for[0]}-"
f"GroupFineTuning-{config.experiment.group_family+str(config.experiment.group_index)}-{config.experiment.model_config.n_ctx}-{config.train_for[0]}-"
f"{config.train_for[1]}-attention-rate-{config.experiment.attention_rate}"
f"{'-nondeterministic' if not config.deterministic else ''}"
)
Expand All @@ -83,7 +90,9 @@ def modular_addition_config(attn_rate: float, group: Group, elements: int):
attn_only=False,
normalization_type=None,
),
group=group,
group_family=type(group).__name__,
group_index=group.index(),
group_size=group.size(),
group_name=group.name(),
zero_biases=True,
attention_rate=attn_rate,
Expand All @@ -105,6 +114,12 @@ def modular_addition_config(attn_rate: float, group: Group, elements: int):
MODULAR_ADDITION_113_PIZZA_CONFIG = modular_addition_config(
attn_rate=1, group=CyclicGroup(113), elements=2
)
DIHEDRAL_100_CLOCK_CONFIG = modular_addition_config(
attn_rate=0, group=DihedralGroup(104), elements=2
)
DIHEDRAL_100_PIZZA_CONFIG = modular_addition_config(
attn_rate=1, group=DihedralGroup(104), elements=2
)


class ModularFineTuningTrainingWrapper(TrainingWrapper[ModularFineTuning]):
Expand All @@ -120,8 +135,8 @@ def build_model(config: Config[ModularFineTuning]) -> HookedTransformer:
model_config,
{
"seed": reseed(config.seed, "model"),
"d_vocab": config.experiment.group.size() + 1,
"d_vocab_out": config.experiment.group.size(),
"d_vocab": config.experiment.group_size + 1,
"d_vocab_out": config.experiment.group_size,
},
warn_if_not_default=False,
)
Expand All @@ -138,9 +153,14 @@ def loss_fn(
logits: Float[Tensor, "batch pos d_vocab"], # noqa: F722
labels: Integer[Tensor, "batch"], # noqa: F821
) -> Float[Tensor, ""]: # noqa: F722
logits = logits
labels = labels
logits = logits[:, -1, :].to(torch.float64)

log_probs = utils.log_softmax(logits, dim=-1)

correct_log_probs = log_probs.gather(-1, labels.unsqueeze(-1))[:, 0]

return -correct_log_probs.mean()

@staticmethod
Expand All @@ -162,14 +182,18 @@ def run_batch(
self, x: Float[Tensor, "batch pos"], prefix: str # noqa: F722
) -> Float[Tensor, ""]: # noqa: F722
self.model.to(x.device, print_details=False)
labels = self.config.experiment.group.reduce(list(x[:, :-1]))

labels = GroupDict[self.config.experiment.group_family](
self.config.experiment.group_index
).reduce(list(x[:, :-1].T))
assert (
len(labels.shape) == 1
), f"labels.shape == {labels.shape} != 1 (from x.shape == {x.shape})"
y_preds = self.model.run_with_hooks(
x, fwd_hooks=[("blocks.0.attn.hook_pattern", self.attention_hook)]
)
loss = self.loss_fn(y_preds, labels)

self.log(f"{prefix}loss", loss, prog_bar=True)
acc = self.acc_fn(y_preds, labels)
self.log(f"{prefix}acc", acc, prog_bar=True)
Expand Down Expand Up @@ -206,10 +230,11 @@ def setup(self, stage: str):
# Full dataset
rng = np.random.default_rng(self.dataset_seed)
pairs = generate_all_sequences(
self.config.experiment.group.size(), self.model_config.n_ctx - 1
self.config.experiment.group_size,
self.model_config.n_ctx - 1,
)
# concat a special token of value self.config.experiment.p to the end of each sequence for '='
equals_token = self.config.experiment.group.size()
equals_token = self.config.experiment.group_size
data = torch.cat(
[pairs, equals_token * torch.ones((len(pairs), 1))], dim=1
).long()
Expand Down Expand Up @@ -296,12 +321,7 @@ def main(argv: List[str] = sys.argv):

add_force_argument(parser)
add_no_save_argument(parser)
HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS = set(
(
"d_vocab",
"d_vocab_out",
)
)
HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS = set(("d_vocab", "d_vocab_out", "group"))
Config.add_arguments(parser)
add_HookedTransformerConfig_arguments(
parser, exclude_arguments=HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS
Expand Down
27 changes: 18 additions & 9 deletions notebooks_alex/pizzaclock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from gbmi.exp_modular_fine_tuning.train import MODULAR_ADDITION_113_CLOCK_CONFIG
from gbmi.exp_modular_fine_tuning.train import MODULAR_ADDITION_113_PIZZA_CONFIG
from gbmi.exp_group_finetuning.train import MODULAR_ADDITION_113_CLOCK_CONFIG
from gbmi.exp_group_finetuning.train import MODULAR_ADDITION_113_PIZZA_CONFIG
from gbmi.exp_group_finetuning.train import DIHEDRAL_100_CLOCK_CONFIG
from gbmi.exp_group_finetuning.train import DIHEDRAL_100_PIZZA_CONFIG

from gbmi.exp_group_finetuning.groups import (
Group,
GroupDict,
CyclicGroup,
DihedralGroup,
)
from gbmi.model import train_or_load_model
import torch
from math import sqrt
Expand All @@ -8,11 +17,11 @@
import tqdm

device = "cuda"
p = 113
q = p
freeze_model = False
config = MODULAR_ADDITION_113_PIZZA_CONFIG

freeze_model = False
config = DIHEDRAL_100_PIZZA_CONFIG
p = config.experiment.group_index
q = p * 2
frac_train = 0.3
seed = 999
num_epochs = 5000
Expand Down Expand Up @@ -92,6 +101,7 @@ def loss_fn(logits, labels, softmax=True):
logits = logits[:, :, -1].squeeze(-1)
else:
logits = logits[:, -1, :]

logits = logits.to(torch.float64)
log_probs = logits.log_softmax(dim=-1)
correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
Expand All @@ -116,9 +126,8 @@ def loss_fn(logits, labels, softmax=True):
b_vector = einops.repeat(torch.arange(q), "j -> (i j)", i=q)
equals_vector = einops.repeat(torch.tensor(q), " -> (i j)", i=q, j=q)
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)


labels = (dataset[:, 0] - dataset[:, 1]) % q
labels = DihedralGroup(104).op(dataset[:, 0], dataset[:, 1]).flatten()
print(labels)
optimizer = torch.optim.AdamW(
full_model.parameters(), lr=1e-3, weight_decay=1, betas=(0.9, 0.98)
)
Expand Down

0 comments on commit d0cf251

Please sign in to comment.