In the course of learning about Transformers for text, I came across multiple different implementations, all purportedly built on the publicly available details of OpenAI's GPT-2. So I decided to setup a testbed to compare the different implementations. I also decided to add a simple non-transformer baseline, and I chose a simple 2-layer feed-forward MLP. To be clear, we're not comparing pre-trained models, but rather we're comparing different implementations of the model (i.e. how would these different implementations perform when configured similarly, and trained on the same data, in terms of accuracy and speed).
If you care about having something simple which you understand in its entirety, I would say TLTransformer is simpler than NanoGPT. The differences in performance between the different implementations probably don't matter in the big picture. If you have a really simple application, consider just throwing in a MLP as a baseline - you may be surprised by its performance.
!nvidia-smi
Mon Mar 13 06:16:43 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100-SXM... Off | 00000000:00:04.0 Off | 0 | | N/A 31C P0 53W / 400W | 8105MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| +-----------------------------------------------------------------------------+
try:
import google.colab
IN_COLAB = True
print("Running as a Colab notebook")
%pip install einops
%pip install transformers
%pip install fancy_einsum
%pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
IN_COLAB = False
import torch
import torch.nn as nn
import time
import torch.nn.functional as F
from typing import Any
import random
from transformers import OpenAIGPTConfig, OpenAIGPTModel
from transformer_lens.utils import gelu_new
from dataclasses import dataclass
import dataclasses
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import sys
if IN_COLAB:
!git clone https://github.com/karpathy/nanoGPT.git
sys.path.insert(0, 'nanoGPT')
else:
sys.path.insert(0, '../../nanoGPT')
from model import GPTConfig, GPT
@dataclass
class DataConfig:
reverse: bool = False
upper_bound: int = 999
each_digit_of_result_separate: bool = False
lhs_number_width: int = len(str(upper_bound))
rhs_number_width: int = len(str(upper_bound))+1
num_plus_examples: int = int(3e5)
num_minus_examples: int = int(3e5)
test_size: int = 4096
@dataclass
class GeneratedData:
all_strs: set[str]
train_x: torch.Tensor
train_y: torch.Tensor
test_x: torch.Tensor
test_y: torch.Tensor
# m stands for unary negative, | indicates end of sequence.
# "." is reserved for potential future use, but currently unused.
chars = '0123456789.=+-m|'
end_of_sequence = "|"
encoder = dict((c, i) for i, c in enumerate(chars))
decoder = dict((i, c) for i, c in enumerate(chars))
def reverse_str(s):
return s[::-1]
def stringify_problem(x, y, math_symbol, reverse=False, lhs_number_width=3, rhs_number_width=4):
if math_symbol == "+":
rhs = f"{x+y:0{rhs_number_width}}"
elif math_symbol == "-":
if x < y:
# since we're adding a unary negative, we need to subtract 1 from the width
if len(str(y-x)) > rhs_number_width - 1:
raise ValueError(f"{x} minus {y} doesn't fit in {rhs_number_width} digits")
rhs = "m" + f"{y-x:0{rhs_number_width-1}}"
else:
rhs = f"{x-y:0{rhs_number_width}}"
else:
raise ValueError(f"Unsupported math symbol {math_symbol}")
if reverse:
rhs = reverse_str(rhs)
# pad z with |s on the right, upto max_len
padded_rhs = rhs + end_of_sequence + end_of_sequence * (rhs_number_width - len(rhs))
str_x = f"{x:0{lhs_number_width}}"
str_y = f"{y:0{lhs_number_width}}"
if reverse:
lhs = reverse_str(str_x) + math_symbol + reverse_str(str_y) + "="
else:
lhs = str_x + math_symbol + str_y + "="
return lhs + padded_rhs
def generate_math_example(cfg: DataConfig, math_symbol = "+"):
x = random.randint(0, cfg.upper_bound)
y = random.randint(0, cfg.upper_bound)
return stringify_problem(x, y, math_symbol, cfg.reverse, cfg.lhs_number_width, cfg.rhs_number_width)
def tensorify_example(example):
return torch.tensor([encoder[c] for c in example])
## Let's create a function for taking X and Y, and including the digits of Y in X
## This function needs to ensure all the rows of X have the same length, hence
## it's going to be padding with the appropriate number of zeros on the left.
def include_digits_of_y_in_x(X, Y):
n = Y.shape[1]
# We need to create X0, X1, ..., Xn-1
Xs = []
Ys = []
for i in range(n):
Xs.append(torch.cat([torch.zeros(X.shape[0], n - 1 - i, dtype=torch.long), X, Y[:, :i]], dim=1))
Ys.append(Y[:, i])
X = torch.cat(Xs, dim=0)
Y = torch.cat(Ys, dim=0)
return X, Y
def generate_data(cfg: DataConfig) -> GeneratedData:
plus_strs = set([generate_math_example(cfg, math_symbol = "+") for _ in range(cfg.num_plus_examples)])
minus_strs = set([generate_math_example(cfg, math_symbol = "-") for _ in range(cfg.num_minus_examples)])
all_strs = plus_strs.union(minus_strs)
all_examples = torch.stack([tensorify_example(i) for i in all_strs])
assert cfg.test_size < all_examples.shape[0] * 0.9, "Test size requested is more than 90% of all data"
train_size = all_examples.shape[0] - cfg.test_size
rhs_len = cfg.rhs_number_width + 1
train_x = all_examples[:train_size, :-rhs_len]
train_y = all_examples[:train_size, -rhs_len:]
test_x = all_examples[train_size:, :-rhs_len]
test_y = all_examples[train_size:, -rhs_len:]
if cfg.each_digit_of_result_separate:
train_x, train_y = include_digits_of_y_in_x(train_x, train_y)
test_x, test_y = include_digits_of_y_in_x(test_x, test_y)
return GeneratedData(all_strs=all_strs, train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y)
## Let's make sure the include_digits_of_y_in_x function works as expected
def test_for_include_digits_of_y_in_x():
X = tensorify_example("123+456=").unsqueeze(dim=0)
Y = tensorify_example("579" + end_of_sequence).unsqueeze(dim=0)
X, Y = include_digits_of_y_in_x(X, Y)
assert torch.equal(X[0], tensorify_example("000123+456="))
assert torch.equal(Y[0], tensorify_example("5").squeeze())
assert torch.equal(X[1] , tensorify_example("00123+456=5"))
assert torch.equal(Y[1] , tensorify_example("7").squeeze())
assert torch.equal(X[2] , tensorify_example("0123+456=57"))
assert torch.equal(Y[2] , tensorify_example("9").squeeze())
assert torch.equal(X[3] , tensorify_example("123+456=579"))
assert torch.equal(Y[3] , tensorify_example(end_of_sequence).squeeze())
test_for_include_digits_of_y_in_x()
stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=True)
'321+654=9750|'
def test_for_stringify_problem():
assert stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=False) == "123+456=0579|"
assert stringify_problem(123, 999, "+", lhs_number_width=5, rhs_number_width=4, reverse=False) == "00123+00999=1122|"
assert stringify_problem(123, 999, "-", lhs_number_width=3, rhs_number_width=4, reverse=False) == "123-999=m876|"
assert stringify_problem(999, 123, "-", lhs_number_width=3, rhs_number_width=4, reverse=False) == "999-123=0876|"
assert stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=True) == "321+654=9750|"
assert stringify_problem(23, 45, "+", lhs_number_width=3, rhs_number_width=4, reverse=True) == "320+540=8600|"
assert stringify_problem(123, 999, "-", lhs_number_width=3, rhs_number_width=4, reverse=True) == "321-999=678m|"
test_for_stringify_problem()
reversed_add_3digit_separated_cfg = DataConfig(
each_digit_of_result_separate=True, upper_bound=999,
num_plus_examples=int(1e5), num_minus_examples=0, test_size=4096, reverse=True,
)
reversed_add_3digit_separated = generate_data(reversed_add_3digit_separated_cfg)
list(reversed_add_3digit_separated.all_strs)[0]
'272+926=1090|'
results_dict = {}
@dataclass
class TrainConfig:
epochs: int = 1000
train_batch_size: int = 128
lr: float = 1e-3
weight_decay: float = 1e-4
epoch_interval: int = 100
time_budget_seconds: int = 120
@dataclass
class ResultRow:
model_config: Any
train_config: TrainConfig
num_parameters: int
epochs: int
train_loss: float
train_accuracy: float
test_loss: float
test_accuracy: float
train_time_in_seconds: float
time_per_example_in_micros: float
train_losses: dict[int, float]
train_accuracies: dict[int, float]
test_losses: dict[int, float]
test_accuracies: dict[int, float]
def train_and_eval(model_config: Any, m: nn.Module, data: GeneratedData, train_config: TrainConfig):
m = m.to(device)
num_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")
optimizer = torch.optim.AdamW(m.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
train_x = data.train_x.to(device)
train_y = data.train_y.to(device)
test_x = data.test_x.to(device)
test_y = data.test_y.to(device)
training_losses = {}
test_losses = {}
training_accuracies = {}
test_accuracies = {}
outer_start = time.time()
ep = 0
while ep < train_config.epochs:
start = time.time()
optimizer.zero_grad()
rind = torch.randint(0, train_x.shape[0], (train_config.train_batch_size,))
X = train_x[rind]
Y = train_y[rind]
output = m(X)
if type(output) is tuple:
# some models output other stuff besides the logits, let's just use the logits
logits = output[0][:, -1, :] #get logits for last token
else:
logits = output[:, -1, :] #get logits for last token
loss = F.cross_entropy(logits, Y)
if ep % train_config.epoch_interval == 0:
preds = torch.argmax(logits, dim=-1)
training_losses[ep] = loss.item()
training_accuracies[ep] = torch.sum(preds == Y).item() / preds.shape[0]
loss.backward()
optimizer.step()
elapsed = time.time() - start
if ep % train_config.epoch_interval == 0:
with torch.no_grad():
#calculate test loss
output = m(test_x)
if type(output) is tuple:
# some models output other stuff besides the logits, let's just use the logits
test_logits = output[0][:, -1, :]
else:
test_logits = output[:, -1, :]
test_loss = F.cross_entropy(test_logits, test_y)
test_preds = torch.argmax(test_logits, dim=-1)
test_losses[ep] = test_loss.item()
test_accuracies[ep] = torch.sum(test_preds == test_y).item() / test_preds.shape[0]
print(f"Epoch {ep}, train loss {training_losses[ep]: .3E}, test loss {test_losses[ep]: .3f}, " +
f"training accuracy {training_accuracies[ep]: .2f}, test accuracy {test_accuracies[ep]: .2f}, " +
f"time per example {elapsed * 1e6 / train_config.train_batch_size: .2f} µs")
if time.time() - outer_start > train_config.time_budget_seconds:
print("Time budget exceeded, hence stopping training")
break
if test_accuracies[ep] > 0.995:
print("Test accuracy > 99.5%, hence stopping training")
break
ep += 1
if len(training_losses) is None or len(training_accuracies) is None:
raise RuntimeError("Training did not run at all")
if len(test_losses) is None or len(test_accuracies) is None:
raise RuntimeError("Tests did not run at all")
total_elapsed = time.time() - outer_start
print(f"Total training time {total_elapsed: .2f} s")
result_row = ResultRow(
model_config=model_config,
train_config=train_config,
num_parameters=num_params,
epochs=ep+1,
train_loss=training_losses[max(training_losses.keys())],
train_accuracy=training_accuracies[max(training_accuracies.keys())],
test_loss=test_losses[max(test_losses.keys())],
test_accuracy=test_accuracies[max(test_accuracies.keys())],
train_time_in_seconds=total_elapsed,
time_per_example_in_micros=total_elapsed * 1e6 / ((ep + 1) * train_config.train_batch_size),
train_losses=training_losses,
train_accuracies=training_accuracies,
test_losses=test_losses,
test_accuracies=test_accuracies,
)
return result_row
## create wrapper around OpenAIGPTModel, so we get logits and not the last state.
class OpenAIGPTLMHeadModel(OpenAIGPTModel):
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
#self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
transformer_outputs = super().forward(input_ids, position_ids, token_type_ids, past, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
return lm_logits
train_config = TrainConfig(
epochs=10000,
train_batch_size=2048,
lr=1e-3,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=60,
)
#train_config.(lr=1e-2)
dataclasses.replace(train_config, lr=1e-2)
TrainConfig(epochs=10000, train_batch_size=2048, lr=0.01, weight_decay=0.0001, epoch_interval=50, time_budget_seconds=60)
model_config = OpenAIGPTConfig(
vocab_size=len(chars),
n_positions=reversed_add_3digit_separated.train_x.shape[1],
n_embd=16,
n_layer=4,
n_head=4,
)
train_config = TrainConfig(
epochs=10000,
train_batch_size=2048,
lr=1e-3,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=60,
)
results_dict[
("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated, train_config
)
results_dict[
("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-4)
)
results_dict[
("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 13824 Epoch 0, train loss 3.035E+00, test loss 2.953, training accuracy 0.03, test accuracy 0.05, time per example 23.54 µs Epoch 50, train loss 2.088E+00, test loss 2.081, training accuracy 0.36, test accuracy 0.37, time per example 13.81 µs Epoch 100, train loss 1.771E+00, test loss 1.738, training accuracy 0.38, test accuracy 0.40, time per example 13.90 µs Epoch 150, train loss 1.578E+00, test loss 1.550, training accuracy 0.42, test accuracy 0.43, time per example 15.00 µs Epoch 200, train loss 1.461E+00, test loss 1.449, training accuracy 0.44, test accuracy 0.46, time per example 14.44 µs Epoch 250, train loss 1.412E+00, test loss 1.424, training accuracy 0.47, test accuracy 0.46, time per example 15.62 µs Epoch 300, train loss 1.402E+00, test loss 1.412, training accuracy 0.46, test accuracy 0.46, time per example 13.55 µs Epoch 350, train loss 1.366E+00, test loss 1.405, training accuracy 0.47, test accuracy 0.46, time per example 14.52 µs Epoch 400, train loss 1.382E+00, test loss 1.399, training accuracy 0.47, test accuracy 0.46, time per example 13.42 µs Epoch 450, train loss 1.406E+00, test loss 1.397, training accuracy 0.45, test accuracy 0.46, time per example 13.53 µs Epoch 500, train loss 1.350E+00, test loss 1.396, training accuracy 0.48, test accuracy 0.46, time per example 13.64 µs Epoch 550, train loss 1.383E+00, test loss 1.392, training accuracy 0.47, test accuracy 0.46, time per example 13.87 µs Epoch 600, train loss 1.436E+00, test loss 1.389, training accuracy 0.43, test accuracy 0.46, time per example 13.49 µs Epoch 650, train loss 1.385E+00, test loss 1.391, training accuracy 0.46, test accuracy 0.46, time per example 14.66 µs Epoch 700, train loss 1.375E+00, test loss 1.389, training accuracy 0.47, test accuracy 0.46, time per example 15.45 µs Epoch 750, train loss 1.349E+00, test loss 1.389, training accuracy 0.47, test accuracy 0.46, time per example 13.62 µs Epoch 800, train loss 1.389E+00, test loss 1.387, training accuracy 0.47, test accuracy 0.46, time per example 13.88 µs Epoch 850, train loss 1.384E+00, test loss 1.387, training accuracy 0.46, test accuracy 0.46, time per example 13.67 µs Epoch 900, train loss 1.389E+00, test loss 1.386, training accuracy 0.46, test accuracy 0.46, time per example 13.87 µs Epoch 950, train loss 1.382E+00, test loss 1.385, training accuracy 0.46, test accuracy 0.46, time per example 15.09 µs Epoch 1000, train loss 1.381E+00, test loss 1.387, training accuracy 0.46, test accuracy 0.46, time per example 14.85 µs Epoch 1050, train loss 1.369E+00, test loss 1.385, training accuracy 0.47, test accuracy 0.46, time per example 15.62 µs Epoch 1100, train loss 1.368E+00, test loss 1.385, training accuracy 0.48, test accuracy 0.46, time per example 14.98 µs Epoch 1150, train loss 1.400E+00, test loss 1.385, training accuracy 0.46, test accuracy 0.46, time per example 15.74 µs Epoch 1200, train loss 1.410E+00, test loss 1.385, training accuracy 0.45, test accuracy 0.46, time per example 15.64 µs Epoch 1250, train loss 1.364E+00, test loss 1.383, training accuracy 0.48, test accuracy 0.46, time per example 13.96 µs Epoch 1300, train loss 1.348E+00, test loss 1.376, training accuracy 0.47, test accuracy 0.46, time per example 13.51 µs Epoch 1350, train loss 1.402E+00, test loss 1.369, training accuracy 0.46, test accuracy 0.47, time per example 13.80 µs Epoch 1400, train loss 1.404E+00, test loss 1.325, training accuracy 0.45, test accuracy 0.49, time per example 13.73 µs Epoch 1450, train loss 1.113E+00, test loss 1.132, training accuracy 0.56, test accuracy 0.55, time per example 13.81 µs Epoch 1500, train loss 8.473E-01, test loss 0.834, training accuracy 0.65, test accuracy 0.65, time per example 13.79 µs Epoch 1550, train loss 7.404E-01, test loss 0.717, training accuracy 0.69, test accuracy 0.70, time per example 14.00 µs Epoch 1600, train loss 6.300E-01, test loss 0.631, training accuracy 0.74, test accuracy 0.74, time per example 14.94 µs Epoch 1650, train loss 5.645E-01, test loss 0.564, training accuracy 0.77, test accuracy 0.77, time per example 15.62 µs Epoch 1700, train loss 5.171E-01, test loss 0.520, training accuracy 0.80, test accuracy 0.78, time per example 14.15 µs Epoch 1750, train loss 4.675E-01, test loss 0.467, training accuracy 0.80, test accuracy 0.81, time per example 13.81 µs Epoch 1800, train loss 4.360E-01, test loss 0.449, training accuracy 0.83, test accuracy 0.82, time per example 14.00 µs Epoch 1850, train loss 4.382E-01, test loss 0.424, training accuracy 0.82, test accuracy 0.83, time per example 13.74 µs Epoch 1900, train loss 3.987E-01, test loss 0.388, training accuracy 0.85, test accuracy 0.84, time per example 13.86 µs Epoch 1950, train loss 3.442E-01, test loss 0.367, training accuracy 0.87, test accuracy 0.85, time per example 13.94 µs Time budget exceeded, hence stopping training Total training time 60.25 s Number of parameters: 13824 Epoch 0, train loss 2.870E+00, test loss 2.859, training accuracy 0.06, test accuracy 0.06, time per example 14.24 µs Epoch 50, train loss 2.545E+00, test loss 2.520, training accuracy 0.26, test accuracy 0.26, time per example 13.58 µs Epoch 100, train loss 2.360E+00, test loss 2.344, training accuracy 0.33, test accuracy 0.33, time per example 14.92 µs Epoch 150, train loss 2.236E+00, test loss 2.244, training accuracy 0.36, test accuracy 0.35, time per example 15.68 µs Epoch 200, train loss 2.174E+00, test loss 2.176, training accuracy 0.37, test accuracy 0.36, time per example 15.54 µs Epoch 250, train loss 2.156E+00, test loss 2.137, training accuracy 0.36, test accuracy 0.36, time per example 13.87 µs Epoch 300, train loss 2.128E+00, test loss 2.109, training accuracy 0.35, test accuracy 0.36, time per example 14.06 µs Epoch 350, train loss 2.083E+00, test loss 2.078, training accuracy 0.36, test accuracy 0.37, time per example 13.85 µs Epoch 400, train loss 2.049E+00, test loss 2.030, training accuracy 0.37, test accuracy 0.37, time per example 13.76 µs Epoch 450, train loss 1.953E+00, test loss 1.944, training accuracy 0.39, test accuracy 0.39, time per example 14.04 µs Epoch 500, train loss 1.906E+00, test loss 1.890, training accuracy 0.40, test accuracy 0.41, time per example 13.96 µs Epoch 550, train loss 1.853E+00, test loss 1.844, training accuracy 0.41, test accuracy 0.41, time per example 15.53 µs Epoch 600, train loss 1.762E+00, test loss 1.780, training accuracy 0.43, test accuracy 0.41, time per example 14.94 µs Epoch 650, train loss 1.729E+00, test loss 1.725, training accuracy 0.41, test accuracy 0.41, time per example 15.39 µs Epoch 700, train loss 1.703E+00, test loss 1.694, training accuracy 0.41, test accuracy 0.41, time per example 16.42 µs Epoch 750, train loss 1.669E+00, test loss 1.670, training accuracy 0.40, test accuracy 0.42, time per example 15.88 µs Epoch 800, train loss 1.652E+00, test loss 1.650, training accuracy 0.42, test accuracy 0.42, time per example 14.50 µs Epoch 850, train loss 1.625E+00, test loss 1.631, training accuracy 0.43, test accuracy 0.43, time per example 14.20 µs Epoch 900, train loss 1.632E+00, test loss 1.614, training accuracy 0.43, test accuracy 0.43, time per example 13.94 µs Epoch 950, train loss 1.604E+00, test loss 1.597, training accuracy 0.42, test accuracy 0.44, time per example 13.64 µs Epoch 1000, train loss 1.584E+00, test loss 1.581, training accuracy 0.44, test accuracy 0.44, time per example 13.94 µs Epoch 1050, train loss 1.553E+00, test loss 1.564, training accuracy 0.45, test accuracy 0.45, time per example 13.98 µs Epoch 1100, train loss 1.539E+00, test loss 1.549, training accuracy 0.45, test accuracy 0.45, time per example 14.77 µs Epoch 1150, train loss 1.533E+00, test loss 1.534, training accuracy 0.45, test accuracy 0.45, time per example 15.14 µs Epoch 1200, train loss 1.518E+00, test loss 1.514, training accuracy 0.45, test accuracy 0.46, time per example 15.82 µs Epoch 1250, train loss 1.472E+00, test loss 1.502, training accuracy 0.47, test accuracy 0.46, time per example 13.87 µs Epoch 1300, train loss 1.498E+00, test loss 1.493, training accuracy 0.45, test accuracy 0.46, time per example 14.19 µs Epoch 1350, train loss 1.508E+00, test loss 1.483, training accuracy 0.46, test accuracy 0.46, time per example 13.91 µs Epoch 1400, train loss 1.488E+00, test loss 1.476, training accuracy 0.46, test accuracy 0.46, time per example 13.92 µs Epoch 1450, train loss 1.447E+00, test loss 1.469, training accuracy 0.46, test accuracy 0.46, time per example 13.94 µs Epoch 1500, train loss 1.444E+00, test loss 1.464, training accuracy 0.47, test accuracy 0.46, time per example 13.91 µs Epoch 1550, train loss 1.424E+00, test loss 1.458, training accuracy 0.48, test accuracy 0.46, time per example 14.01 µs Epoch 1600, train loss 1.456E+00, test loss 1.454, training accuracy 0.46, test accuracy 0.46, time per example 14.68 µs Epoch 1650, train loss 1.426E+00, test loss 1.451, training accuracy 0.46, test accuracy 0.46, time per example 18.00 µs Epoch 1700, train loss 1.453E+00, test loss 1.445, training accuracy 0.46, test accuracy 0.46, time per example 14.04 µs Epoch 1750, train loss 1.473E+00, test loss 1.443, training accuracy 0.45, test accuracy 0.46, time per example 14.18 µs Epoch 1800, train loss 1.440E+00, test loss 1.441, training accuracy 0.46, test accuracy 0.46, time per example 13.72 µs Epoch 1850, train loss 1.423E+00, test loss 1.437, training accuracy 0.47, test accuracy 0.46, time per example 14.47 µs Epoch 1900, train loss 1.420E+00, test loss 1.435, training accuracy 0.47, test accuracy 0.46, time per example 14.02 µs Epoch 1950, train loss 1.420E+00, test loss 1.431, training accuracy 0.46, test accuracy 0.46, time per example 13.88 µs Time budget exceeded, hence stopping training Total training time 60.90 s Number of parameters: 13824 Epoch 0, train loss 2.934E+00, test loss 2.483, training accuracy 0.02, test accuracy 0.22, time per example 14.37 µs Epoch 50, train loss 1.503E+00, test loss 1.482, training accuracy 0.42, test accuracy 0.42, time per example 13.97 µs Epoch 100, train loss 1.381E+00, test loss 1.397, training accuracy 0.47, test accuracy 0.46, time per example 14.94 µs Epoch 150, train loss 1.383E+00, test loss 1.391, training accuracy 0.47, test accuracy 0.46, time per example 14.93 µs Epoch 200, train loss 1.386E+00, test loss 1.388, training accuracy 0.46, test accuracy 0.46, time per example 15.31 µs Epoch 250, train loss 1.434E+00, test loss 1.387, training accuracy 0.44, test accuracy 0.46, time per example 15.94 µs Epoch 300, train loss 1.360E+00, test loss 1.387, training accuracy 0.47, test accuracy 0.46, time per example 16.37 µs Epoch 350, train loss 1.388E+00, test loss 1.386, training accuracy 0.45, test accuracy 0.46, time per example 14.01 µs Epoch 400, train loss 1.393E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 13.82 µs Epoch 450, train loss 1.417E+00, test loss 1.383, training accuracy 0.45, test accuracy 0.46, time per example 14.33 µs Epoch 500, train loss 1.367E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.46, time per example 14.47 µs Epoch 550, train loss 1.402E+00, test loss 1.387, training accuracy 0.45, test accuracy 0.46, time per example 14.08 µs Epoch 600, train loss 1.350E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.46, time per example 15.24 µs Epoch 650, train loss 1.397E+00, test loss 1.383, training accuracy 0.46, test accuracy 0.46, time per example 15.65 µs Epoch 700, train loss 1.392E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 15.69 µs Epoch 750, train loss 1.365E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 16.24 µs Epoch 800, train loss 1.403E+00, test loss 1.383, training accuracy 0.45, test accuracy 0.46, time per example 14.10 µs Epoch 850, train loss 1.417E+00, test loss 1.384, training accuracy 0.45, test accuracy 0.46, time per example 14.08 µs Epoch 900, train loss 1.413E+00, test loss 1.385, training accuracy 0.45, test accuracy 0.46, time per example 14.37 µs Epoch 950, train loss 1.361E+00, test loss 1.383, training accuracy 0.46, test accuracy 0.46, time per example 14.30 µs Epoch 1000, train loss 1.409E+00, test loss 1.384, training accuracy 0.45, test accuracy 0.46, time per example 14.16 µs Epoch 1050, train loss 1.354E+00, test loss 1.383, training accuracy 0.48, test accuracy 0.46, time per example 14.05 µs Epoch 1100, train loss 1.439E+00, test loss 1.384, training accuracy 0.43, test accuracy 0.46, time per example 14.91 µs Epoch 1150, train loss 1.381E+00, test loss 1.383, training accuracy 0.46, test accuracy 0.46, time per example 15.09 µs Epoch 1200, train loss 1.361E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.46, time per example 16.11 µs Epoch 1250, train loss 1.400E+00, test loss 1.383, training accuracy 0.46, test accuracy 0.46, time per example 14.04 µs Epoch 1300, train loss 1.350E+00, test loss 1.386, training accuracy 0.47, test accuracy 0.46, time per example 13.99 µs Epoch 1350, train loss 1.376E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.46, time per example 13.97 µs Epoch 1400, train loss 1.394E+00, test loss 1.383, training accuracy 0.45, test accuracy 0.46, time per example 13.87 µs Epoch 1450, train loss 1.400E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 14.07 µs Epoch 1500, train loss 1.349E+00, test loss 1.383, training accuracy 0.46, test accuracy 0.46, time per example 13.78 µs Epoch 1550, train loss 1.395E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 13.93 µs Epoch 1600, train loss 1.370E+00, test loss 1.385, training accuracy 0.46, test accuracy 0.46, time per example 16.47 µs Epoch 1650, train loss 1.394E+00, test loss 1.385, training accuracy 0.45, test accuracy 0.46, time per example 19.23 µs Epoch 1700, train loss 1.364E+00, test loss 1.425, training accuracy 0.46, test accuracy 0.45, time per example 24.51 µs Epoch 1750, train loss 1.391E+00, test loss 1.418, training accuracy 0.46, test accuracy 0.45, time per example 16.25 µs Epoch 1800, train loss 1.383E+00, test loss 1.399, training accuracy 0.47, test accuracy 0.46, time per example 14.36 µs Epoch 1850, train loss 1.381E+00, test loss 1.389, training accuracy 0.46, test accuracy 0.46, time per example 14.05 µs Epoch 1900, train loss 1.352E+00, test loss 1.385, training accuracy 0.47, test accuracy 0.46, time per example 14.29 µs Time budget exceeded, hence stopping training Total training time 60.35 s
@dataclass
class MLPForSeq2SeqConfig:
vocab_size: int
input_len: int
n_embed: int
n_hidden: int
output_len: int
## MLP for sequence to sequence problems.
## It operates in 3 steps: Embed each input token into a vector, concatenate all the vectors, and then pass through an MLP.
## The result of the MLP are the logits.
class MLPForSeq2Seq(nn.Module):
def __init__(self, cfg: MLPForSeq2SeqConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.n_embed)
self.mlp = nn.Sequential(
nn.Linear(cfg.n_embed * cfg.input_len, cfg.n_hidden),
nn.ReLU(),
nn.Linear(cfg.n_hidden, cfg.vocab_size * cfg.output_len)
)
def forward(self, x):
# x is of shape (batch_size, input_len)
x = self.embed(x)
# now x is of shape (batch_size, input_len, n_embed)
# reshape x to have shape (batch_size, input_len * n_embed)
x = x.view(-1, self.cfg.n_embed * self.cfg.input_len)
x = self.mlp(x)
# now x is of shape (batch_size, vocab_size * output_len)
# reshape x to have shape (batch_size, output_len, vocab_size)
x = x.view(-1, self.cfg.output_len, self.cfg.vocab_size)
return x
model_config = MLPForSeq2SeqConfig(
vocab_size=len(chars),
n_embed=8,
n_hidden=128,
input_len=reversed_add_3digit_separated.train_x.shape[1],
output_len=1,
)
m = MLPForSeq2Seq(model_config)
m(torch.randint(0, len(chars), (10, reversed_add_3digit_separated.train_x.shape[1])))[:, -1, :].shape
torch.Size([10, 16])
type(m(torch.randint(0, len(chars), (10, reversed_add_3digit_separated.train_x.shape[1]))))
torch.Tensor
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-3,
weight_decay=1e-4,
epoch_interval=500,
time_budget_seconds=60,
)
model_config = MLPForSeq2SeqConfig(
vocab_size=len(chars),
n_embed=8,
n_hidden=128,
input_len=reversed_add_3digit_separated.train_x.shape[1],
output_len=1,
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated, train_config
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 14608 Epoch 0, train loss 2.796E+00, test loss 2.749, training accuracy 0.10, test accuracy 0.13, time per example 1.47 µs Epoch 500, train loss 3.596E-01, test loss 0.355, training accuracy 0.89, test accuracy 0.90, time per example 1.04 µs Epoch 1000, train loss 3.558E-02, test loss 0.035, training accuracy 1.00, test accuracy 1.00, time per example 1.15 µs Test accuracy > 99.5%, hence stopping training Total training time 2.17 s Number of parameters: 14608 Epoch 0, train loss 2.755E+00, test loss 2.745, training accuracy 0.06, test accuracy 0.08, time per example 1.52 µs Epoch 500, train loss 1.565E+00, test loss 1.552, training accuracy 0.45, test accuracy 0.45, time per example 1.03 µs Epoch 1000, train loss 1.434E+00, test loss 1.428, training accuracy 0.47, test accuracy 0.48, time per example 1.41 µs Epoch 1500, train loss 1.346E+00, test loss 1.345, training accuracy 0.51, test accuracy 0.53, time per example 1.28 µs Epoch 2000, train loss 1.243E+00, test loss 1.228, training accuracy 0.57, test accuracy 0.57, time per example 1.58 µs Epoch 2500, train loss 1.111E+00, test loss 1.102, training accuracy 0.62, test accuracy 0.63, time per example 1.46 µs Epoch 3000, train loss 9.809E-01, test loss 0.974, training accuracy 0.68, test accuracy 0.69, time per example 1.09 µs Epoch 3500, train loss 8.241E-01, test loss 0.845, training accuracy 0.75, test accuracy 0.73, time per example 1.02 µs Epoch 4000, train loss 7.244E-01, test loss 0.715, training accuracy 0.78, test accuracy 0.79, time per example 1.10 µs Epoch 4500, train loss 5.737E-01, test loss 0.589, training accuracy 0.84, test accuracy 0.83, time per example 1.15 µs Epoch 5000, train loss 4.787E-01, test loss 0.473, training accuracy 0.88, test accuracy 0.88, time per example 1.12 µs Epoch 5500, train loss 3.490E-01, test loss 0.372, training accuracy 0.93, test accuracy 0.92, time per example 1.04 µs Epoch 6000, train loss 2.905E-01, test loss 0.284, training accuracy 0.95, test accuracy 0.95, time per example 1.04 µs Epoch 6500, train loss 2.122E-01, test loss 0.211, training accuracy 0.98, test accuracy 0.98, time per example 1.03 µs Epoch 7000, train loss 1.420E-01, test loss 0.152, training accuracy 0.99, test accuracy 0.99, time per example 1.06 µs Epoch 7500, train loss 1.012E-01, test loss 0.108, training accuracy 1.00, test accuracy 1.00, time per example 1.25 µs Test accuracy > 99.5%, hence stopping training Total training time 17.04 s Number of parameters: 14608 Epoch 0, train loss 2.790E+00, test loss 2.352, training accuracy 0.06, test accuracy 0.33, time per example 1.51 µs Epoch 500, train loss 5.589E-04, test loss 0.001, training accuracy 1.00, test accuracy 1.00, time per example 1.31 µs Test accuracy > 99.5%, hence stopping training Total training time 1.31 s
model_config = GPTConfig(
block_size=reversed_add_3digit_separated.train_x.shape[1],
vocab_size=len(chars),
n_layer=6, n_head=4, n_embd=16, dropout=0.1
)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-3,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=60,
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
model_config, GPT(model_config), reversed_add_3digit_separated, train_config
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
model_config, GPT(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
model_config, GPT(model_config), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-2)
)
fatal: destination path 'nanoGPT' already exists and is not an empty directory. WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 number of parameters: 0.02M Number of parameters: 20160 Epoch 0, train loss 2.780E+00, test loss 2.728, training accuracy 0.06, test accuracy 0.20, time per example 33.63 µs Epoch 50, train loss 2.224E+00, test loss 2.219, training accuracy 0.28, test accuracy 0.29, time per example 20.93 µs Epoch 100, train loss 1.890E+00, test loss 1.867, training accuracy 0.37, test accuracy 0.39, time per example 22.87 µs Epoch 150, train loss 1.747E+00, test loss 1.686, training accuracy 0.38, test accuracy 0.40, time per example 20.56 µs Epoch 200, train loss 1.587E+00, test loss 1.604, training accuracy 0.41, test accuracy 0.40, time per example 21.32 µs Epoch 250, train loss 1.526E+00, test loss 1.540, training accuracy 0.43, test accuracy 0.43, time per example 21.29 µs Epoch 300, train loss 1.429E+00, test loss 1.469, training accuracy 0.47, test accuracy 0.45, time per example 22.66 µs Epoch 350, train loss 1.412E+00, test loss 1.433, training accuracy 0.46, test accuracy 0.46, time per example 23.59 µs Epoch 400, train loss 1.428E+00, test loss 1.418, training accuracy 0.45, test accuracy 0.46, time per example 22.90 µs Epoch 450, train loss 1.409E+00, test loss 1.409, training accuracy 0.47, test accuracy 0.46, time per example 23.32 µs Epoch 500, train loss 1.352E+00, test loss 1.404, training accuracy 0.48, test accuracy 0.46, time per example 21.11 µs Epoch 550, train loss 1.417E+00, test loss 1.401, training accuracy 0.45, test accuracy 0.46, time per example 21.14 µs Epoch 600, train loss 1.406E+00, test loss 1.398, training accuracy 0.46, test accuracy 0.46, time per example 23.29 µs Epoch 650, train loss 1.379E+00, test loss 1.397, training accuracy 0.46, test accuracy 0.46, time per example 21.46 µs Epoch 700, train loss 1.402E+00, test loss 1.395, training accuracy 0.45, test accuracy 0.46, time per example 20.79 µs Epoch 750, train loss 1.449E+00, test loss 1.393, training accuracy 0.44, test accuracy 0.46, time per example 21.24 µs Epoch 800, train loss 1.358E+00, test loss 1.392, training accuracy 0.47, test accuracy 0.46, time per example 20.87 µs Epoch 850, train loss 1.398E+00, test loss 1.390, training accuracy 0.46, test accuracy 0.46, time per example 21.09 µs Epoch 900, train loss 1.345E+00, test loss 1.390, training accuracy 0.48, test accuracy 0.46, time per example 22.85 µs Epoch 950, train loss 1.377E+00, test loss 1.390, training accuracy 0.46, test accuracy 0.46, time per example 24.05 µs Epoch 1000, train loss 1.388E+00, test loss 1.388, training accuracy 0.46, test accuracy 0.46, time per example 21.32 µs Epoch 1050, train loss 1.415E+00, test loss 1.387, training accuracy 0.45, test accuracy 0.46, time per example 20.80 µs Epoch 1100, train loss 1.384E+00, test loss 1.387, training accuracy 0.46, test accuracy 0.46, time per example 20.86 µs Epoch 1150, train loss 1.381E+00, test loss 1.387, training accuracy 0.46, test accuracy 0.46, time per example 22.35 µs Epoch 1200, train loss 1.356E+00, test loss 1.386, training accuracy 0.46, test accuracy 0.46, time per example 22.89 µs Epoch 1250, train loss 1.415E+00, test loss 1.386, training accuracy 0.45, test accuracy 0.46, time per example 23.45 µs Epoch 1300, train loss 1.365E+00, test loss 1.387, training accuracy 0.47, test accuracy 0.46, time per example 21.26 µs Time budget exceeded, hence stopping training Total training time 61.96 s WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 number of parameters: 0.02M Number of parameters: 20160 Epoch 0, train loss 2.788E+00, test loss 2.785, training accuracy 0.05, test accuracy 0.05, time per example 21.03 µs Epoch 50, train loss 2.683E+00, test loss 2.678, training accuracy 0.24, test accuracy 0.25, time per example 20.64 µs Epoch 100, train loss 2.607E+00, test loss 2.606, training accuracy 0.35, test accuracy 0.34, time per example 22.48 µs Epoch 150, train loss 2.539E+00, test loss 2.538, training accuracy 0.35, test accuracy 0.35, time per example 23.84 µs Epoch 200, train loss 2.480E+00, test loss 2.473, training accuracy 0.34, test accuracy 0.35, time per example 26.10 µs Epoch 250, train loss 2.407E+00, test loss 2.411, training accuracy 0.36, test accuracy 0.35, time per example 23.87 µs Epoch 300, train loss 2.356E+00, test loss 2.352, training accuracy 0.35, test accuracy 0.35, time per example 20.95 µs Epoch 350, train loss 2.286E+00, test loss 2.297, training accuracy 0.37, test accuracy 0.35, time per example 20.96 µs Epoch 400, train loss 2.238E+00, test loss 2.243, training accuracy 0.36, test accuracy 0.36, time per example 20.79 µs Epoch 450, train loss 2.205E+00, test loss 2.193, training accuracy 0.36, test accuracy 0.37, time per example 21.65 µs Epoch 500, train loss 2.134E+00, test loss 2.146, training accuracy 0.39, test accuracy 0.38, time per example 23.42 µs Epoch 550, train loss 2.104E+00, test loss 2.101, training accuracy 0.38, test accuracy 0.38, time per example 23.71 µs Epoch 600, train loss 2.056E+00, test loss 2.058, training accuracy 0.39, test accuracy 0.39, time per example 21.78 µs Epoch 650, train loss 2.018E+00, test loss 2.019, training accuracy 0.40, test accuracy 0.40, time per example 21.38 µs Epoch 700, train loss 1.967E+00, test loss 1.983, training accuracy 0.40, test accuracy 0.40, time per example 20.99 µs Epoch 750, train loss 1.916E+00, test loss 1.949, training accuracy 0.42, test accuracy 0.40, time per example 24.10 µs Epoch 800, train loss 1.914E+00, test loss 1.917, training accuracy 0.40, test accuracy 0.40, time per example 22.57 µs Epoch 850, train loss 1.883E+00, test loss 1.888, training accuracy 0.40, test accuracy 0.40, time per example 24.72 µs Epoch 900, train loss 1.863E+00, test loss 1.860, training accuracy 0.40, test accuracy 0.41, time per example 21.32 µs Epoch 950, train loss 1.821E+00, test loss 1.836, training accuracy 0.40, test accuracy 0.40, time per example 20.65 µs Epoch 1000, train loss 1.799E+00, test loss 1.813, training accuracy 0.41, test accuracy 0.41, time per example 21.26 µs Epoch 1050, train loss 1.802E+00, test loss 1.792, training accuracy 0.40, test accuracy 0.41, time per example 21.10 µs Epoch 1100, train loss 1.811E+00, test loss 1.772, training accuracy 0.39, test accuracy 0.41, time per example 29.70 µs Epoch 1150, train loss 1.760E+00, test loss 1.754, training accuracy 0.41, test accuracy 0.41, time per example 24.86 µs Epoch 1200, train loss 1.702E+00, test loss 1.737, training accuracy 0.43, test accuracy 0.41, time per example 25.73 µs Epoch 1250, train loss 1.728E+00, test loss 1.720, training accuracy 0.40, test accuracy 0.41, time per example 21.66 µs Time budget exceeded, hence stopping training Total training time 60.69 s WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 number of parameters: 0.02M Number of parameters: 20160 Epoch 0, train loss 2.782E+00, test loss 2.696, training accuracy 0.09, test accuracy 0.16, time per example 21.20 µs Epoch 50, train loss 1.565E+00, test loss 1.555, training accuracy 0.38, test accuracy 0.38, time per example 21.16 µs Epoch 100, train loss 1.380E+00, test loss 1.403, training accuracy 0.46, test accuracy 0.46, time per example 21.58 µs Epoch 150, train loss 1.372E+00, test loss 1.390, training accuracy 0.46, test accuracy 0.46, time per example 21.23 µs Epoch 200, train loss 1.430E+00, test loss 1.387, training accuracy 0.44, test accuracy 0.46, time per example 22.43 µs Epoch 250, train loss 1.434E+00, test loss 1.386, training accuracy 0.44, test accuracy 0.46, time per example 23.99 µs Epoch 300, train loss 1.379E+00, test loss 1.389, training accuracy 0.46, test accuracy 0.46, time per example 21.07 µs Epoch 350, train loss 1.398E+00, test loss 1.389, training accuracy 0.45, test accuracy 0.46, time per example 21.58 µs Epoch 400, train loss 1.378E+00, test loss 1.388, training accuracy 0.46, test accuracy 0.46, time per example 21.29 µs Epoch 450, train loss 1.423E+00, test loss 1.384, training accuracy 0.45, test accuracy 0.46, time per example 20.87 µs Epoch 500, train loss 1.444E+00, test loss 1.383, training accuracy 0.44, test accuracy 0.46, time per example 22.98 µs Epoch 550, train loss 1.368E+00, test loss 1.372, training accuracy 0.47, test accuracy 0.47, time per example 23.73 µs Epoch 600, train loss 1.265E+00, test loss 1.275, training accuracy 0.51, test accuracy 0.50, time per example 21.27 µs Epoch 650, train loss 9.377E-01, test loss 0.900, training accuracy 0.59, test accuracy 0.63, time per example 21.47 µs Epoch 700, train loss 5.093E-01, test loss 0.538, training accuracy 0.80, test accuracy 0.77, time per example 21.60 µs Epoch 750, train loss 4.023E-01, test loss 0.386, training accuracy 0.84, test accuracy 0.85, time per example 21.73 µs Epoch 800, train loss 3.261E-01, test loss 0.316, training accuracy 0.88, test accuracy 0.88, time per example 22.41 µs Epoch 850, train loss 2.294E-01, test loss 0.254, training accuracy 0.91, test accuracy 0.91, time per example 24.94 µs Epoch 900, train loss 2.458E-01, test loss 0.229, training accuracy 0.91, test accuracy 0.91, time per example 23.72 µs Epoch 950, train loss 1.855E-01, test loss 0.185, training accuracy 0.94, test accuracy 0.93, time per example 21.34 µs Epoch 1000, train loss 1.861E-01, test loss 0.167, training accuracy 0.93, test accuracy 0.94, time per example 22.70 µs Epoch 1050, train loss 1.632E-01, test loss 0.167, training accuracy 0.94, test accuracy 0.94, time per example 21.32 µs Epoch 1100, train loss 1.790E-01, test loss 0.164, training accuracy 0.94, test accuracy 0.94, time per example 23.76 µs Epoch 1150, train loss 1.296E-01, test loss 0.158, training accuracy 0.95, test accuracy 0.94, time per example 23.91 µs Epoch 1200, train loss 1.128E-01, test loss 0.134, training accuracy 0.96, test accuracy 0.95, time per example 21.23 µs Epoch 1250, train loss 1.174E-01, test loss 0.122, training accuracy 0.96, test accuracy 0.96, time per example 21.77 µs Time budget exceeded, hence stopping training Total training time 60.06 s
TLTransformer code (it's a bit long, hence hidden)
import einops
from fancy_einsum import einsum
import math
@dataclass
class TLConfig:
d_model: int = 768
debug: bool = True
layer_norm_eps: float = 1e-5
d_vocab: int = 50257
init_range: float = 0.02
n_ctx: int = 1024
d_head: int = 64
d_mlp: int = 3072
n_heads: int = 12
n_layers: int = 12
class LayerNorm(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.w = nn.Parameter(torch.ones(cfg.d_model))
self.b = nn.Parameter(torch.zeros(cfg.d_model))
def forward(self, residual):
# residual: [batch, position, d_model]
if self.cfg.debug: print("Residual:", residual.shape)
residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
# Calculate the variance, square root it. Add in an epsilon to prevent divide by zero.
scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + self.cfg.layer_norm_eps).sqrt()
normalized = residual / scale
normalized = normalized * self.w + self.b
if self.cfg.debug: print("Normalized:", residual.shape)
return normalized
class Embed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
nn.init.normal_(self.W_E, std=self.cfg.init_range)
def forward(self, tokens):
# tokens: [batch, position]
if self.cfg.debug: print("Tokens:", tokens.shape)
embed = self.W_E[tokens, :] # [batch, position, d_model]
if self.cfg.debug: print("Embeddings:", embed.shape)
return embed
class PosEmbed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
nn.init.normal_(self.W_pos, std=self.cfg.init_range)
def forward(self, tokens):
# tokens: [batch, position]
if self.cfg.debug: print("Tokens:", tokens.shape)
pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
if self.cfg.debug: print("pos_embed:", pos_embed.shape)
return pos_embed
class Attention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
nn.init.normal_(self.W_Q, std=self.cfg.init_range)
self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
nn.init.normal_(self.W_K, std=self.cfg.init_range)
self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
nn.init.normal_(self.W_V, std=self.cfg.init_range)
self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
nn.init.normal_(self.W_O, std=self.cfg.init_range)
self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))
def forward(self, normalized_resid_pre):
# normalized_resid_pre: [batch, position, d_model]
if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)
q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K
attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
attn_scores = self.apply_causal_mask(attn_scores)
pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]
v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V
z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)
attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
return attn_out
def apply_causal_mask(self, attn_scores):
# attn_scores: [batch, n_heads, query_pos, key_pos]
mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
attn_scores.masked_fill_(mask, self.IGNORE)
return attn_scores
class MLP(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
nn.init.normal_(self.W_in, std=self.cfg.init_range)
self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
nn.init.normal_(self.W_out, std=self.cfg.init_range)
self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
def forward(self, normalized_resid_mid):
# normalized_resid_mid: [batch, position, d_model]
if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
post = gelu_new(pre)
mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
return mlp_out
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ln1 = LayerNorm(cfg)
self.attn = Attention(cfg)
self.ln2 = LayerNorm(cfg)
self.mlp = MLP(cfg)
def forward(self, resid_pre):
# resid_pre [batch, position, d_model]
normalized_resid_pre = self.ln1(resid_pre)
attn_out = self.attn(normalized_resid_pre)
resid_mid = resid_pre + attn_out
normalized_resid_mid = self.ln2(resid_mid)
mlp_out = self.mlp(normalized_resid_mid)
resid_post = resid_mid + mlp_out
return resid_post
class Unembed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
nn.init.normal_(self.W_U, std=self.cfg.init_range)
self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
def forward(self, normalized_resid_final):
# normalized_resid_final [batch, position, d_model]
if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
return logits
class TLTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embed = Embed(cfg)
self.pos_embed = PosEmbed(cfg)
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
self.ln_final = LayerNorm(cfg)
self.unembed = Unembed(cfg)
def forward(self, tokens):
# tokens [batch, position]
embed = self.embed(tokens)
pos_embed = self.pos_embed(tokens)
residual = embed + pos_embed
for block in self.blocks:
residual = block(residual)
normalized_resid_final = self.ln_final(residual)
logits = self.unembed(normalized_resid_final)
# logits have shape [batch, position, logits]
return logits
tlconfig = TLConfig(
d_model = 16,
debug = False,
d_vocab = len(chars),
n_ctx = reversed_add_3digit_separated.train_x.shape[1],
d_head = 4, n_heads = 4, n_layers = 4, d_mlp = 64)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-3,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=60,
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated, train_config
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated,
dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 13872 Epoch 0, train loss 2.778E+00, test loss 2.730, training accuracy 0.03, test accuracy 0.16, time per example 20.18 µs Epoch 50, train loss 2.171E+00, test loss 2.161, training accuracy 0.40, test accuracy 0.39, time per example 13.98 µs Epoch 100, train loss 1.790E+00, test loss 1.791, training accuracy 0.43, test accuracy 0.43, time per example 14.48 µs Epoch 150, train loss 1.585E+00, test loss 1.598, training accuracy 0.47, test accuracy 0.46, time per example 15.03 µs Epoch 200, train loss 1.496E+00, test loss 1.484, training accuracy 0.46, test accuracy 0.46, time per example 14.75 µs Epoch 250, train loss 1.455E+00, test loss 1.440, training accuracy 0.45, test accuracy 0.46, time per example 19.40 µs Epoch 300, train loss 1.421E+00, test loss 1.421, training accuracy 0.46, test accuracy 0.46, time per example 24.74 µs Epoch 350, train loss 1.474E+00, test loss 1.465, training accuracy 0.43, test accuracy 0.44, time per example 14.37 µs Epoch 400, train loss 1.423E+00, test loss 1.409, training accuracy 0.45, test accuracy 0.46, time per example 13.88 µs Epoch 450, train loss 1.439E+00, test loss 1.402, training accuracy 0.45, test accuracy 0.47, time per example 14.26 µs Epoch 500, train loss 1.369E+00, test loss 1.398, training accuracy 0.46, test accuracy 0.46, time per example 13.60 µs Epoch 550, train loss 1.404E+00, test loss 1.395, training accuracy 0.45, test accuracy 0.46, time per example 14.09 µs Epoch 600, train loss 1.384E+00, test loss 1.393, training accuracy 0.46, test accuracy 0.46, time per example 14.54 µs Epoch 650, train loss 1.366E+00, test loss 1.391, training accuracy 0.47, test accuracy 0.46, time per example 20.68 µs Epoch 700, train loss 1.417E+00, test loss 1.389, training accuracy 0.45, test accuracy 0.46, time per example 22.56 µs Epoch 750, train loss 1.371E+00, test loss 1.389, training accuracy 0.47, test accuracy 0.46, time per example 20.36 µs Epoch 800, train loss 1.370E+00, test loss 1.388, training accuracy 0.46, test accuracy 0.46, time per example 23.01 µs Epoch 850, train loss 1.382E+00, test loss 1.387, training accuracy 0.46, test accuracy 0.46, time per example 14.48 µs Epoch 900, train loss 1.353E+00, test loss 1.387, training accuracy 0.48, test accuracy 0.46, time per example 14.04 µs Epoch 950, train loss 1.400E+00, test loss 1.388, training accuracy 0.45, test accuracy 0.46, time per example 14.23 µs Epoch 1000, train loss 1.375E+00, test loss 1.386, training accuracy 0.47, test accuracy 0.46, time per example 14.25 µs Epoch 1050, train loss 1.423E+00, test loss 1.387, training accuracy 0.44, test accuracy 0.46, time per example 20.35 µs Epoch 1100, train loss 1.414E+00, test loss 1.386, training accuracy 0.44, test accuracy 0.46, time per example 22.99 µs Epoch 1150, train loss 1.395E+00, test loss 1.386, training accuracy 0.46, test accuracy 0.46, time per example 15.31 µs Epoch 1200, train loss 1.381E+00, test loss 1.385, training accuracy 0.45, test accuracy 0.46, time per example 14.67 µs Epoch 1250, train loss 1.392E+00, test loss 1.386, training accuracy 0.46, test accuracy 0.46, time per example 13.73 µs Epoch 1300, train loss 1.398E+00, test loss 1.385, training accuracy 0.47, test accuracy 0.46, time per example 13.96 µs Epoch 1350, train loss 1.351E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.45, time per example 14.18 µs Epoch 1400, train loss 1.405E+00, test loss 1.384, training accuracy 0.45, test accuracy 0.46, time per example 14.48 µs Epoch 1450, train loss 1.371E+00, test loss 1.384, training accuracy 0.48, test accuracy 0.46, time per example 23.40 µs Epoch 1500, train loss 1.382E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 25.16 µs Epoch 1550, train loss 1.373E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 14.13 µs Epoch 1600, train loss 1.358E+00, test loss 1.383, training accuracy 0.47, test accuracy 0.46, time per example 14.47 µs Epoch 1650, train loss 1.397E+00, test loss 1.383, training accuracy 0.44, test accuracy 0.46, time per example 14.04 µs Epoch 1700, train loss 1.382E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 14.62 µs Time budget exceeded, hence stopping training Total training time 60.25 s Number of parameters: 13872 Epoch 0, train loss 2.789E+00, test loss 2.785, training accuracy 0.05, test accuracy 0.05, time per example 14.73 µs Epoch 50, train loss 2.659E+00, test loss 2.654, training accuracy 0.30, test accuracy 0.30, time per example 14.08 µs Epoch 100, train loss 2.582E+00, test loss 2.583, training accuracy 0.36, test accuracy 0.35, time per example 13.97 µs Epoch 150, train loss 2.513E+00, test loss 2.507, training accuracy 0.40, test accuracy 0.41, time per example 19.63 µs Epoch 200, train loss 2.443E+00, test loss 2.434, training accuracy 0.43, test accuracy 0.43, time per example 22.57 µs Epoch 250, train loss 2.365E+00, test loss 2.366, training accuracy 0.46, test accuracy 0.45, time per example 14.14 µs Epoch 300, train loss 2.310E+00, test loss 2.301, training accuracy 0.45, test accuracy 0.46, time per example 13.73 µs Epoch 350, train loss 2.246E+00, test loss 2.242, training accuracy 0.46, test accuracy 0.46, time per example 22.13 µs Epoch 400, train loss 2.197E+00, test loss 2.186, training accuracy 0.45, test accuracy 0.46, time per example 23.04 µs Epoch 450, train loss 2.138E+00, test loss 2.133, training accuracy 0.46, test accuracy 0.46, time per example 13.21 µs Epoch 500, train loss 2.099E+00, test loss 2.084, training accuracy 0.45, test accuracy 0.46, time per example 13.59 µs Epoch 550, train loss 2.024E+00, test loss 2.037, training accuracy 0.48, test accuracy 0.46, time per example 19.00 µs Epoch 600, train loss 1.991E+00, test loss 1.994, training accuracy 0.47, test accuracy 0.46, time per example 23.00 µs Epoch 650, train loss 1.968E+00, test loss 1.953, training accuracy 0.44, test accuracy 0.46, time per example 13.40 µs Epoch 700, train loss 1.917E+00, test loss 1.915, training accuracy 0.46, test accuracy 0.46, time per example 13.68 µs Epoch 750, train loss 1.864E+00, test loss 1.879, training accuracy 0.47, test accuracy 0.46, time per example 13.62 µs Epoch 800, train loss 1.845E+00, test loss 1.845, training accuracy 0.45, test accuracy 0.46, time per example 13.94 µs Epoch 850, train loss 1.831E+00, test loss 1.814, training accuracy 0.46, test accuracy 0.46, time per example 13.95 µs Epoch 900, train loss 1.767E+00, test loss 1.786, training accuracy 0.46, test accuracy 0.46, time per example 18.74 µs Epoch 950, train loss 1.788E+00, test loss 1.758, training accuracy 0.44, test accuracy 0.46, time per example 19.80 µs Epoch 1000, train loss 1.724E+00, test loss 1.733, training accuracy 0.47, test accuracy 0.46, time per example 24.58 µs Epoch 1050, train loss 1.690E+00, test loss 1.709, training accuracy 0.47, test accuracy 0.46, time per example 14.11 µs Epoch 1100, train loss 1.670E+00, test loss 1.687, training accuracy 0.47, test accuracy 0.46, time per example 13.99 µs Epoch 1150, train loss 1.635E+00, test loss 1.667, training accuracy 0.48, test accuracy 0.46, time per example 14.32 µs Epoch 1200, train loss 1.630E+00, test loss 1.647, training accuracy 0.47, test accuracy 0.46, time per example 14.52 µs Epoch 1250, train loss 1.613E+00, test loss 1.630, training accuracy 0.47, test accuracy 0.46, time per example 18.26 µs Epoch 1300, train loss 1.602E+00, test loss 1.614, training accuracy 0.47, test accuracy 0.46, time per example 14.09 µs Epoch 1350, train loss 1.556E+00, test loss 1.599, training accuracy 0.48, test accuracy 0.46, time per example 13.91 µs Epoch 1400, train loss 1.577E+00, test loss 1.585, training accuracy 0.47, test accuracy 0.46, time per example 19.98 µs Epoch 1450, train loss 1.623E+00, test loss 1.571, training accuracy 0.44, test accuracy 0.46, time per example 23.34 µs Epoch 1500, train loss 1.549E+00, test loss 1.560, training accuracy 0.46, test accuracy 0.46, time per example 15.32 µs Epoch 1550, train loss 1.526E+00, test loss 1.549, training accuracy 0.47, test accuracy 0.46, time per example 14.06 µs Epoch 1600, train loss 1.558E+00, test loss 1.538, training accuracy 0.45, test accuracy 0.46, time per example 14.50 µs Epoch 1650, train loss 1.529E+00, test loss 1.528, training accuracy 0.46, test accuracy 0.46, time per example 13.68 µs Epoch 1700, train loss 1.524E+00, test loss 1.519, training accuracy 0.44, test accuracy 0.46, time per example 19.25 µs Time budget exceeded, hence stopping training Total training time 60.11 s Number of parameters: 13872 Epoch 0, train loss 2.749E+00, test loss 2.689, training accuracy 0.18, test accuracy 0.27, time per example 19.48 µs Epoch 50, train loss 1.536E+00, test loss 1.644, training accuracy 0.42, test accuracy 0.34, time per example 22.88 µs Epoch 100, train loss 1.383E+00, test loss 1.392, training accuracy 0.47, test accuracy 0.46, time per example 18.39 µs Epoch 150, train loss 1.378E+00, test loss 1.359, training accuracy 0.46, test accuracy 0.47, time per example 22.33 µs Epoch 200, train loss 1.339E+00, test loss 1.387, training accuracy 0.49, test accuracy 0.46, time per example 13.97 µs Epoch 250, train loss 1.351E+00, test loss 1.384, training accuracy 0.47, test accuracy 0.46, time per example 13.96 µs Epoch 300, train loss 1.377E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 13.93 µs Epoch 350, train loss 1.350E+00, test loss 1.388, training accuracy 0.48, test accuracy 0.46, time per example 15.97 µs Epoch 400, train loss 1.392E+00, test loss 1.384, training accuracy 0.46, test accuracy 0.46, time per example 13.52 µs Epoch 450, train loss 1.384E+00, test loss 1.383, training accuracy 0.47, test accuracy 0.47, time per example 14.41 µs Epoch 500, train loss 1.386E+00, test loss 1.373, training accuracy 0.47, test accuracy 0.47, time per example 20.66 µs Epoch 550, train loss 1.279E+00, test loss 1.292, training accuracy 0.50, test accuracy 0.50, time per example 22.88 µs Epoch 600, train loss 1.188E+00, test loss 1.121, training accuracy 0.54, test accuracy 0.54, time per example 13.41 µs Epoch 650, train loss 7.683E-01, test loss 0.825, training accuracy 0.68, test accuracy 0.63, time per example 13.63 µs Epoch 700, train loss 5.573E-01, test loss 0.616, training accuracy 0.78, test accuracy 0.73, time per example 13.49 µs Epoch 750, train loss 1.453E-01, test loss 0.137, training accuracy 0.97, test accuracy 0.98, time per example 14.06 µs Epoch 800, train loss 2.238E-02, test loss 0.014, training accuracy 1.00, test accuracy 1.00, time per example 14.01 µs Test accuracy > 99.5%, hence stopping training Total training time 28.83 s
simple_results = []
for key, value in results_dict.items():
simple_results.append((key[0], f"{value.train_config.lr: .1e}", value.train_loss, value.test_loss, value.test_accuracy, value.num_parameters, value.time_per_example_in_micros, value.train_time_in_seconds))
import pandas as pd
pd.DataFrame.from_records(simple_results,
columns=["model", "Learning Rate", "train_loss", "test_loss", "test_accuracy",
"num_parameters", "time_per_example_in_micros", "train_time_in_seconds"]
).round(3)
model | Learning Rate | train_loss | test_loss | test_accuracy | num_parameters | time_per_example_in_micros | train_time_in_seconds | |
---|---|---|---|---|---|---|---|---|
0 | HuggingFaceGPT | 1.0e-03 | 0.344 | 0.367 | 0.852 | 13824 | 15.079 | 60.249 |
1 | HuggingFaceGPT | 1.0e-04 | 1.420 | 1.431 | 0.461 | 13824 | 15.241 | 60.898 |
2 | HuggingFaceGPT | 1.0e-02 | 1.352 | 1.385 | 0.460 | 13824 | 15.501 | 60.349 |
3 | MLP | 1.0e-03 | 0.036 | 0.035 | 1.000 | 14608 | 1.058 | 2.169 |
4 | MLP | 1.0e-04 | 0.101 | 0.108 | 0.999 | 14608 | 1.109 | 17.040 |
5 | MLP | 1.0e-02 | 0.001 | 0.001 | 1.000 | 14608 | 1.278 | 1.311 |
6 | NanoGPT | 1.0e-03 | 1.365 | 1.387 | 0.460 | 20160 | 23.255 | 61.963 |
7 | NanoGPT | 1.0e-04 | 1.728 | 1.720 | 0.412 | 20160 | 23.687 | 60.688 |
8 | NanoGPT | 1.0e-02 | 0.117 | 0.122 | 0.956 | 20160 | 23.442 | 60.061 |
9 | TLTransformer | 1.0e-03 | 1.382 | 1.384 | 0.463 | 13872 | 17.296 | 60.254 |
10 | TLTransformer | 1.0e-04 | 1.524 | 1.519 | 0.462 | 13872 | 17.255 | 60.110 |
11 | TLTransformer | 1.0e-02 | 0.022 | 0.014 | 1.000 | 13872 | 17.572 | 28.826 |
@dataclass
class BardData:
train: torch.Tensor
test: torch.Tensor
vocab_size: int
stoi: dict[str, int]
itos: dict[int, str]
def make_shakespeare_data():
import os
import requests
filename = 'shakespeare.txt'
if not os.path.exists(filename):
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(filename, 'w') as f:
f.write(requests.get(url).text)
with open(filename, 'r') as f:
text = f.read()
print("length of dataset in characters: ", len(text))
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print("vocab size:", vocab_size)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
return [stoi[c] for c in s]
n = len(text)
train_data = text[:int(n*0.9)]
val_data = text[int(n*0.9):]
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids)} tokens")
print(f"val has {len(val_ids)} tokens")
# don't make the mistake of setting dtype to uint8 in an effort to save memory
# somehow the code doesn't work with dtypes other than long
train_ids = torch.tensor(train_ids, dtype=torch.long)
val_ids = torch.tensor(val_ids, dtype=torch.long)
return BardData(train_ids, val_ids, vocab_size, stoi, itos)
bd = make_shakespeare_data()
length of dataset in characters: 1115394 all the unique characters: !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz vocab size: 65 train has 1003854 tokens val has 111540 tokens
def get_batch(data: BardData, is_train: bool, batch_size: int, block_size: int, is_y_single_token: bool = False):
d = data.train if is_train else data.test
ix = torch.randint(len(d) - block_size, (batch_size,))
x = torch.stack([d[i:i+block_size] for i in ix])
if is_y_single_token:
y = torch.stack([d[i+block_size] for i in ix])
else:
y = torch.stack([d[i+1:i+1+block_size] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
def train_and_eval_bard(model_config: Any, m: nn.Module, data: BardData,
train_config: TrainConfig, block_size: int,
is_y_single_token: bool = True, is_nano_gpt: bool = False):
m = m.to(device)
num_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")
optimizer = torch.optim.AdamW(m.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
training_losses = {}
test_losses = {}
training_accuracies = {}
test_accuracies = {}
outer_start = time.time()
ep = 0
while ep < train_config.epochs:
start = time.time()
optimizer.zero_grad()
x, y = get_batch(data, is_train=True, batch_size=train_config.train_batch_size,
block_size=block_size, is_y_single_token=is_y_single_token)
if is_nano_gpt:
output, loss = m(x, y)
logits = einops.rearrange(output, "b s v -> b v s")
else:
output = m(x)
if type(output) is tuple:
output = output[0]
if is_y_single_token:
logits = output.squeeze()
else:
logits = einops.rearrange(output, "b s v -> b v s")
loss = F.cross_entropy(logits, y)
if ep % train_config.epoch_interval == 0:
training_losses[ep] = loss.item()
if is_y_single_token:
preds = torch.argmax(logits, dim=-1)
next_tokens = y
else:
preds = torch.argmax(logits[:, -1, :], dim=-1)
next_tokens = y[:, -1]
training_accuracies[ep] = torch.sum(preds == next_tokens).item() / preds.shape[0]
loss.backward()
optimizer.step()
elapsed = time.time() - start
if ep % train_config.epoch_interval == 0:
with torch.no_grad():
#calculate test loss
test_x, test_y = get_batch(data, is_train=False, batch_size=train_config.train_batch_size,
block_size=block_size, is_y_single_token=is_y_single_token)
if is_nano_gpt:
output, test_loss = m(x, y)
test_logits = einops.rearrange(output, "b s v -> b v s")
else:
output = m(test_x)
if type(output) is tuple:
output = output[0]
if is_y_single_token:
test_logits = output.squeeze()
else:
test_logits = einops.rearrange(output, "b s v -> b v s")
test_loss = F.cross_entropy(test_logits, test_y)
test_losses[ep] = test_loss.item()
if is_y_single_token:
test_preds = torch.argmax(test_logits, dim=-1)
next_tokens = test_y
else:
test_preds = torch.argmax(test_logits[:, -1, :], dim=-1)
next_tokens = test_y[:, -1]
test_accuracies[ep] = torch.sum(test_preds == next_tokens).item() / test_preds.shape[0]
print(f"Epoch {ep}, train loss {training_losses[ep]: .3E}, test loss {test_losses[ep]: .3f}, " +
f"training accuracy {training_accuracies[ep]: .2f}, test accuracy {test_accuracies[ep]: .2f}, " +
f"time per example {elapsed * 1e6 / train_config.train_batch_size: .2f} µs")
if time.time() - outer_start > train_config.time_budget_seconds:
print("Time budget exceeded, hence stopping training")
break
ep += 1
if len(training_losses) is None or len(training_accuracies) is None:
raise RuntimeError("Training did not run at all")
if len(test_losses) is None or len(test_accuracies) is None:
raise RuntimeError("Tests did not run at all")
total_elapsed = time.time() - outer_start
print(f"Total training time {total_elapsed: .2f} s")
result_row = ResultRow(
model_config=model_config,
train_config=train_config,
num_parameters=num_params,
epochs=ep+1,
train_loss=training_losses[max(training_losses.keys())],
train_accuracy=training_accuracies[max(training_accuracies.keys())],
test_loss=test_losses[max(test_losses.keys())],
test_accuracy=test_accuracies[max(test_accuracies.keys())],
train_time_in_seconds=total_elapsed,
time_per_example_in_micros=total_elapsed * 1e6 / ((ep + 1) * train_config.train_batch_size),
train_losses=training_losses,
train_accuracies=training_accuracies,
test_losses=test_losses,
test_accuracies=test_accuracies,
)
return result_row
bard_results = {}
tlconfig = TLConfig(
d_model = 64,
debug = False,
d_vocab = bd.vocab_size,
n_ctx = 128,
d_head = 16,
n_heads = 4,
n_layers = 4,
d_mlp = 256)
train_config = TrainConfig(
epochs=20000,
train_batch_size=1024,
lr=1e-2,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=300,
)
bard_results[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval_bard(
tlconfig, TLTransformer(tlconfig), bd, train_config, block_size=tlconfig.n_ctx, is_y_single_token=False
)
tlconfig = TLConfig(
d_model = 96,
debug = False,
d_vocab = bd.vocab_size,
n_ctx = 128,
d_head = 24,
n_heads = 4,
n_layers = 4,
d_mlp = 384
)
bard_results[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval_bard(
tlconfig, TLTransformer(tlconfig), bd, train_config, block_size=tlconfig.n_ctx, is_y_single_token=False
)
Number of parameters: 216641 Epoch 0, train loss 4.221E+00, test loss 4.210, training accuracy 0.02, test accuracy 0.00, time per example 86.75 µs Epoch 50, train loss 3.300E+00, test loss 3.335, training accuracy 0.00, test accuracy 0.00, time per example 103.44 µs Epoch 100, train loss 2.947E+00, test loss 2.962, training accuracy 0.00, test accuracy 0.01, time per example 103.55 µs Epoch 150, train loss 2.789E+00, test loss 2.784, training accuracy 0.01, test accuracy 0.01, time per example 103.28 µs Epoch 200, train loss 2.644E+00, test loss 2.639, training accuracy 0.01, test accuracy 0.01, time per example 104.05 µs Epoch 250, train loss 2.569E+00, test loss 2.556, training accuracy 0.01, test accuracy 0.01, time per example 103.91 µs Epoch 300, train loss 2.486E+00, test loss 2.509, training accuracy 0.01, test accuracy 0.01, time per example 103.95 µs Epoch 350, train loss 2.435E+00, test loss 2.458, training accuracy 0.01, test accuracy 0.01, time per example 103.47 µs Epoch 400, train loss 2.369E+00, test loss 2.418, training accuracy 0.00, test accuracy 0.00, time per example 123.16 µs Epoch 450, train loss 2.375E+00, test loss 2.403, training accuracy 0.01, test accuracy 0.00, time per example 100.34 µs Epoch 500, train loss 2.197E+00, test loss 2.266, training accuracy 0.01, test accuracy 0.01, time per example 104.06 µs Epoch 550, train loss 2.097E+00, test loss 2.179, training accuracy 0.01, test accuracy 0.00, time per example 103.48 µs Epoch 600, train loss 2.054E+00, test loss 2.143, training accuracy 0.00, test accuracy 0.01, time per example 103.09 µs Epoch 650, train loss 1.995E+00, test loss 2.085, training accuracy 0.01, test accuracy 0.01, time per example 103.53 µs Epoch 700, train loss 1.933E+00, test loss 2.075, training accuracy 0.01, test accuracy 0.01, time per example 102.97 µs Epoch 750, train loss 1.779E+00, test loss 1.926, training accuracy 0.01, test accuracy 0.01, time per example 103.28 µs Epoch 800, train loss 1.736E+00, test loss 1.894, training accuracy 0.01, test accuracy 0.01, time per example 103.34 µs Epoch 850, train loss 1.629E+00, test loss 1.829, training accuracy 0.01, test accuracy 0.01, time per example 104.22 µs Epoch 900, train loss 1.609E+00, test loss 1.779, training accuracy 0.01, test accuracy 0.01, time per example 103.54 µs Epoch 950, train loss 1.553E+00, test loss 1.756, training accuracy 0.01, test accuracy 0.01, time per example 103.18 µs Epoch 1000, train loss 1.523E+00, test loss 1.753, training accuracy 0.01, test accuracy 0.01, time per example 103.07 µs Epoch 1050, train loss 1.517E+00, test loss 1.746, training accuracy 0.00, test accuracy 0.01, time per example 102.80 µs Epoch 1100, train loss 1.469E+00, test loss 1.709, training accuracy 0.01, test accuracy 0.00, time per example 103.87 µs Epoch 1150, train loss 1.461E+00, test loss 1.692, training accuracy 0.01, test accuracy 0.01, time per example 104.11 µs Epoch 1200, train loss 1.420E+00, test loss 1.663, training accuracy 0.01, test accuracy 0.01, time per example 103.55 µs Epoch 1250, train loss 1.423E+00, test loss 1.657, training accuracy 0.01, test accuracy 0.00, time per example 111.39 µs Epoch 1300, train loss 1.400E+00, test loss 1.658, training accuracy 0.01, test accuracy 0.01, time per example 93.21 µs Epoch 1350, train loss 1.389E+00, test loss 1.650, training accuracy 0.00, test accuracy 0.00, time per example 104.23 µs Epoch 1400, train loss 1.787E+00, test loss 1.917, training accuracy 0.01, test accuracy 0.00, time per example 103.62 µs Epoch 1450, train loss 1.392E+00, test loss 1.640, training accuracy 0.01, test accuracy 0.00, time per example 103.04 µs Epoch 1500, train loss 1.363E+00, test loss 1.653, training accuracy 0.00, test accuracy 0.01, time per example 103.73 µs Epoch 1550, train loss 1.356E+00, test loss 1.621, training accuracy 0.01, test accuracy 0.01, time per example 102.62 µs Epoch 1600, train loss 1.356E+00, test loss 1.640, training accuracy 0.00, test accuracy 0.01, time per example 102.62 µs Epoch 1650, train loss 1.342E+00, test loss 1.618, training accuracy 0.01, test accuracy 0.01, time per example 103.59 µs Epoch 1700, train loss 1.336E+00, test loss 1.627, training accuracy 0.01, test accuracy 0.00, time per example 103.35 µs Epoch 1750, train loss 1.326E+00, test loss 1.613, training accuracy 0.00, test accuracy 0.00, time per example 103.23 µs Epoch 1800, train loss 1.315E+00, test loss 1.622, training accuracy 0.00, test accuracy 0.00, time per example 104.13 µs Epoch 1850, train loss 1.323E+00, test loss 1.615, training accuracy 0.00, test accuracy 0.00, time per example 105.23 µs Epoch 1900, train loss 1.310E+00, test loss 1.617, training accuracy 0.00, test accuracy 0.00, time per example 103.34 µs Epoch 1950, train loss 1.311E+00, test loss 1.611, training accuracy 0.00, test accuracy 0.01, time per example 103.51 µs Epoch 2000, train loss 1.305E+00, test loss 1.636, training accuracy 0.01, test accuracy 0.00, time per example 103.68 µs Epoch 2050, train loss 1.289E+00, test loss 1.598, training accuracy 0.01, test accuracy 0.01, time per example 104.22 µs Epoch 2100, train loss 1.300E+00, test loss 1.623, training accuracy 0.01, test accuracy 0.01, time per example 103.89 µs Epoch 2150, train loss 1.296E+00, test loss 1.582, training accuracy 0.01, test accuracy 0.01, time per example 104.03 µs Epoch 2200, train loss 1.283E+00, test loss 1.607, training accuracy 0.00, test accuracy 0.00, time per example 103.20 µs Epoch 2250, train loss 1.278E+00, test loss 1.609, training accuracy 0.01, test accuracy 0.00, time per example 103.32 µs Epoch 2300, train loss 1.286E+00, test loss 1.607, training accuracy 0.01, test accuracy 0.01, time per example 103.30 µs Epoch 2350, train loss 1.269E+00, test loss 1.613, training accuracy 0.01, test accuracy 0.01, time per example 103.16 µs Epoch 2400, train loss 1.263E+00, test loss 1.617, training accuracy 0.01, test accuracy 0.01, time per example 103.18 µs Epoch 2450, train loss 1.272E+00, test loss 1.590, training accuracy 0.01, test accuracy 0.01, time per example 103.24 µs Epoch 2500, train loss 1.255E+00, test loss 1.625, training accuracy 0.01, test accuracy 0.01, time per example 103.02 µs Epoch 2550, train loss 1.257E+00, test loss 1.618, training accuracy 0.01, test accuracy 0.01, time per example 103.09 µs Epoch 2600, train loss 1.259E+00, test loss 1.633, training accuracy 0.01, test accuracy 0.01, time per example 102.88 µs Epoch 2650, train loss 1.255E+00, test loss 1.649, training accuracy 0.01, test accuracy 0.01, time per example 103.66 µs Epoch 2700, train loss 1.256E+00, test loss 1.634, training accuracy 0.01, test accuracy 0.00, time per example 103.62 µs Epoch 2750, train loss 1.258E+00, test loss 1.623, training accuracy 0.00, test accuracy 0.01, time per example 103.01 µs Epoch 2800, train loss 1.237E+00, test loss 1.628, training accuracy 0.00, test accuracy 0.01, time per example 102.92 µs Epoch 2850, train loss 1.247E+00, test loss 1.625, training accuracy 0.00, test accuracy 0.01, time per example 103.67 µs Epoch 2900, train loss 1.236E+00, test loss 1.613, training accuracy 0.01, test accuracy 0.00, time per example 104.43 µs Time budget exceeded, hence stopping training Total training time 300.81 s Number of parameters: 472385 Epoch 0, train loss 4.199E+00, test loss 4.143, training accuracy 0.00, test accuracy 0.01, time per example 99.80 µs Epoch 50, train loss 3.158E+00, test loss 3.144, training accuracy 0.00, test accuracy 0.00, time per example 130.82 µs Epoch 100, train loss 2.906E+00, test loss 2.905, training accuracy 0.01, test accuracy 0.00, time per example 131.30 µs Epoch 150, train loss 2.671E+00, test loss 2.659, training accuracy 0.03, test accuracy 0.03, time per example 130.64 µs Epoch 200, train loss 2.590E+00, test loss 2.595, training accuracy 0.02, test accuracy 0.03, time per example 130.83 µs Epoch 250, train loss 2.531E+00, test loss 2.523, training accuracy 0.02, test accuracy 0.01, time per example 133.76 µs Epoch 300, train loss 2.487E+00, test loss 2.490, training accuracy 0.01, test accuracy 0.01, time per example 131.56 µs Epoch 350, train loss 2.424E+00, test loss 2.452, training accuracy 0.02, test accuracy 0.02, time per example 130.67 µs Epoch 400, train loss 2.347E+00, test loss 2.398, training accuracy 0.01, test accuracy 0.01, time per example 131.11 µs Epoch 450, train loss 2.273E+00, test loss 2.311, training accuracy 0.01, test accuracy 0.00, time per example 130.74 µs Epoch 500, train loss 2.254E+00, test loss 2.298, training accuracy 0.01, test accuracy 0.01, time per example 131.04 µs Epoch 550, train loss 2.120E+00, test loss 2.195, training accuracy 0.01, test accuracy 0.01, time per example 131.05 µs Epoch 600, train loss 2.090E+00, test loss 2.181, training accuracy 0.01, test accuracy 0.01, time per example 131.11 µs Epoch 650, train loss 2.001E+00, test loss 2.142, training accuracy 0.01, test accuracy 0.01, time per example 131.90 µs Epoch 700, train loss 1.888E+00, test loss 2.029, training accuracy 0.00, test accuracy 0.01, time per example 132.29 µs Epoch 750, train loss 1.843E+00, test loss 1.989, training accuracy 0.00, test accuracy 0.00, time per example 131.19 µs Epoch 800, train loss 1.757E+00, test loss 1.926, training accuracy 0.01, test accuracy 0.01, time per example 130.61 µs Epoch 850, train loss 1.720E+00, test loss 1.923, training accuracy 0.01, test accuracy 0.01, time per example 131.67 µs Epoch 900, train loss 1.654E+00, test loss 1.845, training accuracy 0.01, test accuracy 0.01, time per example 130.85 µs Epoch 950, train loss 1.682E+00, test loss 1.879, training accuracy 0.01, test accuracy 0.01, time per example 130.97 µs Epoch 1000, train loss 1.549E+00, test loss 1.768, training accuracy 0.01, test accuracy 0.01, time per example 131.76 µs Epoch 1050, train loss 1.506E+00, test loss 1.736, training accuracy 0.01, test accuracy 0.01, time per example 132.47 µs Epoch 1100, train loss 1.475E+00, test loss 1.701, training accuracy 0.01, test accuracy 0.01, time per example 131.71 µs Epoch 1150, train loss 1.452E+00, test loss 1.723, training accuracy 0.01, test accuracy 0.01, time per example 130.95 µs Epoch 1200, train loss 1.413E+00, test loss 1.698, training accuracy 0.01, test accuracy 0.00, time per example 131.56 µs Epoch 1250, train loss 1.396E+00, test loss 1.676, training accuracy 0.00, test accuracy 0.00, time per example 130.85 µs Epoch 1300, train loss 1.382E+00, test loss 1.663, training accuracy 0.01, test accuracy 0.01, time per example 131.45 µs Epoch 1350, train loss 1.369E+00, test loss 1.632, training accuracy 0.01, test accuracy 0.01, time per example 130.92 µs Epoch 1400, train loss 1.352E+00, test loss 1.625, training accuracy 0.01, test accuracy 0.00, time per example 131.19 µs Epoch 1450, train loss 1.339E+00, test loss 1.648, training accuracy 0.00, test accuracy 0.01, time per example 131.13 µs Epoch 1500, train loss 1.318E+00, test loss 1.625, training accuracy 0.01, test accuracy 0.01, time per example 131.40 µs Epoch 1550, train loss 1.295E+00, test loss 1.605, training accuracy 0.01, test accuracy 0.00, time per example 131.12 µs Epoch 1600, train loss 1.297E+00, test loss 1.608, training accuracy 0.01, test accuracy 0.01, time per example 132.11 µs Epoch 1650, train loss 1.292E+00, test loss 1.616, training accuracy 0.01, test accuracy 0.01, time per example 131.17 µs Epoch 1700, train loss 1.279E+00, test loss 1.619, training accuracy 0.01, test accuracy 0.01, time per example 130.97 µs Epoch 1750, train loss 1.270E+00, test loss 1.613, training accuracy 0.01, test accuracy 0.01, time per example 130.65 µs Epoch 1800, train loss 1.266E+00, test loss 1.621, training accuracy 0.01, test accuracy 0.01, time per example 130.82 µs Epoch 1850, train loss 1.252E+00, test loss 1.629, training accuracy 0.01, test accuracy 0.01, time per example 132.94 µs Epoch 1900, train loss 1.242E+00, test loss 1.626, training accuracy 0.01, test accuracy 0.01, time per example 130.14 µs Epoch 1950, train loss 1.245E+00, test loss 1.624, training accuracy 0.00, test accuracy 0.01, time per example 130.73 µs Epoch 2000, train loss 1.232E+00, test loss 1.622, training accuracy 0.00, test accuracy 0.01, time per example 130.76 µs Epoch 2050, train loss 1.227E+00, test loss 1.613, training accuracy 0.01, test accuracy 0.01, time per example 130.94 µs Epoch 2100, train loss 1.218E+00, test loss 1.660, training accuracy 0.01, test accuracy 0.00, time per example 131.15 µs Epoch 2150, train loss 1.209E+00, test loss 1.647, training accuracy 0.01, test accuracy 0.00, time per example 134.50 µs Epoch 2200, train loss 1.194E+00, test loss 1.659, training accuracy 0.00, test accuracy 0.00, time per example 130.24 µs Epoch 2250, train loss 1.203E+00, test loss 1.648, training accuracy 0.00, test accuracy 0.00, time per example 130.75 µs Time budget exceeded, hence stopping training Total training time 303.39 s
cfg1 = MLPForSeq2SeqConfig(
vocab_size=bd.vocab_size,
input_len=128,
n_embed=16,
n_hidden=256,
output_len=1)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-2,
weight_decay=1e-4,
epoch_interval=200,
time_budget_seconds=300,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
cfg1, MLPForSeq2Seq(cfg1), bd, train_config, block_size=cfg1.input_len, is_y_single_token=True
)
cfg2 = MLPForSeq2SeqConfig(
vocab_size=bd.vocab_size,
input_len=128,
n_embed=8,
n_hidden=128,
output_len=1)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-2,
weight_decay=1e-4,
epoch_interval=200,
time_budget_seconds=300,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
cfg2, MLPForSeq2Seq(cfg2), bd, train_config, block_size=cfg2.input_len, is_y_single_token=True
)
Number of parameters: 542289 Epoch 0, train loss 4.216E+00, test loss 5.999, training accuracy 0.01, test accuracy 0.16, time per example 19.51 µs Epoch 200, train loss 2.293E+00, test loss 2.493, training accuracy 0.33, test accuracy 0.29, time per example 18.44 µs Epoch 400, train loss 2.301E+00, test loss 2.315, training accuracy 0.33, test accuracy 0.32, time per example 17.59 µs Epoch 600, train loss 2.148E+00, test loss 2.259, training accuracy 0.37, test accuracy 0.35, time per example 27.56 µs Epoch 800, train loss 2.176E+00, test loss 2.232, training accuracy 0.37, test accuracy 0.36, time per example 18.03 µs Epoch 1000, train loss 2.117E+00, test loss 2.216, training accuracy 0.37, test accuracy 0.38, time per example 17.59 µs Epoch 1200, train loss 2.098E+00, test loss 2.210, training accuracy 0.40, test accuracy 0.37, time per example 18.21 µs Epoch 1400, train loss 2.049E+00, test loss 2.166, training accuracy 0.41, test accuracy 0.39, time per example 17.85 µs Epoch 1600, train loss 2.038E+00, test loss 2.206, training accuracy 0.40, test accuracy 0.38, time per example 17.80 µs Epoch 1800, train loss 2.033E+00, test loss 2.208, training accuracy 0.40, test accuracy 0.38, time per example 17.91 µs Epoch 2000, train loss 1.944E+00, test loss 2.181, training accuracy 0.43, test accuracy 0.38, time per example 17.90 µs Epoch 2200, train loss 1.977E+00, test loss 2.215, training accuracy 0.42, test accuracy 0.37, time per example 18.36 µs Epoch 2400, train loss 1.966E+00, test loss 2.184, training accuracy 0.42, test accuracy 0.38, time per example 17.77 µs Epoch 2600, train loss 1.969E+00, test loss 2.240, training accuracy 0.44, test accuracy 0.37, time per example 17.69 µs Epoch 2800, train loss 1.950E+00, test loss 2.210, training accuracy 0.43, test accuracy 0.39, time per example 17.81 µs Epoch 3000, train loss 1.978E+00, test loss 2.164, training accuracy 0.42, test accuracy 0.39, time per example 19.78 µs Epoch 3200, train loss 1.917E+00, test loss 2.166, training accuracy 0.44, test accuracy 0.40, time per example 18.15 µs Epoch 3400, train loss 1.908E+00, test loss 2.263, training accuracy 0.44, test accuracy 0.36, time per example 18.26 µs Epoch 3600, train loss 1.895E+00, test loss 2.189, training accuracy 0.44, test accuracy 0.39, time per example 18.25 µs Epoch 3800, train loss 1.931E+00, test loss 2.216, training accuracy 0.43, test accuracy 0.39, time per example 17.69 µs Epoch 4000, train loss 1.920E+00, test loss 2.156, training accuracy 0.44, test accuracy 0.39, time per example 18.02 µs Epoch 4200, train loss 1.856E+00, test loss 2.158, training accuracy 0.44, test accuracy 0.40, time per example 18.22 µs Epoch 4400, train loss 1.918E+00, test loss 2.265, training accuracy 0.44, test accuracy 0.38, time per example 17.73 µs Epoch 4600, train loss 1.930E+00, test loss 2.173, training accuracy 0.43, test accuracy 0.40, time per example 17.80 µs Epoch 4800, train loss 1.900E+00, test loss 2.227, training accuracy 0.44, test accuracy 0.36, time per example 18.13 µs Epoch 5000, train loss 1.847E+00, test loss 2.218, training accuracy 0.45, test accuracy 0.39, time per example 17.68 µs Epoch 5200, train loss 1.874E+00, test loss 2.182, training accuracy 0.45, test accuracy 0.38, time per example 26.37 µs Epoch 5400, train loss 1.841E+00, test loss 2.208, training accuracy 0.45, test accuracy 0.39, time per example 17.81 µs Epoch 5600, train loss 1.850E+00, test loss 2.178, training accuracy 0.45, test accuracy 0.41, time per example 18.08 µs Epoch 5800, train loss 1.860E+00, test loss 2.261, training accuracy 0.44, test accuracy 0.38, time per example 17.93 µs Epoch 6000, train loss 1.889E+00, test loss 2.233, training accuracy 0.44, test accuracy 0.38, time per example 17.84 µs Epoch 6200, train loss 1.875E+00, test loss 2.121, training accuracy 0.44, test accuracy 0.41, time per example 17.65 µs Epoch 6400, train loss 1.848E+00, test loss 2.098, training accuracy 0.45, test accuracy 0.42, time per example 18.00 µs Epoch 6600, train loss 1.835E+00, test loss 2.122, training accuracy 0.46, test accuracy 0.38, time per example 19.32 µs Epoch 6800, train loss 1.821E+00, test loss 2.099, training accuracy 0.47, test accuracy 0.40, time per example 18.07 µs Epoch 7000, train loss 1.874E+00, test loss 2.208, training accuracy 0.46, test accuracy 0.39, time per example 17.72 µs Epoch 7200, train loss 1.824E+00, test loss 2.192, training accuracy 0.45, test accuracy 0.41, time per example 17.69 µs Epoch 7400, train loss 1.888E+00, test loss 2.162, training accuracy 0.45, test accuracy 0.40, time per example 19.13 µs Epoch 7600, train loss 1.848E+00, test loss 2.182, training accuracy 0.46, test accuracy 0.39, time per example 18.02 µs Time budget exceeded, hence stopping training Total training time 302.81 s Number of parameters: 140105 Epoch 0, train loss 4.215E+00, test loss 3.657, training accuracy 0.01, test accuracy 0.14, time per example 21.02 µs Epoch 200, train loss 2.320E+00, test loss 2.374, training accuracy 0.33, test accuracy 0.33, time per example 106.10 µs Epoch 400, train loss 2.138E+00, test loss 2.230, training accuracy 0.37, test accuracy 0.34, time per example 18.01 µs Epoch 600, train loss 2.061E+00, test loss 2.218, training accuracy 0.40, test accuracy 0.36, time per example 22.09 µs Epoch 800, train loss 1.994E+00, test loss 2.136, training accuracy 0.42, test accuracy 0.39, time per example 17.73 µs Epoch 1000, train loss 1.926E+00, test loss 2.165, training accuracy 0.44, test accuracy 0.38, time per example 17.38 µs Epoch 1200, train loss 1.967E+00, test loss 2.192, training accuracy 0.41, test accuracy 0.37, time per example 17.61 µs Epoch 1400, train loss 1.972E+00, test loss 2.099, training accuracy 0.42, test accuracy 0.39, time per example 17.68 µs Epoch 1600, train loss 1.919E+00, test loss 2.155, training accuracy 0.43, test accuracy 0.38, time per example 17.81 µs Epoch 1800, train loss 1.938E+00, test loss 2.146, training accuracy 0.43, test accuracy 0.39, time per example 17.54 µs Epoch 2000, train loss 1.948E+00, test loss 2.110, training accuracy 0.43, test accuracy 0.40, time per example 18.01 µs Epoch 2200, train loss 1.978E+00, test loss 2.128, training accuracy 0.42, test accuracy 0.39, time per example 21.03 µs Epoch 2400, train loss 1.901E+00, test loss 2.104, training accuracy 0.44, test accuracy 0.39, time per example 18.00 µs Epoch 2600, train loss 1.926E+00, test loss 2.126, training accuracy 0.43, test accuracy 0.40, time per example 17.99 µs Epoch 2800, train loss 1.862E+00, test loss 2.117, training accuracy 0.47, test accuracy 0.38, time per example 17.66 µs Epoch 3000, train loss 1.865E+00, test loss 2.098, training accuracy 0.45, test accuracy 0.41, time per example 17.76 µs Epoch 3200, train loss 1.895E+00, test loss 2.087, training accuracy 0.44, test accuracy 0.42, time per example 17.57 µs Epoch 3400, train loss 1.872E+00, test loss 2.025, training accuracy 0.45, test accuracy 0.43, time per example 17.48 µs Epoch 3600, train loss 1.846E+00, test loss 2.057, training accuracy 0.45, test accuracy 0.41, time per example 18.18 µs Epoch 3800, train loss 1.904E+00, test loss 2.157, training accuracy 0.44, test accuracy 0.38, time per example 18.20 µs Epoch 4000, train loss 1.820E+00, test loss 2.060, training accuracy 0.46, test accuracy 0.41, time per example 17.69 µs Epoch 4200, train loss 1.847E+00, test loss 2.148, training accuracy 0.45, test accuracy 0.40, time per example 18.02 µs Epoch 4400, train loss 1.885E+00, test loss 2.066, training accuracy 0.44, test accuracy 0.42, time per example 17.57 µs Epoch 4600, train loss 1.877E+00, test loss 2.058, training accuracy 0.45, test accuracy 0.41, time per example 17.71 µs Epoch 4800, train loss 1.795E+00, test loss 2.097, training accuracy 0.47, test accuracy 0.40, time per example 17.78 µs Epoch 5000, train loss 1.871E+00, test loss 2.150, training accuracy 0.44, test accuracy 0.39, time per example 17.54 µs Epoch 5200, train loss 1.852E+00, test loss 2.071, training accuracy 0.47, test accuracy 0.40, time per example 17.56 µs Epoch 5400, train loss 1.853E+00, test loss 2.201, training accuracy 0.46, test accuracy 0.36, time per example 17.37 µs Epoch 5600, train loss 1.840E+00, test loss 2.063, training accuracy 0.46, test accuracy 0.41, time per example 17.35 µs Epoch 5800, train loss 1.844E+00, test loss 2.103, training accuracy 0.44, test accuracy 0.41, time per example 17.73 µs Epoch 6000, train loss 1.856E+00, test loss 2.212, training accuracy 0.46, test accuracy 0.39, time per example 17.63 µs Epoch 6200, train loss 1.804E+00, test loss 2.143, training accuracy 0.46, test accuracy 0.41, time per example 17.56 µs Epoch 6400, train loss 1.832E+00, test loss 2.145, training accuracy 0.45, test accuracy 0.39, time per example 17.54 µs Epoch 6600, train loss 1.790E+00, test loss 2.162, training accuracy 0.47, test accuracy 0.38, time per example 17.70 µs Epoch 6800, train loss 1.809E+00, test loss 2.100, training accuracy 0.46, test accuracy 0.41, time per example 17.56 µs Epoch 7000, train loss 1.799E+00, test loss 2.115, training accuracy 0.47, test accuracy 0.41, time per example 17.38 µs Epoch 7200, train loss 1.809E+00, test loss 2.154, training accuracy 0.46, test accuracy 0.39, time per example 18.01 µs Epoch 7400, train loss 1.874E+00, test loss 2.168, training accuracy 0.45, test accuracy 0.38, time per example 17.74 µs Epoch 7600, train loss 1.795E+00, test loss 2.068, training accuracy 0.46, test accuracy 0.42, time per example 17.78 µs Epoch 7800, train loss 1.831E+00, test loss 2.135, training accuracy 0.46, test accuracy 0.40, time per example 17.61 µs Time budget exceeded, hence stopping training Total training time 307.70 s
cfg = MLPForSeq2SeqConfig(
vocab_size=bd.vocab_size,
input_len=128,
n_embed=16,
n_hidden=64,
output_len=1)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-2,
weight_decay=1e-4,
epoch_interval=25,
time_budget_seconds=600,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
cfg, MLPForSeq2Seq(cfg), bd, train_config, block_size=cfg.input_len, is_y_single_token=True
)
Number of parameters: 136401 Epoch 0, train loss 4.149E+00, test loss 4.607, training accuracy 0.02, test accuracy 0.17, time per example 51.37 µs Epoch 25, train loss 3.201E+00, test loss 3.214, training accuracy 0.18, test accuracy 0.18, time per example 85.03 µs Epoch 50, train loss 2.968E+00, test loss 2.920, training accuracy 0.21, test accuracy 0.22, time per example 88.08 µs Epoch 75, train loss 2.727E+00, test loss 2.719, training accuracy 0.24, test accuracy 0.24, time per example 45.28 µs Epoch 100, train loss 2.569E+00, test loss 2.584, training accuracy 0.28, test accuracy 0.27, time per example 50.96 µs Epoch 125, train loss 2.512E+00, test loss 2.555, training accuracy 0.29, test accuracy 0.28, time per example 62.23 µs Epoch 150, train loss 2.437E+00, test loss 2.460, training accuracy 0.29, test accuracy 0.30, time per example 82.49 µs Epoch 175, train loss 2.460E+00, test loss 2.432, training accuracy 0.30, test accuracy 0.29, time per example 52.60 µs Epoch 200, train loss 2.372E+00, test loss 2.405, training accuracy 0.33, test accuracy 0.32, time per example 50.60 µs Epoch 225, train loss 2.392E+00, test loss 2.345, training accuracy 0.30, test accuracy 0.32, time per example 51.51 µs Epoch 250, train loss 2.376E+00, test loss 2.399, training accuracy 0.33, test accuracy 0.31, time per example 53.17 µs Epoch 275, train loss 2.361E+00, test loss 2.300, training accuracy 0.33, test accuracy 0.35, time per example 89.10 µs Epoch 300, train loss 2.300E+00, test loss 2.345, training accuracy 0.35, test accuracy 0.32, time per example 54.12 µs Epoch 325, train loss 2.310E+00, test loss 2.331, training accuracy 0.35, test accuracy 0.32, time per example 52.23 µs Epoch 350, train loss 2.251E+00, test loss 2.405, training accuracy 0.35, test accuracy 0.32, time per example 56.03 µs Epoch 375, train loss 2.283E+00, test loss 2.366, training accuracy 0.35, test accuracy 0.32, time per example 83.41 µs Epoch 400, train loss 2.235E+00, test loss 2.322, training accuracy 0.35, test accuracy 0.33, time per example 62.75 µs Epoch 425, train loss 2.297E+00, test loss 2.283, training accuracy 0.35, test accuracy 0.35, time per example 84.59 µs Epoch 450, train loss 2.267E+00, test loss 2.300, training accuracy 0.34, test accuracy 0.34, time per example 58.65 µs Epoch 475, train loss 2.259E+00, test loss 2.359, training accuracy 0.35, test accuracy 0.32, time per example 80.48 µs Epoch 500, train loss 2.256E+00, test loss 2.327, training accuracy 0.35, test accuracy 0.33, time per example 51.15 µs Epoch 525, train loss 2.210E+00, test loss 2.213, training accuracy 0.37, test accuracy 0.36, time per example 53.42 µs Epoch 550, train loss 2.243E+00, test loss 2.329, training accuracy 0.36, test accuracy 0.35, time per example 52.59 µs Epoch 575, train loss 2.205E+00, test loss 2.285, training accuracy 0.36, test accuracy 0.36, time per example 54.17 µs Epoch 600, train loss 2.220E+00, test loss 2.285, training accuracy 0.36, test accuracy 0.34, time per example 75.96 µs Epoch 625, train loss 2.194E+00, test loss 2.276, training accuracy 0.38, test accuracy 0.34, time per example 53.92 µs Epoch 650, train loss 2.209E+00, test loss 2.340, training accuracy 0.36, test accuracy 0.33, time per example 52.06 µs Epoch 675, train loss 2.155E+00, test loss 2.271, training accuracy 0.38, test accuracy 0.36, time per example 57.34 µs Epoch 700, train loss 2.185E+00, test loss 2.316, training accuracy 0.38, test accuracy 0.36, time per example 75.45 µs Epoch 725, train loss 2.133E+00, test loss 2.260, training accuracy 0.38, test accuracy 0.36, time per example 55.67 µs Epoch 750, train loss 2.135E+00, test loss 2.308, training accuracy 0.37, test accuracy 0.35, time per example 51.90 µs Epoch 775, train loss 2.186E+00, test loss 2.280, training accuracy 0.36, test accuracy 0.35, time per example 52.92 µs Epoch 800, train loss 2.158E+00, test loss 2.344, training accuracy 0.38, test accuracy 0.33, time per example 57.37 µs Epoch 825, train loss 2.176E+00, test loss 2.240, training accuracy 0.37, test accuracy 0.35, time per example 95.01 µs Epoch 850, train loss 2.102E+00, test loss 2.281, training accuracy 0.38, test accuracy 0.34, time per example 85.03 µs Epoch 875, train loss 2.121E+00, test loss 2.286, training accuracy 0.40, test accuracy 0.33, time per example 59.99 µs Epoch 900, train loss 2.132E+00, test loss 2.226, training accuracy 0.39, test accuracy 0.36, time per example 51.90 µs Epoch 925, train loss 2.161E+00, test loss 2.249, training accuracy 0.37, test accuracy 0.36, time per example 73.39 µs Epoch 950, train loss 2.121E+00, test loss 2.279, training accuracy 0.39, test accuracy 0.37, time per example 79.24 µs Epoch 975, train loss 2.193E+00, test loss 2.186, training accuracy 0.37, test accuracy 0.37, time per example 52.89 µs Epoch 1000, train loss 2.092E+00, test loss 2.274, training accuracy 0.40, test accuracy 0.35, time per example 52.59 µs Epoch 1025, train loss 2.105E+00, test loss 2.209, training accuracy 0.39, test accuracy 0.37, time per example 51.29 µs Epoch 1050, train loss 2.143E+00, test loss 2.237, training accuracy 0.38, test accuracy 0.35, time per example 83.31 µs Epoch 1075, train loss 2.122E+00, test loss 2.244, training accuracy 0.38, test accuracy 0.35, time per example 52.49 µs Epoch 1100, train loss 2.161E+00, test loss 2.269, training accuracy 0.38, test accuracy 0.35, time per example 52.72 µs Epoch 1125, train loss 2.085E+00, test loss 2.250, training accuracy 0.40, test accuracy 0.36, time per example 49.41 µs Epoch 1150, train loss 2.137E+00, test loss 2.284, training accuracy 0.37, test accuracy 0.35, time per example 48.66 µs Epoch 1175, train loss 2.104E+00, test loss 2.195, training accuracy 0.38, test accuracy 0.37, time per example 84.92 µs Epoch 1200, train loss 2.121E+00, test loss 2.253, training accuracy 0.40, test accuracy 0.36, time per example 48.26 µs Epoch 1225, train loss 2.103E+00, test loss 2.269, training accuracy 0.39, test accuracy 0.36, time per example 83.12 µs Epoch 1250, train loss 2.113E+00, test loss 2.240, training accuracy 0.40, test accuracy 0.36, time per example 44.94 µs Epoch 1275, train loss 2.059E+00, test loss 2.322, training accuracy 0.41, test accuracy 0.34, time per example 80.86 µs Epoch 1300, train loss 2.188E+00, test loss 2.246, training accuracy 0.36, test accuracy 0.35, time per example 52.28 µs Epoch 1325, train loss 2.114E+00, test loss 2.241, training accuracy 0.40, test accuracy 0.35, time per example 57.34 µs Epoch 1350, train loss 2.111E+00, test loss 2.221, training accuracy 0.38, test accuracy 0.36, time per example 57.76 µs Epoch 1375, train loss 2.105E+00, test loss 2.229, training accuracy 0.39, test accuracy 0.36, time per example 60.24 µs Epoch 1400, train loss 2.142E+00, test loss 2.230, training accuracy 0.37, test accuracy 0.37, time per example 76.50 µs Epoch 1425, train loss 2.067E+00, test loss 2.222, training accuracy 0.40, test accuracy 0.36, time per example 50.90 µs Epoch 1450, train loss 2.101E+00, test loss 2.256, training accuracy 0.37, test accuracy 0.35, time per example 59.67 µs Epoch 1475, train loss 2.119E+00, test loss 2.245, training accuracy 0.39, test accuracy 0.36, time per example 58.59 µs Epoch 1500, train loss 2.164E+00, test loss 2.272, training accuracy 0.37, test accuracy 0.34, time per example 83.59 µs Epoch 1525, train loss 2.091E+00, test loss 2.249, training accuracy 0.40, test accuracy 0.36, time per example 52.30 µs Epoch 1550, train loss 2.117E+00, test loss 2.156, training accuracy 0.39, test accuracy 0.37, time per example 65.28 µs Epoch 1575, train loss 2.106E+00, test loss 2.270, training accuracy 0.39, test accuracy 0.36, time per example 51.39 µs Epoch 1600, train loss 2.137E+00, test loss 2.280, training accuracy 0.39, test accuracy 0.35, time per example 53.02 µs Epoch 1625, train loss 2.126E+00, test loss 2.318, training accuracy 0.38, test accuracy 0.33, time per example 85.76 µs Epoch 1650, train loss 2.060E+00, test loss 2.220, training accuracy 0.40, test accuracy 0.36, time per example 83.03 µs Epoch 1675, train loss 2.048E+00, test loss 2.279, training accuracy 0.40, test accuracy 0.33, time per example 54.12 µs Epoch 1700, train loss 2.113E+00, test loss 2.218, training accuracy 0.39, test accuracy 0.37, time per example 46.88 µs Epoch 1725, train loss 2.091E+00, test loss 2.210, training accuracy 0.40, test accuracy 0.36, time per example 81.26 µs Epoch 1750, train loss 2.110E+00, test loss 2.257, training accuracy 0.38, test accuracy 0.36, time per example 59.73 µs Epoch 1775, train loss 2.062E+00, test loss 2.277, training accuracy 0.40, test accuracy 0.37, time per example 52.60 µs Epoch 1800, train loss 2.083E+00, test loss 2.197, training accuracy 0.39, test accuracy 0.37, time per example 51.79 µs Epoch 1825, train loss 2.072E+00, test loss 2.267, training accuracy 0.40, test accuracy 0.36, time per example 53.23 µs Epoch 1850, train loss 2.060E+00, test loss 2.199, training accuracy 0.42, test accuracy 0.37, time per example 77.29 µs Epoch 1875, train loss 2.038E+00, test loss 2.240, training accuracy 0.39, test accuracy 0.36, time per example 52.67 µs Epoch 1900, train loss 2.071E+00, test loss 2.282, training accuracy 0.40, test accuracy 0.36, time per example 58.44 µs Epoch 1925, train loss 2.116E+00, test loss 2.239, training accuracy 0.39, test accuracy 0.37, time per example 51.54 µs Epoch 1950, train loss 2.077E+00, test loss 2.214, training accuracy 0.40, test accuracy 0.36, time per example 90.86 µs Epoch 1975, train loss 2.069E+00, test loss 2.260, training accuracy 0.41, test accuracy 0.36, time per example 56.73 µs Epoch 2000, train loss 2.117E+00, test loss 2.180, training accuracy 0.39, test accuracy 0.38, time per example 54.01 µs Epoch 2025, train loss 2.118E+00, test loss 2.194, training accuracy 0.39, test accuracy 0.37, time per example 53.36 µs Epoch 2050, train loss 2.124E+00, test loss 2.220, training accuracy 0.40, test accuracy 0.38, time per example 89.01 µs Epoch 2075, train loss 2.061E+00, test loss 2.199, training accuracy 0.40, test accuracy 0.36, time per example 85.78 µs Epoch 2100, train loss 2.059E+00, test loss 2.146, training accuracy 0.40, test accuracy 0.39, time per example 54.12 µs Epoch 2125, train loss 2.068E+00, test loss 2.237, training accuracy 0.40, test accuracy 0.36, time per example 58.74 µs Epoch 2150, train loss 2.010E+00, test loss 2.205, training accuracy 0.41, test accuracy 0.36, time per example 212.56 µs Epoch 2175, train loss 2.060E+00, test loss 2.193, training accuracy 0.40, test accuracy 0.38, time per example 83.75 µs Epoch 2200, train loss 2.005E+00, test loss 2.242, training accuracy 0.41, test accuracy 0.36, time per example 51.65 µs Epoch 2225, train loss 2.006E+00, test loss 2.221, training accuracy 0.42, test accuracy 0.36, time per example 52.76 µs Epoch 2250, train loss 2.030E+00, test loss 2.235, training accuracy 0.41, test accuracy 0.36, time per example 51.50 µs Epoch 2275, train loss 2.104E+00, test loss 2.198, training accuracy 0.40, test accuracy 0.36, time per example 52.65 µs Epoch 2300, train loss 2.064E+00, test loss 2.186, training accuracy 0.39, test accuracy 0.36, time per example 82.65 µs Epoch 2325, train loss 1.985E+00, test loss 2.273, training accuracy 0.43, test accuracy 0.34, time per example 55.74 µs Epoch 2350, train loss 2.097E+00, test loss 2.284, training accuracy 0.39, test accuracy 0.35, time per example 58.31 µs Epoch 2375, train loss 2.031E+00, test loss 2.244, training accuracy 0.41, test accuracy 0.36, time per example 67.23 µs Epoch 2400, train loss 2.059E+00, test loss 2.237, training accuracy 0.40, test accuracy 0.36, time per example 89.22 µs Epoch 2425, train loss 2.065E+00, test loss 2.233, training accuracy 0.41, test accuracy 0.35, time per example 53.10 µs Epoch 2450, train loss 2.066E+00, test loss 2.263, training accuracy 0.40, test accuracy 0.36, time per example 70.09 µs Epoch 2475, train loss 2.038E+00, test loss 2.234, training accuracy 0.41, test accuracy 0.36, time per example 56.24 µs Epoch 2500, train loss 2.015E+00, test loss 2.195, training accuracy 0.41, test accuracy 0.38, time per example 72.66 µs Epoch 2525, train loss 2.027E+00, test loss 2.191, training accuracy 0.40, test accuracy 0.36, time per example 82.10 µs Epoch 2550, train loss 2.070E+00, test loss 2.207, training accuracy 0.41, test accuracy 0.36, time per example 52.18 µs Epoch 2575, train loss 2.043E+00, test loss 2.227, training accuracy 0.40, test accuracy 0.36, time per example 43.97 µs Epoch 2600, train loss 2.075E+00, test loss 2.157, training accuracy 0.40, test accuracy 0.39, time per example 51.23 µs Epoch 2625, train loss 2.021E+00, test loss 2.216, training accuracy 0.41, test accuracy 0.37, time per example 77.24 µs Epoch 2650, train loss 2.087E+00, test loss 2.256, training accuracy 0.38, test accuracy 0.36, time per example 44.60 µs Epoch 2675, train loss 2.064E+00, test loss 2.231, training accuracy 0.41, test accuracy 0.36, time per example 58.65 µs Epoch 2700, train loss 2.070E+00, test loss 2.240, training accuracy 0.40, test accuracy 0.35, time per example 53.43 µs Epoch 2725, train loss 2.090E+00, test loss 2.190, training accuracy 0.39, test accuracy 0.38, time per example 51.50 µs Epoch 2750, train loss 2.055E+00, test loss 2.208, training accuracy 0.41, test accuracy 0.37, time per example 81.13 µs Epoch 2775, train loss 1.991E+00, test loss 2.209, training accuracy 0.42, test accuracy 0.36, time per example 51.20 µs Epoch 2800, train loss 2.017E+00, test loss 2.253, training accuracy 0.42, test accuracy 0.38, time per example 44.74 µs Epoch 2825, train loss 2.000E+00, test loss 2.257, training accuracy 0.41, test accuracy 0.36, time per example 49.46 µs Epoch 2850, train loss 2.044E+00, test loss 2.245, training accuracy 0.41, test accuracy 0.37, time per example 47.12 µs Epoch 2875, train loss 2.035E+00, test loss 2.176, training accuracy 0.42, test accuracy 0.37, time per example 95.37 µs Epoch 2900, train loss 2.008E+00, test loss 2.221, training accuracy 0.43, test accuracy 0.35, time per example 52.43 µs Epoch 2925, train loss 2.068E+00, test loss 2.213, training accuracy 0.41, test accuracy 0.37, time per example 48.14 µs Epoch 2950, train loss 2.028E+00, test loss 2.276, training accuracy 0.42, test accuracy 0.35, time per example 63.92 µs Epoch 2975, train loss 2.044E+00, test loss 2.222, training accuracy 0.41, test accuracy 0.38, time per example 54.70 µs Epoch 3000, train loss 2.042E+00, test loss 2.272, training accuracy 0.40, test accuracy 0.36, time per example 75.18 µs Epoch 3025, train loss 2.025E+00, test loss 2.220, training accuracy 0.41, test accuracy 0.37, time per example 53.82 µs Epoch 3050, train loss 2.044E+00, test loss 2.180, training accuracy 0.40, test accuracy 0.36, time per example 51.23 µs Epoch 3075, train loss 2.041E+00, test loss 2.262, training accuracy 0.39, test accuracy 0.37, time per example 56.45 µs Epoch 3100, train loss 2.085E+00, test loss 2.246, training accuracy 0.39, test accuracy 0.36, time per example 104.63 µs Epoch 3125, train loss 2.065E+00, test loss 2.256, training accuracy 0.41, test accuracy 0.36, time per example 52.41 µs Epoch 3150, train loss 2.028E+00, test loss 2.169, training accuracy 0.41, test accuracy 0.37, time per example 50.84 µs Epoch 3175, train loss 2.023E+00, test loss 2.211, training accuracy 0.40, test accuracy 0.38, time per example 53.56 µs Epoch 3200, train loss 2.076E+00, test loss 2.223, training accuracy 0.40, test accuracy 0.36, time per example 65.03 µs Epoch 3225, train loss 2.005E+00, test loss 2.211, training accuracy 0.42, test accuracy 0.37, time per example 81.99 µs Epoch 3250, train loss 2.070E+00, test loss 2.209, training accuracy 0.39, test accuracy 0.36, time per example 52.90 µs Epoch 3275, train loss 2.061E+00, test loss 2.171, training accuracy 0.40, test accuracy 0.37, time per example 81.10 µs Epoch 3300, train loss 2.036E+00, test loss 2.243, training accuracy 0.39, test accuracy 0.36, time per example 50.79 µs Epoch 3325, train loss 1.998E+00, test loss 2.253, training accuracy 0.42, test accuracy 0.38, time per example 82.06 µs Epoch 3350, train loss 2.002E+00, test loss 2.291, training accuracy 0.42, test accuracy 0.35, time per example 47.06 µs Epoch 3375, train loss 1.991E+00, test loss 2.199, training accuracy 0.41, test accuracy 0.39, time per example 45.69 µs Epoch 3400, train loss 2.047E+00, test loss 2.242, training accuracy 0.40, test accuracy 0.37, time per example 45.34 µs Epoch 3425, train loss 2.023E+00, test loss 2.260, training accuracy 0.42, test accuracy 0.36, time per example 51.17 µs Epoch 3450, train loss 1.984E+00, test loss 2.218, training accuracy 0.42, test accuracy 0.37, time per example 81.51 µs Epoch 3475, train loss 1.973E+00, test loss 2.148, training accuracy 0.42, test accuracy 0.37, time per example 51.27 µs Epoch 3500, train loss 2.072E+00, test loss 2.197, training accuracy 0.39, test accuracy 0.36, time per example 53.87 µs Epoch 3525, train loss 2.063E+00, test loss 2.213, training accuracy 0.41, test accuracy 0.37, time per example 53.11 µs Epoch 3550, train loss 2.040E+00, test loss 2.226, training accuracy 0.40, test accuracy 0.37, time per example 82.27 µs Epoch 3575, train loss 1.959E+00, test loss 2.233, training accuracy 0.42, test accuracy 0.37, time per example 59.19 µs Epoch 3600, train loss 1.972E+00, test loss 2.270, training accuracy 0.43, test accuracy 0.36, time per example 50.42 µs Epoch 3625, train loss 2.041E+00, test loss 2.223, training accuracy 0.40, test accuracy 0.37, time per example 52.94 µs Epoch 3650, train loss 2.049E+00, test loss 2.208, training accuracy 0.40, test accuracy 0.37, time per example 58.56 µs Epoch 3675, train loss 2.011E+00, test loss 2.230, training accuracy 0.40, test accuracy 0.37, time per example 93.26 µs Epoch 3700, train loss 1.992E+00, test loss 2.214, training accuracy 0.42, test accuracy 0.36, time per example 88.48 µs Epoch 3725, train loss 2.044E+00, test loss 2.241, training accuracy 0.43, test accuracy 0.36, time per example 56.74 µs Epoch 3750, train loss 2.011E+00, test loss 2.162, training accuracy 0.42, test accuracy 0.38, time per example 58.43 µs Epoch 3775, train loss 1.965E+00, test loss 2.210, training accuracy 0.43, test accuracy 0.38, time per example 52.24 µs Epoch 3800, train loss 1.982E+00, test loss 2.208, training accuracy 0.42, test accuracy 0.37, time per example 82.27 µs Epoch 3825, train loss 2.039E+00, test loss 2.182, training accuracy 0.41, test accuracy 0.37, time per example 52.78 µs Epoch 3850, train loss 2.042E+00, test loss 2.307, training accuracy 0.41, test accuracy 0.36, time per example 55.92 µs Epoch 3875, train loss 1.999E+00, test loss 2.226, training accuracy 0.41, test accuracy 0.38, time per example 158.04 µs Epoch 3900, train loss 1.992E+00, test loss 2.299, training accuracy 0.42, test accuracy 0.36, time per example 56.08 µs Epoch 3925, train loss 2.075E+00, test loss 2.191, training accuracy 0.40, test accuracy 0.37, time per example 80.45 µs Epoch 3950, train loss 2.014E+00, test loss 2.221, training accuracy 0.42, test accuracy 0.37, time per example 52.09 µs Epoch 3975, train loss 2.029E+00, test loss 2.260, training accuracy 0.42, test accuracy 0.35, time per example 51.94 µs Epoch 4000, train loss 1.946E+00, test loss 2.207, training accuracy 0.44, test accuracy 0.37, time per example 51.97 µs Epoch 4025, train loss 2.004E+00, test loss 2.247, training accuracy 0.41, test accuracy 0.35, time per example 85.34 µs Epoch 4050, train loss 2.007E+00, test loss 2.161, training accuracy 0.41, test accuracy 0.38, time per example 51.78 µs Epoch 4075, train loss 2.014E+00, test loss 2.179, training accuracy 0.41, test accuracy 0.38, time per example 83.09 µs Epoch 4100, train loss 2.018E+00, test loss 2.202, training accuracy 0.41, test accuracy 0.36, time per example 61.74 µs Epoch 4125, train loss 2.014E+00, test loss 2.212, training accuracy 0.43, test accuracy 0.38, time per example 86.39 µs Epoch 4150, train loss 1.993E+00, test loss 2.190, training accuracy 0.42, test accuracy 0.38, time per example 157.08 µs Epoch 4175, train loss 1.963E+00, test loss 2.199, training accuracy 0.44, test accuracy 0.39, time per example 59.74 µs Epoch 4200, train loss 2.027E+00, test loss 2.252, training accuracy 0.41, test accuracy 0.37, time per example 52.88 µs Epoch 4225, train loss 1.978E+00, test loss 2.256, training accuracy 0.42, test accuracy 0.37, time per example 50.34 µs Epoch 4250, train loss 1.966E+00, test loss 2.182, training accuracy 0.42, test accuracy 0.38, time per example 81.53 µs Epoch 4275, train loss 2.010E+00, test loss 2.202, training accuracy 0.42, test accuracy 0.38, time per example 51.86 µs Epoch 4300, train loss 2.054E+00, test loss 2.258, training accuracy 0.40, test accuracy 0.36, time per example 52.43 µs Epoch 4325, train loss 1.994E+00, test loss 2.180, training accuracy 0.41, test accuracy 0.39, time per example 52.35 µs Epoch 4350, train loss 2.020E+00, test loss 2.241, training accuracy 0.42, test accuracy 0.38, time per example 72.38 µs Epoch 4375, train loss 2.018E+00, test loss 2.247, training accuracy 0.39, test accuracy 0.36, time per example 78.79 µs Epoch 4400, train loss 2.052E+00, test loss 2.308, training accuracy 0.40, test accuracy 0.37, time per example 54.31 µs Epoch 4425, train loss 1.978E+00, test loss 2.165, training accuracy 0.42, test accuracy 0.39, time per example 160.59 µs Epoch 4450, train loss 2.023E+00, test loss 2.202, training accuracy 0.41, test accuracy 0.36, time per example 60.88 µs Time budget exceeded, hence stopping training Total training time 600.99 s
model_config = GPTConfig(
block_size=128,
vocab_size=65,
n_layer=4, n_head=4, n_embd=64, dropout=0.1
)
train_config = TrainConfig(
epochs=20000,
train_batch_size=2048,
lr=1e-2,
weight_decay=1e-4,
epoch_interval=50,
time_budget_seconds=300,
)
bard_results[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval_bard(
model_config, GPT(model_config), bd, train_config, block_size=128, is_y_single_token=False, is_nano_gpt=True
)
cfg2 = dataclasses.replace(model_config, n_embd=96)
bard_results[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval_bard(
cfg2, GPT(cfg2), bd, train_config, block_size=128, is_y_single_token=False, is_nano_gpt=True
)
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 number of parameters: 0.20M Number of parameters: 212416 Epoch 0, train loss 4.193E+00, test loss 3.763, training accuracy 0.01, test accuracy 0.01, time per example 136.27 µs Epoch 50, train loss 3.315E+00, test loss 3.315, training accuracy 0.01, test accuracy 0.01, time per example 126.46 µs Epoch 100, train loss 3.312E+00, test loss 3.312, training accuracy 0.01, test accuracy 0.01, time per example 125.70 µs Epoch 150, train loss 3.308E+00, test loss 3.307, training accuracy 0.01, test accuracy 0.00, time per example 126.57 µs Epoch 200, train loss 3.117E+00, test loss 3.116, training accuracy 0.01, test accuracy 0.01, time per example 125.18 µs Epoch 250, train loss 2.947E+00, test loss 2.941, training accuracy 0.01, test accuracy 0.01, time per example 125.28 µs Epoch 300, train loss 2.744E+00, test loss 2.741, training accuracy 0.01, test accuracy 0.01, time per example 125.31 µs Epoch 350, train loss 2.626E+00, test loss 2.625, training accuracy 0.01, test accuracy 0.01, time per example 125.22 µs Epoch 400, train loss 2.563E+00, test loss 2.569, training accuracy 0.01, test accuracy 0.01, time per example 125.77 µs Epoch 450, train loss 2.500E+00, test loss 2.500, training accuracy 0.01, test accuracy 0.01, time per example 125.50 µs Epoch 500, train loss 2.511E+00, test loss 2.509, training accuracy 0.01, test accuracy 0.01, time per example 125.82 µs Epoch 550, train loss 2.452E+00, test loss 2.452, training accuracy 0.01, test accuracy 0.01, time per example 125.27 µs Epoch 600, train loss 2.418E+00, test loss 2.415, training accuracy 0.01, test accuracy 0.01, time per example 126.10 µs Epoch 650, train loss 2.378E+00, test loss 2.378, training accuracy 0.01, test accuracy 0.00, time per example 128.21 µs Epoch 700, train loss 2.344E+00, test loss 2.345, training accuracy 0.01, test accuracy 0.01, time per example 125.52 µs Epoch 750, train loss 2.308E+00, test loss 2.306, training accuracy 0.01, test accuracy 0.01, time per example 125.64 µs Epoch 800, train loss 2.281E+00, test loss 2.270, training accuracy 0.01, test accuracy 0.01, time per example 125.77 µs Epoch 850, train loss 2.267E+00, test loss 2.261, training accuracy 0.01, test accuracy 0.01, time per example 128.29 µs Epoch 900, train loss 2.158E+00, test loss 2.147, training accuracy 0.01, test accuracy 0.01, time per example 130.54 µs Epoch 950, train loss 2.062E+00, test loss 2.065, training accuracy 0.01, test accuracy 0.01, time per example 125.06 µs Epoch 1000, train loss 1.994E+00, test loss 1.992, training accuracy 0.01, test accuracy 0.01, time per example 125.06 µs Epoch 1050, train loss 1.921E+00, test loss 1.920, training accuracy 0.01, test accuracy 0.01, time per example 125.04 µs Epoch 1100, train loss 1.879E+00, test loss 1.875, training accuracy 0.01, test accuracy 0.01, time per example 125.37 µs Epoch 1150, train loss 1.815E+00, test loss 1.812, training accuracy 0.01, test accuracy 0.01, time per example 125.81 µs Time budget exceeded, hence stopping training Total training time 302.89 s WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0 number of parameters: 0.45M Number of parameters: 466080 Epoch 0, train loss 4.188E+00, test loss 3.797, training accuracy 0.01, test accuracy 0.03, time per example 149.12 µs Epoch 50, train loss 2.971E+00, test loss 2.969, training accuracy 0.01, test accuracy 0.01, time per example 147.15 µs Epoch 100, train loss 2.744E+00, test loss 2.729, training accuracy 0.01, test accuracy 0.01, time per example 147.11 µs Epoch 150, train loss 2.595E+00, test loss 2.591, training accuracy 0.00, test accuracy 0.00, time per example 147.27 µs Epoch 200, train loss 2.510E+00, test loss 2.511, training accuracy 0.01, test accuracy 0.01, time per example 147.33 µs Epoch 250, train loss 2.402E+00, test loss 2.398, training accuracy 0.00, test accuracy 0.00, time per example 147.21 µs Epoch 300, train loss 2.286E+00, test loss 2.287, training accuracy 0.01, test accuracy 0.01, time per example 147.28 µs Epoch 350, train loss 2.198E+00, test loss 2.193, training accuracy 0.01, test accuracy 0.01, time per example 147.10 µs Epoch 400, train loss 2.101E+00, test loss 2.095, training accuracy 0.01, test accuracy 0.01, time per example 146.94 µs Epoch 450, train loss 2.008E+00, test loss 2.009, training accuracy 0.01, test accuracy 0.01, time per example 147.14 µs Epoch 500, train loss 1.924E+00, test loss 1.916, training accuracy 0.01, test accuracy 0.01, time per example 147.40 µs Epoch 550, train loss 1.847E+00, test loss 1.859, training accuracy 0.01, test accuracy 0.01, time per example 147.47 µs Epoch 600, train loss 1.769E+00, test loss 1.761, training accuracy 0.01, test accuracy 0.01, time per example 146.95 µs Epoch 650, train loss 1.706E+00, test loss 1.701, training accuracy 0.01, test accuracy 0.01, time per example 147.16 µs Epoch 700, train loss 1.658E+00, test loss 1.654, training accuracy 0.01, test accuracy 0.00, time per example 146.91 µs Epoch 750, train loss 1.623E+00, test loss 1.620, training accuracy 0.01, test accuracy 0.01, time per example 147.35 µs Epoch 800, train loss 1.591E+00, test loss 1.587, training accuracy 0.01, test accuracy 0.01, time per example 147.44 µs Epoch 850, train loss 1.555E+00, test loss 1.560, training accuracy 0.01, test accuracy 0.01, time per example 147.65 µs Epoch 900, train loss 1.521E+00, test loss 1.518, training accuracy 0.01, test accuracy 0.01, time per example 147.25 µs Epoch 950, train loss 1.489E+00, test loss 1.489, training accuracy 0.00, test accuracy 0.01, time per example 147.07 µs Epoch 1000, train loss 1.474E+00, test loss 1.470, training accuracy 0.01, test accuracy 0.00, time per example 147.37 µs Time budget exceeded, hence stopping training Total training time 308.73 s
sb_results = []
for key, value in bard_results.items():
sb_results.append((key[0], f"{value.train_config.lr: .1e}", value.train_loss, value.test_loss, value.num_parameters, value.time_per_example_in_micros, value.train_time_in_seconds))
import pandas as pd
pd.DataFrame.from_records(sb_results,
columns=["model", "Learning Rate", "train_loss", "test_loss",
"num_parameters", "time_per_example_in_micros", "train_time_in_seconds"]
).round(3)
model | Learning Rate | train_loss | test_loss | num_parameters | time_per_example_in_micros | train_time_in_seconds | |
---|---|---|---|---|---|---|---|
0 | MLP | 1.0e-02 | 1.848 | 2.182 | 542289 | 19.452 | 302.810 |
1 | MLP | 1.0e-02 | 1.831 | 2.135 | 140105 | 19.259 | 307.696 |
2 | TLTransformer | 1.0e-02 | 1.236 | 1.613 | 216641 | 101.261 | 300.808 |
3 | TLTransformer | 1.0e-02 | 1.203 | 1.648 | 472385 | 131.621 | 303.390 |
4 | NanoGPT | 1.0e-02 | 1.815 | 1.812 | 212416 | 128.491 | 302.885 |
5 | NanoGPT | 1.0e-02 | 1.474 | 1.470 | 466080 | 150.595 | 308.727 |