Open In Colab

Evaluating Transformer implementations for small-scale applications¶

In the course of learning about Transformers for text, I came across multiple different implementations, all purportedly built on the publicly available details of OpenAI's GPT-2. So I decided to setup a testbed to compare the different implementations. I also decided to add a simple non-transformer baseline, and I chose a simple 2-layer feed-forward MLP. To be clear, we're not comparing pre-trained models, but rather we're comparing different implementations of the model (i.e. how would these different implementations perform when configured similarly, and trained on the same data, in terms of accuracy and speed).

Methods being compared¶

  1. HuggingFaceGPT: HuggingFace's GPT implementation (to be clear, we are not using the pre-trained weights or architecture)
  2. NanoGPT: Andrej Karpathy's implementation
  3. TLTransformer: GPT implementation based on Neel Nanda's lectures, referred to as "TLTransformer" based on his package TransformerLens. Note that a cleanly reusable public implementation of this does not exist, hence I needed to reimplement this with some copy-pasting here, in this notebook. The code is fairly small, so this did not pose a problem.
  4. 2-layer MLP: A simple MLP, see implementation in the Section named "MLP" below.

Test beds¶

  1. n-digit addition: Specifically, here I used 3-digit addition problems of the sort '272+926=1090'. I have also run tests on subtraction and using more digits than just 3; I have omitted those results here since the general patterns observed here have held there as well.
  2. Character-level modeling of Shakespeare's poems.

Results¶

  1. The most surprising result here is that for n-digit addition (and other math problems as well), MLP significantly outperforms all transformer implementations. It trains much faster, gets perfect accuracy, and running inference is also much faster.
  2. For n-digit addition, among the Transformer implementations, TLTransformer does the best - it both gets near perfect accuracy as well as trains faster, compared to HuggingFaceGPT and NanoGPT.
  3. For Shakespeare data, however, MLP is unable to do well (phew!). No matter how many parameters it is configured with, the test loss obtained by MLP is clearly much worse than what's achieved by either of the Transformer implementations.
  4. For Shakespeare data, there's roughly similar performance between TLTransformer and NanoGPT.

Takeaways¶

If you care about having something simple which you understand in its entirety, I would say TLTransformer is simpler than NanoGPT. The differences in performance between the different implementations probably don't matter in the big picture. If you have a really simple application, consider just throwing in a MLP as a baseline - you may be surprised by its performance.

Table of Contents¶

  1. Setup
  2. N-digit math problems
    1. Common code for training models and storing the results
    2. Huggingface's GPT implementation
    3. MLP
    4. NanoGPT
    5. TLTransformer
    6. Results for n-digit addition
  3. Shakespeare Data
    1. Using TL Transformer
    2. Using MLP
    3. Using NanoGPT
    4. Results

Setup¶

In [ ]:
!nvidia-smi
Mon Mar 13 06:16:43 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0    53W / 400W |   8105MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
In [ ]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install einops
  %pip install transformers
  %pip install fancy_einsum
  %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
  IN_COLAB = False

import torch
import torch.nn as nn
import time
import torch.nn.functional as F
from typing import Any
import random
from transformers import OpenAIGPTConfig, OpenAIGPTModel
from transformer_lens.utils import gelu_new
from dataclasses import dataclass
import dataclasses
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
In [ ]:
import sys
if IN_COLAB:
   !git clone https://github.com/karpathy/nanoGPT.git
   sys.path.insert(0, 'nanoGPT')
else: 
    sys.path.insert(0, '../../nanoGPT')
from model import GPTConfig, GPT

N-digit math problems¶

In [ ]:
@dataclass
class DataConfig:
    reverse: bool = False
    upper_bound: int = 999
    each_digit_of_result_separate: bool = False
    lhs_number_width: int = len(str(upper_bound))
    rhs_number_width: int = len(str(upper_bound))+1
    num_plus_examples: int = int(3e5)
    num_minus_examples: int = int(3e5)
    test_size: int = 4096

@dataclass
class GeneratedData:
    all_strs: set[str]
    train_x: torch.Tensor
    train_y: torch.Tensor
    test_x: torch.Tensor
    test_y: torch.Tensor
    
In [ ]:
# m stands for unary negative, | indicates end of sequence. 
# "." is reserved for potential future use, but currently unused. 
chars = '0123456789.=+-m|'
end_of_sequence = "|"
encoder = dict((c, i) for i, c in enumerate(chars))
decoder = dict((i, c) for i, c in enumerate(chars))

def reverse_str(s):
    return s[::-1]

def stringify_problem(x, y, math_symbol, reverse=False, lhs_number_width=3, rhs_number_width=4):
    if math_symbol == "+":
        rhs = f"{x+y:0{rhs_number_width}}"
    elif math_symbol == "-":
        if x < y:
            # since we're adding a unary negative, we need to subtract 1 from the width
            if len(str(y-x)) > rhs_number_width - 1:
                raise ValueError(f"{x} minus {y} doesn't fit in {rhs_number_width} digits")
            rhs = "m" + f"{y-x:0{rhs_number_width-1}}" 
        else:
            rhs = f"{x-y:0{rhs_number_width}}"
    else:
        raise ValueError(f"Unsupported math symbol {math_symbol}")
    if reverse:
        rhs = reverse_str(rhs)
    # pad z with |s on the right, upto max_len
    padded_rhs = rhs + end_of_sequence + end_of_sequence * (rhs_number_width - len(rhs))
    str_x = f"{x:0{lhs_number_width}}"
    str_y = f"{y:0{lhs_number_width}}"
    if reverse:
        lhs = reverse_str(str_x) + math_symbol + reverse_str(str_y) + "="
    else:
        lhs = str_x + math_symbol + str_y + "="
    return lhs + padded_rhs

def generate_math_example(cfg: DataConfig, math_symbol = "+"):
    x = random.randint(0, cfg.upper_bound)
    y = random.randint(0, cfg.upper_bound)
    return stringify_problem(x, y, math_symbol, cfg.reverse, cfg.lhs_number_width, cfg.rhs_number_width)

def tensorify_example(example):
    return torch.tensor([encoder[c] for c in example])

## Let's create a function for taking X and Y, and including the digits of Y in X
## This function needs to ensure all the rows of X have the same length, hence
## it's going to be padding with the appropriate number of zeros on the left. 
def include_digits_of_y_in_x(X, Y):
    n = Y.shape[1]
    # We need to create X0, X1, ..., Xn-1
    Xs = []
    Ys = []
    for i in range(n):
        Xs.append(torch.cat([torch.zeros(X.shape[0], n - 1 - i, dtype=torch.long), X, Y[:, :i]], dim=1))
        Ys.append(Y[:, i])
    X = torch.cat(Xs, dim=0)
    Y = torch.cat(Ys, dim=0)
    return X, Y

def generate_data(cfg: DataConfig) -> GeneratedData:
  plus_strs = set([generate_math_example(cfg, math_symbol = "+") for _ in range(cfg.num_plus_examples)])
  minus_strs = set([generate_math_example(cfg, math_symbol = "-") for _ in range(cfg.num_minus_examples)])
  all_strs = plus_strs.union(minus_strs)

  all_examples = torch.stack([tensorify_example(i) for i in all_strs])
  assert cfg.test_size < all_examples.shape[0] * 0.9, "Test size requested is more than 90% of all data"
  train_size = all_examples.shape[0] - cfg.test_size
  rhs_len = cfg.rhs_number_width + 1
  train_x = all_examples[:train_size, :-rhs_len]
  train_y = all_examples[:train_size, -rhs_len:]
  test_x = all_examples[train_size:, :-rhs_len]
  test_y = all_examples[train_size:, -rhs_len:]

  if cfg.each_digit_of_result_separate:
    train_x, train_y = include_digits_of_y_in_x(train_x, train_y)
    test_x, test_y = include_digits_of_y_in_x(test_x, test_y)

  return GeneratedData(all_strs=all_strs, train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y)
In [ ]:
## Let's make sure the include_digits_of_y_in_x function works as expected
def test_for_include_digits_of_y_in_x():
    X = tensorify_example("123+456=").unsqueeze(dim=0)
    Y = tensorify_example("579" + end_of_sequence).unsqueeze(dim=0)
    X, Y = include_digits_of_y_in_x(X, Y)
    assert torch.equal(X[0], tensorify_example("000123+456="))
    assert torch.equal(Y[0], tensorify_example("5").squeeze())
    assert torch.equal(X[1] , tensorify_example("00123+456=5"))
    assert torch.equal(Y[1] , tensorify_example("7").squeeze())
    assert torch.equal(X[2] , tensorify_example("0123+456=57"))
    assert torch.equal(Y[2] , tensorify_example("9").squeeze())
    assert torch.equal(X[3] , tensorify_example("123+456=579"))
    assert torch.equal(Y[3] , tensorify_example(end_of_sequence).squeeze())

test_for_include_digits_of_y_in_x()
In [ ]:
stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=True)
Out[ ]:
'321+654=9750|'
In [ ]:
def test_for_stringify_problem():
    assert stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=False) == "123+456=0579|"
    assert stringify_problem(123, 999, "+", lhs_number_width=5, rhs_number_width=4, reverse=False) == "00123+00999=1122|"
    assert stringify_problem(123, 999, "-", lhs_number_width=3, rhs_number_width=4, reverse=False) == "123-999=m876|"
    assert stringify_problem(999, 123, "-", lhs_number_width=3, rhs_number_width=4, reverse=False) == "999-123=0876|"
    assert stringify_problem(123, 456, "+", lhs_number_width=3, rhs_number_width=4, reverse=True) == "321+654=9750|"
    assert stringify_problem(23, 45, "+", lhs_number_width=3, rhs_number_width=4, reverse=True) == "320+540=8600|"
    assert stringify_problem(123, 999, "-", lhs_number_width=3, rhs_number_width=4, reverse=True) == "321-999=678m|"

test_for_stringify_problem()
In [ ]:
reversed_add_3digit_separated_cfg = DataConfig(
    each_digit_of_result_separate=True, upper_bound=999, 
    num_plus_examples=int(1e5), num_minus_examples=0, test_size=4096, reverse=True,
)
reversed_add_3digit_separated = generate_data(reversed_add_3digit_separated_cfg)
In [ ]:
list(reversed_add_3digit_separated.all_strs)[0]
Out[ ]:
'272+926=1090|'

Common code for training models and storing the results¶

In [ ]:
results_dict = {}
In [ ]:
@dataclass
class TrainConfig:
    epochs: int = 1000
    train_batch_size: int = 128
    lr: float = 1e-3
    weight_decay: float = 1e-4
    epoch_interval: int = 100
    time_budget_seconds: int = 120
In [ ]:
@dataclass
class ResultRow:
    model_config: Any
    train_config: TrainConfig
    num_parameters: int
    epochs: int
    train_loss: float
    train_accuracy: float
    test_loss: float
    test_accuracy: float
    train_time_in_seconds: float
    time_per_example_in_micros: float
    train_losses: dict[int, float]
    train_accuracies: dict[int, float]
    test_losses: dict[int, float]
    test_accuracies: dict[int, float]
In [ ]:
def train_and_eval(model_config: Any, m: nn.Module, data: GeneratedData, train_config: TrainConfig):
    m = m.to(device)
    num_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Number of parameters: {num_params}")

    optimizer = torch.optim.AdamW(m.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
    train_x = data.train_x.to(device)
    train_y = data.train_y.to(device)
    test_x = data.test_x.to(device)
    test_y = data.test_y.to(device)

    training_losses = {}
    test_losses = {}
    training_accuracies = {}
    test_accuracies = {}

    outer_start = time.time()
    ep = 0
    while ep < train_config.epochs:
        start = time.time()
        optimizer.zero_grad()
        rind = torch.randint(0, train_x.shape[0], (train_config.train_batch_size,))
        X = train_x[rind]
        Y = train_y[rind]
        output = m(X)
        if type(output) is tuple:
            # some models output other stuff besides the logits, let's just use the logits
            logits = output[0][:, -1, :] #get logits for last token
        else:
            logits = output[:, -1, :] #get logits for last token
        loss = F.cross_entropy(logits, Y)
        if ep % train_config.epoch_interval == 0:
            preds = torch.argmax(logits, dim=-1)
            training_losses[ep] = loss.item()
            training_accuracies[ep] = torch.sum(preds == Y).item() / preds.shape[0]
        
        loss.backward()
        optimizer.step()
        elapsed = time.time() - start

        if ep % train_config.epoch_interval == 0:
            with torch.no_grad():
                #calculate test loss
                output = m(test_x)
                if type(output) is tuple:
                    # some models output other stuff besides the logits, let's just use the logits
                    test_logits = output[0][:, -1, :]
                else:
                    test_logits = output[:, -1, :]
                test_loss = F.cross_entropy(test_logits, test_y)
                test_preds = torch.argmax(test_logits, dim=-1)

                test_losses[ep] = test_loss.item()
                test_accuracies[ep] = torch.sum(test_preds == test_y).item() / test_preds.shape[0]
                print(f"Epoch {ep}, train loss {training_losses[ep]: .3E}, test loss {test_losses[ep]: .3f}, " +
                    f"training accuracy {training_accuracies[ep]: .2f}, test accuracy {test_accuracies[ep]: .2f}, " +
                    f"time per example {elapsed * 1e6 / train_config.train_batch_size: .2f} µs")
                if time.time() - outer_start > train_config.time_budget_seconds:
                    print("Time budget exceeded, hence stopping training")
                    break
                if test_accuracies[ep] > 0.995:
                    print("Test accuracy > 99.5%, hence stopping training")
                    break
        ep += 1

    if len(training_losses) is None or len(training_accuracies) is None:
        raise RuntimeError("Training did not run at all")
    if len(test_losses) is None or len(test_accuracies) is None:
        raise RuntimeError("Tests did not run at all")
    
    total_elapsed = time.time() - outer_start
    print(f"Total training time {total_elapsed: .2f} s")
    result_row = ResultRow(
        model_config=model_config,
        train_config=train_config,
        num_parameters=num_params, 
        epochs=ep+1, 
        train_loss=training_losses[max(training_losses.keys())], 
        train_accuracy=training_accuracies[max(training_accuracies.keys())],
        test_loss=test_losses[max(test_losses.keys())],
        test_accuracy=test_accuracies[max(test_accuracies.keys())],
        train_time_in_seconds=total_elapsed,
        time_per_example_in_micros=total_elapsed * 1e6 / ((ep + 1) * train_config.train_batch_size),
        train_losses=training_losses,
        train_accuracies=training_accuracies,
        test_losses=test_losses,
        test_accuracies=test_accuracies,
    )
    return result_row

HuggingFace's GPT implementation¶

In [ ]:
## create wrapper around OpenAIGPTModel, so we get logits and not the last state. 
class OpenAIGPTLMHeadModel(OpenAIGPTModel):
    def __init__(self, config):
        super().__init__(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        #self.apply(self.init_weights)

    def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
        transformer_outputs = super().forward(input_ids, position_ids, token_type_ids, past, head_mask)
        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)
        return lm_logits
In [ ]:
train_config = TrainConfig(
    epochs=10000,
    train_batch_size=2048,
    lr=1e-3,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=60,
)
#train_config.(lr=1e-2)

dataclasses.replace(train_config, lr=1e-2)
Out[ ]:
TrainConfig(epochs=10000, train_batch_size=2048, lr=0.01, weight_decay=0.0001, epoch_interval=50, time_budget_seconds=60)
In [ ]:
model_config = OpenAIGPTConfig(
        vocab_size=len(chars), 
        n_positions=reversed_add_3digit_separated.train_x.shape[1], 
        n_embd=16, 
        n_layer=4, 
        n_head=4,
    )
train_config = TrainConfig(
    epochs=10000,
    train_batch_size=2048,
    lr=1e-3,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=60,
)
results_dict[
    ("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
    model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated, train_config
)

results_dict[
    ("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
    model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-4)
)

results_dict[
    ("HuggingFaceGPT", f"run_{int(time.time())}t")
] = train_and_eval(
    model_config, OpenAIGPTLMHeadModel(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 13824
Epoch 0, train loss  3.035E+00, test loss  2.953, training accuracy  0.03, test accuracy  0.05, time per example  23.54 µs
Epoch 50, train loss  2.088E+00, test loss  2.081, training accuracy  0.36, test accuracy  0.37, time per example  13.81 µs
Epoch 100, train loss  1.771E+00, test loss  1.738, training accuracy  0.38, test accuracy  0.40, time per example  13.90 µs
Epoch 150, train loss  1.578E+00, test loss  1.550, training accuracy  0.42, test accuracy  0.43, time per example  15.00 µs
Epoch 200, train loss  1.461E+00, test loss  1.449, training accuracy  0.44, test accuracy  0.46, time per example  14.44 µs
Epoch 250, train loss  1.412E+00, test loss  1.424, training accuracy  0.47, test accuracy  0.46, time per example  15.62 µs
Epoch 300, train loss  1.402E+00, test loss  1.412, training accuracy  0.46, test accuracy  0.46, time per example  13.55 µs
Epoch 350, train loss  1.366E+00, test loss  1.405, training accuracy  0.47, test accuracy  0.46, time per example  14.52 µs
Epoch 400, train loss  1.382E+00, test loss  1.399, training accuracy  0.47, test accuracy  0.46, time per example  13.42 µs
Epoch 450, train loss  1.406E+00, test loss  1.397, training accuracy  0.45, test accuracy  0.46, time per example  13.53 µs
Epoch 500, train loss  1.350E+00, test loss  1.396, training accuracy  0.48, test accuracy  0.46, time per example  13.64 µs
Epoch 550, train loss  1.383E+00, test loss  1.392, training accuracy  0.47, test accuracy  0.46, time per example  13.87 µs
Epoch 600, train loss  1.436E+00, test loss  1.389, training accuracy  0.43, test accuracy  0.46, time per example  13.49 µs
Epoch 650, train loss  1.385E+00, test loss  1.391, training accuracy  0.46, test accuracy  0.46, time per example  14.66 µs
Epoch 700, train loss  1.375E+00, test loss  1.389, training accuracy  0.47, test accuracy  0.46, time per example  15.45 µs
Epoch 750, train loss  1.349E+00, test loss  1.389, training accuracy  0.47, test accuracy  0.46, time per example  13.62 µs
Epoch 800, train loss  1.389E+00, test loss  1.387, training accuracy  0.47, test accuracy  0.46, time per example  13.88 µs
Epoch 850, train loss  1.384E+00, test loss  1.387, training accuracy  0.46, test accuracy  0.46, time per example  13.67 µs
Epoch 900, train loss  1.389E+00, test loss  1.386, training accuracy  0.46, test accuracy  0.46, time per example  13.87 µs
Epoch 950, train loss  1.382E+00, test loss  1.385, training accuracy  0.46, test accuracy  0.46, time per example  15.09 µs
Epoch 1000, train loss  1.381E+00, test loss  1.387, training accuracy  0.46, test accuracy  0.46, time per example  14.85 µs
Epoch 1050, train loss  1.369E+00, test loss  1.385, training accuracy  0.47, test accuracy  0.46, time per example  15.62 µs
Epoch 1100, train loss  1.368E+00, test loss  1.385, training accuracy  0.48, test accuracy  0.46, time per example  14.98 µs
Epoch 1150, train loss  1.400E+00, test loss  1.385, training accuracy  0.46, test accuracy  0.46, time per example  15.74 µs
Epoch 1200, train loss  1.410E+00, test loss  1.385, training accuracy  0.45, test accuracy  0.46, time per example  15.64 µs
Epoch 1250, train loss  1.364E+00, test loss  1.383, training accuracy  0.48, test accuracy  0.46, time per example  13.96 µs
Epoch 1300, train loss  1.348E+00, test loss  1.376, training accuracy  0.47, test accuracy  0.46, time per example  13.51 µs
Epoch 1350, train loss  1.402E+00, test loss  1.369, training accuracy  0.46, test accuracy  0.47, time per example  13.80 µs
Epoch 1400, train loss  1.404E+00, test loss  1.325, training accuracy  0.45, test accuracy  0.49, time per example  13.73 µs
Epoch 1450, train loss  1.113E+00, test loss  1.132, training accuracy  0.56, test accuracy  0.55, time per example  13.81 µs
Epoch 1500, train loss  8.473E-01, test loss  0.834, training accuracy  0.65, test accuracy  0.65, time per example  13.79 µs
Epoch 1550, train loss  7.404E-01, test loss  0.717, training accuracy  0.69, test accuracy  0.70, time per example  14.00 µs
Epoch 1600, train loss  6.300E-01, test loss  0.631, training accuracy  0.74, test accuracy  0.74, time per example  14.94 µs
Epoch 1650, train loss  5.645E-01, test loss  0.564, training accuracy  0.77, test accuracy  0.77, time per example  15.62 µs
Epoch 1700, train loss  5.171E-01, test loss  0.520, training accuracy  0.80, test accuracy  0.78, time per example  14.15 µs
Epoch 1750, train loss  4.675E-01, test loss  0.467, training accuracy  0.80, test accuracy  0.81, time per example  13.81 µs
Epoch 1800, train loss  4.360E-01, test loss  0.449, training accuracy  0.83, test accuracy  0.82, time per example  14.00 µs
Epoch 1850, train loss  4.382E-01, test loss  0.424, training accuracy  0.82, test accuracy  0.83, time per example  13.74 µs
Epoch 1900, train loss  3.987E-01, test loss  0.388, training accuracy  0.85, test accuracy  0.84, time per example  13.86 µs
Epoch 1950, train loss  3.442E-01, test loss  0.367, training accuracy  0.87, test accuracy  0.85, time per example  13.94 µs
Time budget exceeded, hence stopping training
Total training time  60.25 s
Number of parameters: 13824
Epoch 0, train loss  2.870E+00, test loss  2.859, training accuracy  0.06, test accuracy  0.06, time per example  14.24 µs
Epoch 50, train loss  2.545E+00, test loss  2.520, training accuracy  0.26, test accuracy  0.26, time per example  13.58 µs
Epoch 100, train loss  2.360E+00, test loss  2.344, training accuracy  0.33, test accuracy  0.33, time per example  14.92 µs
Epoch 150, train loss  2.236E+00, test loss  2.244, training accuracy  0.36, test accuracy  0.35, time per example  15.68 µs
Epoch 200, train loss  2.174E+00, test loss  2.176, training accuracy  0.37, test accuracy  0.36, time per example  15.54 µs
Epoch 250, train loss  2.156E+00, test loss  2.137, training accuracy  0.36, test accuracy  0.36, time per example  13.87 µs
Epoch 300, train loss  2.128E+00, test loss  2.109, training accuracy  0.35, test accuracy  0.36, time per example  14.06 µs
Epoch 350, train loss  2.083E+00, test loss  2.078, training accuracy  0.36, test accuracy  0.37, time per example  13.85 µs
Epoch 400, train loss  2.049E+00, test loss  2.030, training accuracy  0.37, test accuracy  0.37, time per example  13.76 µs
Epoch 450, train loss  1.953E+00, test loss  1.944, training accuracy  0.39, test accuracy  0.39, time per example  14.04 µs
Epoch 500, train loss  1.906E+00, test loss  1.890, training accuracy  0.40, test accuracy  0.41, time per example  13.96 µs
Epoch 550, train loss  1.853E+00, test loss  1.844, training accuracy  0.41, test accuracy  0.41, time per example  15.53 µs
Epoch 600, train loss  1.762E+00, test loss  1.780, training accuracy  0.43, test accuracy  0.41, time per example  14.94 µs
Epoch 650, train loss  1.729E+00, test loss  1.725, training accuracy  0.41, test accuracy  0.41, time per example  15.39 µs
Epoch 700, train loss  1.703E+00, test loss  1.694, training accuracy  0.41, test accuracy  0.41, time per example  16.42 µs
Epoch 750, train loss  1.669E+00, test loss  1.670, training accuracy  0.40, test accuracy  0.42, time per example  15.88 µs
Epoch 800, train loss  1.652E+00, test loss  1.650, training accuracy  0.42, test accuracy  0.42, time per example  14.50 µs
Epoch 850, train loss  1.625E+00, test loss  1.631, training accuracy  0.43, test accuracy  0.43, time per example  14.20 µs
Epoch 900, train loss  1.632E+00, test loss  1.614, training accuracy  0.43, test accuracy  0.43, time per example  13.94 µs
Epoch 950, train loss  1.604E+00, test loss  1.597, training accuracy  0.42, test accuracy  0.44, time per example  13.64 µs
Epoch 1000, train loss  1.584E+00, test loss  1.581, training accuracy  0.44, test accuracy  0.44, time per example  13.94 µs
Epoch 1050, train loss  1.553E+00, test loss  1.564, training accuracy  0.45, test accuracy  0.45, time per example  13.98 µs
Epoch 1100, train loss  1.539E+00, test loss  1.549, training accuracy  0.45, test accuracy  0.45, time per example  14.77 µs
Epoch 1150, train loss  1.533E+00, test loss  1.534, training accuracy  0.45, test accuracy  0.45, time per example  15.14 µs
Epoch 1200, train loss  1.518E+00, test loss  1.514, training accuracy  0.45, test accuracy  0.46, time per example  15.82 µs
Epoch 1250, train loss  1.472E+00, test loss  1.502, training accuracy  0.47, test accuracy  0.46, time per example  13.87 µs
Epoch 1300, train loss  1.498E+00, test loss  1.493, training accuracy  0.45, test accuracy  0.46, time per example  14.19 µs
Epoch 1350, train loss  1.508E+00, test loss  1.483, training accuracy  0.46, test accuracy  0.46, time per example  13.91 µs
Epoch 1400, train loss  1.488E+00, test loss  1.476, training accuracy  0.46, test accuracy  0.46, time per example  13.92 µs
Epoch 1450, train loss  1.447E+00, test loss  1.469, training accuracy  0.46, test accuracy  0.46, time per example  13.94 µs
Epoch 1500, train loss  1.444E+00, test loss  1.464, training accuracy  0.47, test accuracy  0.46, time per example  13.91 µs
Epoch 1550, train loss  1.424E+00, test loss  1.458, training accuracy  0.48, test accuracy  0.46, time per example  14.01 µs
Epoch 1600, train loss  1.456E+00, test loss  1.454, training accuracy  0.46, test accuracy  0.46, time per example  14.68 µs
Epoch 1650, train loss  1.426E+00, test loss  1.451, training accuracy  0.46, test accuracy  0.46, time per example  18.00 µs
Epoch 1700, train loss  1.453E+00, test loss  1.445, training accuracy  0.46, test accuracy  0.46, time per example  14.04 µs
Epoch 1750, train loss  1.473E+00, test loss  1.443, training accuracy  0.45, test accuracy  0.46, time per example  14.18 µs
Epoch 1800, train loss  1.440E+00, test loss  1.441, training accuracy  0.46, test accuracy  0.46, time per example  13.72 µs
Epoch 1850, train loss  1.423E+00, test loss  1.437, training accuracy  0.47, test accuracy  0.46, time per example  14.47 µs
Epoch 1900, train loss  1.420E+00, test loss  1.435, training accuracy  0.47, test accuracy  0.46, time per example  14.02 µs
Epoch 1950, train loss  1.420E+00, test loss  1.431, training accuracy  0.46, test accuracy  0.46, time per example  13.88 µs
Time budget exceeded, hence stopping training
Total training time  60.90 s
Number of parameters: 13824
Epoch 0, train loss  2.934E+00, test loss  2.483, training accuracy  0.02, test accuracy  0.22, time per example  14.37 µs
Epoch 50, train loss  1.503E+00, test loss  1.482, training accuracy  0.42, test accuracy  0.42, time per example  13.97 µs
Epoch 100, train loss  1.381E+00, test loss  1.397, training accuracy  0.47, test accuracy  0.46, time per example  14.94 µs
Epoch 150, train loss  1.383E+00, test loss  1.391, training accuracy  0.47, test accuracy  0.46, time per example  14.93 µs
Epoch 200, train loss  1.386E+00, test loss  1.388, training accuracy  0.46, test accuracy  0.46, time per example  15.31 µs
Epoch 250, train loss  1.434E+00, test loss  1.387, training accuracy  0.44, test accuracy  0.46, time per example  15.94 µs
Epoch 300, train loss  1.360E+00, test loss  1.387, training accuracy  0.47, test accuracy  0.46, time per example  16.37 µs
Epoch 350, train loss  1.388E+00, test loss  1.386, training accuracy  0.45, test accuracy  0.46, time per example  14.01 µs
Epoch 400, train loss  1.393E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  13.82 µs
Epoch 450, train loss  1.417E+00, test loss  1.383, training accuracy  0.45, test accuracy  0.46, time per example  14.33 µs
Epoch 500, train loss  1.367E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.46, time per example  14.47 µs
Epoch 550, train loss  1.402E+00, test loss  1.387, training accuracy  0.45, test accuracy  0.46, time per example  14.08 µs
Epoch 600, train loss  1.350E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.46, time per example  15.24 µs
Epoch 650, train loss  1.397E+00, test loss  1.383, training accuracy  0.46, test accuracy  0.46, time per example  15.65 µs
Epoch 700, train loss  1.392E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  15.69 µs
Epoch 750, train loss  1.365E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  16.24 µs
Epoch 800, train loss  1.403E+00, test loss  1.383, training accuracy  0.45, test accuracy  0.46, time per example  14.10 µs
Epoch 850, train loss  1.417E+00, test loss  1.384, training accuracy  0.45, test accuracy  0.46, time per example  14.08 µs
Epoch 900, train loss  1.413E+00, test loss  1.385, training accuracy  0.45, test accuracy  0.46, time per example  14.37 µs
Epoch 950, train loss  1.361E+00, test loss  1.383, training accuracy  0.46, test accuracy  0.46, time per example  14.30 µs
Epoch 1000, train loss  1.409E+00, test loss  1.384, training accuracy  0.45, test accuracy  0.46, time per example  14.16 µs
Epoch 1050, train loss  1.354E+00, test loss  1.383, training accuracy  0.48, test accuracy  0.46, time per example  14.05 µs
Epoch 1100, train loss  1.439E+00, test loss  1.384, training accuracy  0.43, test accuracy  0.46, time per example  14.91 µs
Epoch 1150, train loss  1.381E+00, test loss  1.383, training accuracy  0.46, test accuracy  0.46, time per example  15.09 µs
Epoch 1200, train loss  1.361E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.46, time per example  16.11 µs
Epoch 1250, train loss  1.400E+00, test loss  1.383, training accuracy  0.46, test accuracy  0.46, time per example  14.04 µs
Epoch 1300, train loss  1.350E+00, test loss  1.386, training accuracy  0.47, test accuracy  0.46, time per example  13.99 µs
Epoch 1350, train loss  1.376E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.46, time per example  13.97 µs
Epoch 1400, train loss  1.394E+00, test loss  1.383, training accuracy  0.45, test accuracy  0.46, time per example  13.87 µs
Epoch 1450, train loss  1.400E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  14.07 µs
Epoch 1500, train loss  1.349E+00, test loss  1.383, training accuracy  0.46, test accuracy  0.46, time per example  13.78 µs
Epoch 1550, train loss  1.395E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  13.93 µs
Epoch 1600, train loss  1.370E+00, test loss  1.385, training accuracy  0.46, test accuracy  0.46, time per example  16.47 µs
Epoch 1650, train loss  1.394E+00, test loss  1.385, training accuracy  0.45, test accuracy  0.46, time per example  19.23 µs
Epoch 1700, train loss  1.364E+00, test loss  1.425, training accuracy  0.46, test accuracy  0.45, time per example  24.51 µs
Epoch 1750, train loss  1.391E+00, test loss  1.418, training accuracy  0.46, test accuracy  0.45, time per example  16.25 µs
Epoch 1800, train loss  1.383E+00, test loss  1.399, training accuracy  0.47, test accuracy  0.46, time per example  14.36 µs
Epoch 1850, train loss  1.381E+00, test loss  1.389, training accuracy  0.46, test accuracy  0.46, time per example  14.05 µs
Epoch 1900, train loss  1.352E+00, test loss  1.385, training accuracy  0.47, test accuracy  0.46, time per example  14.29 µs
Time budget exceeded, hence stopping training
Total training time  60.35 s

MLP¶

In [ ]:
@dataclass
class MLPForSeq2SeqConfig:
    vocab_size: int
    input_len: int
    n_embed: int
    n_hidden: int
    output_len: int

## MLP for sequence to sequence problems.
## It operates in 3 steps: Embed each input token into a vector, concatenate all the vectors, and then pass through an MLP. 
## The result of the MLP are the logits. 
class MLPForSeq2Seq(nn.Module):
    def __init__(self, cfg: MLPForSeq2SeqConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.n_embed)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.n_embed * cfg.input_len, cfg.n_hidden),
            nn.ReLU(),
            nn.Linear(cfg.n_hidden, cfg.vocab_size * cfg.output_len)
        )
        
    def forward(self, x):
        # x is of shape (batch_size, input_len)
        x = self.embed(x)
        # now x is of shape (batch_size, input_len, n_embed)
        # reshape x to have shape (batch_size, input_len * n_embed)
        x = x.view(-1, self.cfg.n_embed * self.cfg.input_len)
        x = self.mlp(x)
        # now x is of shape (batch_size, vocab_size * output_len)
        # reshape x to have shape (batch_size, output_len, vocab_size)
        x = x.view(-1, self.cfg.output_len, self.cfg.vocab_size)
        return x

debugging, ignore¶

In [ ]:
model_config = MLPForSeq2SeqConfig(
    vocab_size=len(chars),
    n_embed=8,
    n_hidden=128,
    input_len=reversed_add_3digit_separated.train_x.shape[1],
    output_len=1,
)
m = MLPForSeq2Seq(model_config)
m(torch.randint(0, len(chars), (10, reversed_add_3digit_separated.train_x.shape[1])))[:, -1, :].shape
Out[ ]:
torch.Size([10, 16])
In [ ]:
type(m(torch.randint(0, len(chars), (10, reversed_add_3digit_separated.train_x.shape[1]))))
Out[ ]:
torch.Tensor

run for real¶

In [ ]:
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-3,
    weight_decay=1e-4,
    epoch_interval=500,
    time_budget_seconds=60,
)
model_config = MLPForSeq2SeqConfig(
    vocab_size=len(chars),
    n_embed=8,
    n_hidden=128,
    input_len=reversed_add_3digit_separated.train_x.shape[1],
    output_len=1,
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated, train_config
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("MLP", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, MLPForSeq2Seq(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 14608
Epoch 0, train loss  2.796E+00, test loss  2.749, training accuracy  0.10, test accuracy  0.13, time per example  1.47 µs
Epoch 500, train loss  3.596E-01, test loss  0.355, training accuracy  0.89, test accuracy  0.90, time per example  1.04 µs
Epoch 1000, train loss  3.558E-02, test loss  0.035, training accuracy  1.00, test accuracy  1.00, time per example  1.15 µs
Test accuracy > 99.5%, hence stopping training
Total training time  2.17 s
Number of parameters: 14608
Epoch 0, train loss  2.755E+00, test loss  2.745, training accuracy  0.06, test accuracy  0.08, time per example  1.52 µs
Epoch 500, train loss  1.565E+00, test loss  1.552, training accuracy  0.45, test accuracy  0.45, time per example  1.03 µs
Epoch 1000, train loss  1.434E+00, test loss  1.428, training accuracy  0.47, test accuracy  0.48, time per example  1.41 µs
Epoch 1500, train loss  1.346E+00, test loss  1.345, training accuracy  0.51, test accuracy  0.53, time per example  1.28 µs
Epoch 2000, train loss  1.243E+00, test loss  1.228, training accuracy  0.57, test accuracy  0.57, time per example  1.58 µs
Epoch 2500, train loss  1.111E+00, test loss  1.102, training accuracy  0.62, test accuracy  0.63, time per example  1.46 µs
Epoch 3000, train loss  9.809E-01, test loss  0.974, training accuracy  0.68, test accuracy  0.69, time per example  1.09 µs
Epoch 3500, train loss  8.241E-01, test loss  0.845, training accuracy  0.75, test accuracy  0.73, time per example  1.02 µs
Epoch 4000, train loss  7.244E-01, test loss  0.715, training accuracy  0.78, test accuracy  0.79, time per example  1.10 µs
Epoch 4500, train loss  5.737E-01, test loss  0.589, training accuracy  0.84, test accuracy  0.83, time per example  1.15 µs
Epoch 5000, train loss  4.787E-01, test loss  0.473, training accuracy  0.88, test accuracy  0.88, time per example  1.12 µs
Epoch 5500, train loss  3.490E-01, test loss  0.372, training accuracy  0.93, test accuracy  0.92, time per example  1.04 µs
Epoch 6000, train loss  2.905E-01, test loss  0.284, training accuracy  0.95, test accuracy  0.95, time per example  1.04 µs
Epoch 6500, train loss  2.122E-01, test loss  0.211, training accuracy  0.98, test accuracy  0.98, time per example  1.03 µs
Epoch 7000, train loss  1.420E-01, test loss  0.152, training accuracy  0.99, test accuracy  0.99, time per example  1.06 µs
Epoch 7500, train loss  1.012E-01, test loss  0.108, training accuracy  1.00, test accuracy  1.00, time per example  1.25 µs
Test accuracy > 99.5%, hence stopping training
Total training time  17.04 s
Number of parameters: 14608
Epoch 0, train loss  2.790E+00, test loss  2.352, training accuracy  0.06, test accuracy  0.33, time per example  1.51 µs
Epoch 500, train loss  5.589E-04, test loss  0.001, training accuracy  1.00, test accuracy  1.00, time per example  1.31 µs
Test accuracy > 99.5%, hence stopping training
Total training time  1.31 s
In [ ]:
 

NanoGPT¶

In [ ]:
model_config = GPTConfig(
    block_size=reversed_add_3digit_separated.train_x.shape[1], 
    vocab_size=len(chars), 
    n_layer=6, n_head=4, n_embd=16, dropout=0.1
)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-3,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=60,
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, GPT(model_config), reversed_add_3digit_separated, train_config
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, GPT(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval(
    model_config, GPT(model_config), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-2)
)
fatal: destination path 'nanoGPT' already exists and is not an empty directory.
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
number of parameters: 0.02M
Number of parameters: 20160
Epoch 0, train loss  2.780E+00, test loss  2.728, training accuracy  0.06, test accuracy  0.20, time per example  33.63 µs
Epoch 50, train loss  2.224E+00, test loss  2.219, training accuracy  0.28, test accuracy  0.29, time per example  20.93 µs
Epoch 100, train loss  1.890E+00, test loss  1.867, training accuracy  0.37, test accuracy  0.39, time per example  22.87 µs
Epoch 150, train loss  1.747E+00, test loss  1.686, training accuracy  0.38, test accuracy  0.40, time per example  20.56 µs
Epoch 200, train loss  1.587E+00, test loss  1.604, training accuracy  0.41, test accuracy  0.40, time per example  21.32 µs
Epoch 250, train loss  1.526E+00, test loss  1.540, training accuracy  0.43, test accuracy  0.43, time per example  21.29 µs
Epoch 300, train loss  1.429E+00, test loss  1.469, training accuracy  0.47, test accuracy  0.45, time per example  22.66 µs
Epoch 350, train loss  1.412E+00, test loss  1.433, training accuracy  0.46, test accuracy  0.46, time per example  23.59 µs
Epoch 400, train loss  1.428E+00, test loss  1.418, training accuracy  0.45, test accuracy  0.46, time per example  22.90 µs
Epoch 450, train loss  1.409E+00, test loss  1.409, training accuracy  0.47, test accuracy  0.46, time per example  23.32 µs
Epoch 500, train loss  1.352E+00, test loss  1.404, training accuracy  0.48, test accuracy  0.46, time per example  21.11 µs
Epoch 550, train loss  1.417E+00, test loss  1.401, training accuracy  0.45, test accuracy  0.46, time per example  21.14 µs
Epoch 600, train loss  1.406E+00, test loss  1.398, training accuracy  0.46, test accuracy  0.46, time per example  23.29 µs
Epoch 650, train loss  1.379E+00, test loss  1.397, training accuracy  0.46, test accuracy  0.46, time per example  21.46 µs
Epoch 700, train loss  1.402E+00, test loss  1.395, training accuracy  0.45, test accuracy  0.46, time per example  20.79 µs
Epoch 750, train loss  1.449E+00, test loss  1.393, training accuracy  0.44, test accuracy  0.46, time per example  21.24 µs
Epoch 800, train loss  1.358E+00, test loss  1.392, training accuracy  0.47, test accuracy  0.46, time per example  20.87 µs
Epoch 850, train loss  1.398E+00, test loss  1.390, training accuracy  0.46, test accuracy  0.46, time per example  21.09 µs
Epoch 900, train loss  1.345E+00, test loss  1.390, training accuracy  0.48, test accuracy  0.46, time per example  22.85 µs
Epoch 950, train loss  1.377E+00, test loss  1.390, training accuracy  0.46, test accuracy  0.46, time per example  24.05 µs
Epoch 1000, train loss  1.388E+00, test loss  1.388, training accuracy  0.46, test accuracy  0.46, time per example  21.32 µs
Epoch 1050, train loss  1.415E+00, test loss  1.387, training accuracy  0.45, test accuracy  0.46, time per example  20.80 µs
Epoch 1100, train loss  1.384E+00, test loss  1.387, training accuracy  0.46, test accuracy  0.46, time per example  20.86 µs
Epoch 1150, train loss  1.381E+00, test loss  1.387, training accuracy  0.46, test accuracy  0.46, time per example  22.35 µs
Epoch 1200, train loss  1.356E+00, test loss  1.386, training accuracy  0.46, test accuracy  0.46, time per example  22.89 µs
Epoch 1250, train loss  1.415E+00, test loss  1.386, training accuracy  0.45, test accuracy  0.46, time per example  23.45 µs
Epoch 1300, train loss  1.365E+00, test loss  1.387, training accuracy  0.47, test accuracy  0.46, time per example  21.26 µs
Time budget exceeded, hence stopping training
Total training time  61.96 s
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
number of parameters: 0.02M
Number of parameters: 20160
Epoch 0, train loss  2.788E+00, test loss  2.785, training accuracy  0.05, test accuracy  0.05, time per example  21.03 µs
Epoch 50, train loss  2.683E+00, test loss  2.678, training accuracy  0.24, test accuracy  0.25, time per example  20.64 µs
Epoch 100, train loss  2.607E+00, test loss  2.606, training accuracy  0.35, test accuracy  0.34, time per example  22.48 µs
Epoch 150, train loss  2.539E+00, test loss  2.538, training accuracy  0.35, test accuracy  0.35, time per example  23.84 µs
Epoch 200, train loss  2.480E+00, test loss  2.473, training accuracy  0.34, test accuracy  0.35, time per example  26.10 µs
Epoch 250, train loss  2.407E+00, test loss  2.411, training accuracy  0.36, test accuracy  0.35, time per example  23.87 µs
Epoch 300, train loss  2.356E+00, test loss  2.352, training accuracy  0.35, test accuracy  0.35, time per example  20.95 µs
Epoch 350, train loss  2.286E+00, test loss  2.297, training accuracy  0.37, test accuracy  0.35, time per example  20.96 µs
Epoch 400, train loss  2.238E+00, test loss  2.243, training accuracy  0.36, test accuracy  0.36, time per example  20.79 µs
Epoch 450, train loss  2.205E+00, test loss  2.193, training accuracy  0.36, test accuracy  0.37, time per example  21.65 µs
Epoch 500, train loss  2.134E+00, test loss  2.146, training accuracy  0.39, test accuracy  0.38, time per example  23.42 µs
Epoch 550, train loss  2.104E+00, test loss  2.101, training accuracy  0.38, test accuracy  0.38, time per example  23.71 µs
Epoch 600, train loss  2.056E+00, test loss  2.058, training accuracy  0.39, test accuracy  0.39, time per example  21.78 µs
Epoch 650, train loss  2.018E+00, test loss  2.019, training accuracy  0.40, test accuracy  0.40, time per example  21.38 µs
Epoch 700, train loss  1.967E+00, test loss  1.983, training accuracy  0.40, test accuracy  0.40, time per example  20.99 µs
Epoch 750, train loss  1.916E+00, test loss  1.949, training accuracy  0.42, test accuracy  0.40, time per example  24.10 µs
Epoch 800, train loss  1.914E+00, test loss  1.917, training accuracy  0.40, test accuracy  0.40, time per example  22.57 µs
Epoch 850, train loss  1.883E+00, test loss  1.888, training accuracy  0.40, test accuracy  0.40, time per example  24.72 µs
Epoch 900, train loss  1.863E+00, test loss  1.860, training accuracy  0.40, test accuracy  0.41, time per example  21.32 µs
Epoch 950, train loss  1.821E+00, test loss  1.836, training accuracy  0.40, test accuracy  0.40, time per example  20.65 µs
Epoch 1000, train loss  1.799E+00, test loss  1.813, training accuracy  0.41, test accuracy  0.41, time per example  21.26 µs
Epoch 1050, train loss  1.802E+00, test loss  1.792, training accuracy  0.40, test accuracy  0.41, time per example  21.10 µs
Epoch 1100, train loss  1.811E+00, test loss  1.772, training accuracy  0.39, test accuracy  0.41, time per example  29.70 µs
Epoch 1150, train loss  1.760E+00, test loss  1.754, training accuracy  0.41, test accuracy  0.41, time per example  24.86 µs
Epoch 1200, train loss  1.702E+00, test loss  1.737, training accuracy  0.43, test accuracy  0.41, time per example  25.73 µs
Epoch 1250, train loss  1.728E+00, test loss  1.720, training accuracy  0.40, test accuracy  0.41, time per example  21.66 µs
Time budget exceeded, hence stopping training
Total training time  60.69 s
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
number of parameters: 0.02M
Number of parameters: 20160
Epoch 0, train loss  2.782E+00, test loss  2.696, training accuracy  0.09, test accuracy  0.16, time per example  21.20 µs
Epoch 50, train loss  1.565E+00, test loss  1.555, training accuracy  0.38, test accuracy  0.38, time per example  21.16 µs
Epoch 100, train loss  1.380E+00, test loss  1.403, training accuracy  0.46, test accuracy  0.46, time per example  21.58 µs
Epoch 150, train loss  1.372E+00, test loss  1.390, training accuracy  0.46, test accuracy  0.46, time per example  21.23 µs
Epoch 200, train loss  1.430E+00, test loss  1.387, training accuracy  0.44, test accuracy  0.46, time per example  22.43 µs
Epoch 250, train loss  1.434E+00, test loss  1.386, training accuracy  0.44, test accuracy  0.46, time per example  23.99 µs
Epoch 300, train loss  1.379E+00, test loss  1.389, training accuracy  0.46, test accuracy  0.46, time per example  21.07 µs
Epoch 350, train loss  1.398E+00, test loss  1.389, training accuracy  0.45, test accuracy  0.46, time per example  21.58 µs
Epoch 400, train loss  1.378E+00, test loss  1.388, training accuracy  0.46, test accuracy  0.46, time per example  21.29 µs
Epoch 450, train loss  1.423E+00, test loss  1.384, training accuracy  0.45, test accuracy  0.46, time per example  20.87 µs
Epoch 500, train loss  1.444E+00, test loss  1.383, training accuracy  0.44, test accuracy  0.46, time per example  22.98 µs
Epoch 550, train loss  1.368E+00, test loss  1.372, training accuracy  0.47, test accuracy  0.47, time per example  23.73 µs
Epoch 600, train loss  1.265E+00, test loss  1.275, training accuracy  0.51, test accuracy  0.50, time per example  21.27 µs
Epoch 650, train loss  9.377E-01, test loss  0.900, training accuracy  0.59, test accuracy  0.63, time per example  21.47 µs
Epoch 700, train loss  5.093E-01, test loss  0.538, training accuracy  0.80, test accuracy  0.77, time per example  21.60 µs
Epoch 750, train loss  4.023E-01, test loss  0.386, training accuracy  0.84, test accuracy  0.85, time per example  21.73 µs
Epoch 800, train loss  3.261E-01, test loss  0.316, training accuracy  0.88, test accuracy  0.88, time per example  22.41 µs
Epoch 850, train loss  2.294E-01, test loss  0.254, training accuracy  0.91, test accuracy  0.91, time per example  24.94 µs
Epoch 900, train loss  2.458E-01, test loss  0.229, training accuracy  0.91, test accuracy  0.91, time per example  23.72 µs
Epoch 950, train loss  1.855E-01, test loss  0.185, training accuracy  0.94, test accuracy  0.93, time per example  21.34 µs
Epoch 1000, train loss  1.861E-01, test loss  0.167, training accuracy  0.93, test accuracy  0.94, time per example  22.70 µs
Epoch 1050, train loss  1.632E-01, test loss  0.167, training accuracy  0.94, test accuracy  0.94, time per example  21.32 µs
Epoch 1100, train loss  1.790E-01, test loss  0.164, training accuracy  0.94, test accuracy  0.94, time per example  23.76 µs
Epoch 1150, train loss  1.296E-01, test loss  0.158, training accuracy  0.95, test accuracy  0.94, time per example  23.91 µs
Epoch 1200, train loss  1.128E-01, test loss  0.134, training accuracy  0.96, test accuracy  0.95, time per example  21.23 µs
Epoch 1250, train loss  1.174E-01, test loss  0.122, training accuracy  0.96, test accuracy  0.96, time per example  21.77 µs
Time budget exceeded, hence stopping training
Total training time  60.06 s

TLTransformer¶

TLTransformer code (it's a bit long, hence hidden)

In [ ]:
import einops
from fancy_einsum import einsum
import math

@dataclass
class TLConfig:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        if self.cfg.debug: print("Residual:", residual.shape)
        residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        # Calculate the variance, square root it. Add in an epsilon to prevent divide by zero.
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + self.cfg.layer_norm_eps).sqrt()
        normalized = residual / scale
        normalized = normalized * self.w + self.b
        if self.cfg.debug: print("Normalized:", residual.shape)
        return normalized

class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        embed = self.W_E[tokens, :] # [batch, position, d_model]
        if self.cfg.debug: print("Embeddings:", embed.shape)
        return embed

class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)
        
        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K
        
        attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_mask(attn_scores)

        pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]

        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)

        attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        return attn_out

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
        pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = gelu_new(pre)
        mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
        return mlp_out

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post

class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
        return logits
    
class TLTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

Using TLTransformer¶

In [ ]:
tlconfig = TLConfig(
    d_model = 16, 
    debug = False, 
    d_vocab = len(chars), 
    n_ctx = reversed_add_3digit_separated.train_x.shape[1], 
    d_head = 4, n_heads = 4, n_layers = 4, d_mlp = 64)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-3,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=60,
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
    tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated, train_config
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
    tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-4)
)
results_dict[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval(
    tlconfig, TLTransformer(tlconfig), reversed_add_3digit_separated, 
    dataclasses.replace(train_config, lr=1e-2)
)
Number of parameters: 13872
Epoch 0, train loss  2.778E+00, test loss  2.730, training accuracy  0.03, test accuracy  0.16, time per example  20.18 µs
Epoch 50, train loss  2.171E+00, test loss  2.161, training accuracy  0.40, test accuracy  0.39, time per example  13.98 µs
Epoch 100, train loss  1.790E+00, test loss  1.791, training accuracy  0.43, test accuracy  0.43, time per example  14.48 µs
Epoch 150, train loss  1.585E+00, test loss  1.598, training accuracy  0.47, test accuracy  0.46, time per example  15.03 µs
Epoch 200, train loss  1.496E+00, test loss  1.484, training accuracy  0.46, test accuracy  0.46, time per example  14.75 µs
Epoch 250, train loss  1.455E+00, test loss  1.440, training accuracy  0.45, test accuracy  0.46, time per example  19.40 µs
Epoch 300, train loss  1.421E+00, test loss  1.421, training accuracy  0.46, test accuracy  0.46, time per example  24.74 µs
Epoch 350, train loss  1.474E+00, test loss  1.465, training accuracy  0.43, test accuracy  0.44, time per example  14.37 µs
Epoch 400, train loss  1.423E+00, test loss  1.409, training accuracy  0.45, test accuracy  0.46, time per example  13.88 µs
Epoch 450, train loss  1.439E+00, test loss  1.402, training accuracy  0.45, test accuracy  0.47, time per example  14.26 µs
Epoch 500, train loss  1.369E+00, test loss  1.398, training accuracy  0.46, test accuracy  0.46, time per example  13.60 µs
Epoch 550, train loss  1.404E+00, test loss  1.395, training accuracy  0.45, test accuracy  0.46, time per example  14.09 µs
Epoch 600, train loss  1.384E+00, test loss  1.393, training accuracy  0.46, test accuracy  0.46, time per example  14.54 µs
Epoch 650, train loss  1.366E+00, test loss  1.391, training accuracy  0.47, test accuracy  0.46, time per example  20.68 µs
Epoch 700, train loss  1.417E+00, test loss  1.389, training accuracy  0.45, test accuracy  0.46, time per example  22.56 µs
Epoch 750, train loss  1.371E+00, test loss  1.389, training accuracy  0.47, test accuracy  0.46, time per example  20.36 µs
Epoch 800, train loss  1.370E+00, test loss  1.388, training accuracy  0.46, test accuracy  0.46, time per example  23.01 µs
Epoch 850, train loss  1.382E+00, test loss  1.387, training accuracy  0.46, test accuracy  0.46, time per example  14.48 µs
Epoch 900, train loss  1.353E+00, test loss  1.387, training accuracy  0.48, test accuracy  0.46, time per example  14.04 µs
Epoch 950, train loss  1.400E+00, test loss  1.388, training accuracy  0.45, test accuracy  0.46, time per example  14.23 µs
Epoch 1000, train loss  1.375E+00, test loss  1.386, training accuracy  0.47, test accuracy  0.46, time per example  14.25 µs
Epoch 1050, train loss  1.423E+00, test loss  1.387, training accuracy  0.44, test accuracy  0.46, time per example  20.35 µs
Epoch 1100, train loss  1.414E+00, test loss  1.386, training accuracy  0.44, test accuracy  0.46, time per example  22.99 µs
Epoch 1150, train loss  1.395E+00, test loss  1.386, training accuracy  0.46, test accuracy  0.46, time per example  15.31 µs
Epoch 1200, train loss  1.381E+00, test loss  1.385, training accuracy  0.45, test accuracy  0.46, time per example  14.67 µs
Epoch 1250, train loss  1.392E+00, test loss  1.386, training accuracy  0.46, test accuracy  0.46, time per example  13.73 µs
Epoch 1300, train loss  1.398E+00, test loss  1.385, training accuracy  0.47, test accuracy  0.46, time per example  13.96 µs
Epoch 1350, train loss  1.351E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.45, time per example  14.18 µs
Epoch 1400, train loss  1.405E+00, test loss  1.384, training accuracy  0.45, test accuracy  0.46, time per example  14.48 µs
Epoch 1450, train loss  1.371E+00, test loss  1.384, training accuracy  0.48, test accuracy  0.46, time per example  23.40 µs
Epoch 1500, train loss  1.382E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  25.16 µs
Epoch 1550, train loss  1.373E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  14.13 µs
Epoch 1600, train loss  1.358E+00, test loss  1.383, training accuracy  0.47, test accuracy  0.46, time per example  14.47 µs
Epoch 1650, train loss  1.397E+00, test loss  1.383, training accuracy  0.44, test accuracy  0.46, time per example  14.04 µs
Epoch 1700, train loss  1.382E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  14.62 µs
Time budget exceeded, hence stopping training
Total training time  60.25 s
Number of parameters: 13872
Epoch 0, train loss  2.789E+00, test loss  2.785, training accuracy  0.05, test accuracy  0.05, time per example  14.73 µs
Epoch 50, train loss  2.659E+00, test loss  2.654, training accuracy  0.30, test accuracy  0.30, time per example  14.08 µs
Epoch 100, train loss  2.582E+00, test loss  2.583, training accuracy  0.36, test accuracy  0.35, time per example  13.97 µs
Epoch 150, train loss  2.513E+00, test loss  2.507, training accuracy  0.40, test accuracy  0.41, time per example  19.63 µs
Epoch 200, train loss  2.443E+00, test loss  2.434, training accuracy  0.43, test accuracy  0.43, time per example  22.57 µs
Epoch 250, train loss  2.365E+00, test loss  2.366, training accuracy  0.46, test accuracy  0.45, time per example  14.14 µs
Epoch 300, train loss  2.310E+00, test loss  2.301, training accuracy  0.45, test accuracy  0.46, time per example  13.73 µs
Epoch 350, train loss  2.246E+00, test loss  2.242, training accuracy  0.46, test accuracy  0.46, time per example  22.13 µs
Epoch 400, train loss  2.197E+00, test loss  2.186, training accuracy  0.45, test accuracy  0.46, time per example  23.04 µs
Epoch 450, train loss  2.138E+00, test loss  2.133, training accuracy  0.46, test accuracy  0.46, time per example  13.21 µs
Epoch 500, train loss  2.099E+00, test loss  2.084, training accuracy  0.45, test accuracy  0.46, time per example  13.59 µs
Epoch 550, train loss  2.024E+00, test loss  2.037, training accuracy  0.48, test accuracy  0.46, time per example  19.00 µs
Epoch 600, train loss  1.991E+00, test loss  1.994, training accuracy  0.47, test accuracy  0.46, time per example  23.00 µs
Epoch 650, train loss  1.968E+00, test loss  1.953, training accuracy  0.44, test accuracy  0.46, time per example  13.40 µs
Epoch 700, train loss  1.917E+00, test loss  1.915, training accuracy  0.46, test accuracy  0.46, time per example  13.68 µs
Epoch 750, train loss  1.864E+00, test loss  1.879, training accuracy  0.47, test accuracy  0.46, time per example  13.62 µs
Epoch 800, train loss  1.845E+00, test loss  1.845, training accuracy  0.45, test accuracy  0.46, time per example  13.94 µs
Epoch 850, train loss  1.831E+00, test loss  1.814, training accuracy  0.46, test accuracy  0.46, time per example  13.95 µs
Epoch 900, train loss  1.767E+00, test loss  1.786, training accuracy  0.46, test accuracy  0.46, time per example  18.74 µs
Epoch 950, train loss  1.788E+00, test loss  1.758, training accuracy  0.44, test accuracy  0.46, time per example  19.80 µs
Epoch 1000, train loss  1.724E+00, test loss  1.733, training accuracy  0.47, test accuracy  0.46, time per example  24.58 µs
Epoch 1050, train loss  1.690E+00, test loss  1.709, training accuracy  0.47, test accuracy  0.46, time per example  14.11 µs
Epoch 1100, train loss  1.670E+00, test loss  1.687, training accuracy  0.47, test accuracy  0.46, time per example  13.99 µs
Epoch 1150, train loss  1.635E+00, test loss  1.667, training accuracy  0.48, test accuracy  0.46, time per example  14.32 µs
Epoch 1200, train loss  1.630E+00, test loss  1.647, training accuracy  0.47, test accuracy  0.46, time per example  14.52 µs
Epoch 1250, train loss  1.613E+00, test loss  1.630, training accuracy  0.47, test accuracy  0.46, time per example  18.26 µs
Epoch 1300, train loss  1.602E+00, test loss  1.614, training accuracy  0.47, test accuracy  0.46, time per example  14.09 µs
Epoch 1350, train loss  1.556E+00, test loss  1.599, training accuracy  0.48, test accuracy  0.46, time per example  13.91 µs
Epoch 1400, train loss  1.577E+00, test loss  1.585, training accuracy  0.47, test accuracy  0.46, time per example  19.98 µs
Epoch 1450, train loss  1.623E+00, test loss  1.571, training accuracy  0.44, test accuracy  0.46, time per example  23.34 µs
Epoch 1500, train loss  1.549E+00, test loss  1.560, training accuracy  0.46, test accuracy  0.46, time per example  15.32 µs
Epoch 1550, train loss  1.526E+00, test loss  1.549, training accuracy  0.47, test accuracy  0.46, time per example  14.06 µs
Epoch 1600, train loss  1.558E+00, test loss  1.538, training accuracy  0.45, test accuracy  0.46, time per example  14.50 µs
Epoch 1650, train loss  1.529E+00, test loss  1.528, training accuracy  0.46, test accuracy  0.46, time per example  13.68 µs
Epoch 1700, train loss  1.524E+00, test loss  1.519, training accuracy  0.44, test accuracy  0.46, time per example  19.25 µs
Time budget exceeded, hence stopping training
Total training time  60.11 s
Number of parameters: 13872
Epoch 0, train loss  2.749E+00, test loss  2.689, training accuracy  0.18, test accuracy  0.27, time per example  19.48 µs
Epoch 50, train loss  1.536E+00, test loss  1.644, training accuracy  0.42, test accuracy  0.34, time per example  22.88 µs
Epoch 100, train loss  1.383E+00, test loss  1.392, training accuracy  0.47, test accuracy  0.46, time per example  18.39 µs
Epoch 150, train loss  1.378E+00, test loss  1.359, training accuracy  0.46, test accuracy  0.47, time per example  22.33 µs
Epoch 200, train loss  1.339E+00, test loss  1.387, training accuracy  0.49, test accuracy  0.46, time per example  13.97 µs
Epoch 250, train loss  1.351E+00, test loss  1.384, training accuracy  0.47, test accuracy  0.46, time per example  13.96 µs
Epoch 300, train loss  1.377E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  13.93 µs
Epoch 350, train loss  1.350E+00, test loss  1.388, training accuracy  0.48, test accuracy  0.46, time per example  15.97 µs
Epoch 400, train loss  1.392E+00, test loss  1.384, training accuracy  0.46, test accuracy  0.46, time per example  13.52 µs
Epoch 450, train loss  1.384E+00, test loss  1.383, training accuracy  0.47, test accuracy  0.47, time per example  14.41 µs
Epoch 500, train loss  1.386E+00, test loss  1.373, training accuracy  0.47, test accuracy  0.47, time per example  20.66 µs
Epoch 550, train loss  1.279E+00, test loss  1.292, training accuracy  0.50, test accuracy  0.50, time per example  22.88 µs
Epoch 600, train loss  1.188E+00, test loss  1.121, training accuracy  0.54, test accuracy  0.54, time per example  13.41 µs
Epoch 650, train loss  7.683E-01, test loss  0.825, training accuracy  0.68, test accuracy  0.63, time per example  13.63 µs
Epoch 700, train loss  5.573E-01, test loss  0.616, training accuracy  0.78, test accuracy  0.73, time per example  13.49 µs
Epoch 750, train loss  1.453E-01, test loss  0.137, training accuracy  0.97, test accuracy  0.98, time per example  14.06 µs
Epoch 800, train loss  2.238E-02, test loss  0.014, training accuracy  1.00, test accuracy  1.00, time per example  14.01 µs
Test accuracy > 99.5%, hence stopping training
Total training time  28.83 s
In [ ]:
 

Results for n-digit addition¶

In [ ]:
simple_results = []
for key, value in results_dict.items():
    simple_results.append((key[0], f"{value.train_config.lr: .1e}", value.train_loss, value.test_loss, value.test_accuracy, value.num_parameters, value.time_per_example_in_micros, value.train_time_in_seconds))
import pandas as pd
pd.DataFrame.from_records(simple_results, 
    columns=["model", "Learning Rate", "train_loss", "test_loss", "test_accuracy", 
             "num_parameters", "time_per_example_in_micros", "train_time_in_seconds"]
).round(3)
Out[ ]:
model Learning Rate train_loss test_loss test_accuracy num_parameters time_per_example_in_micros train_time_in_seconds
0 HuggingFaceGPT 1.0e-03 0.344 0.367 0.852 13824 15.079 60.249
1 HuggingFaceGPT 1.0e-04 1.420 1.431 0.461 13824 15.241 60.898
2 HuggingFaceGPT 1.0e-02 1.352 1.385 0.460 13824 15.501 60.349
3 MLP 1.0e-03 0.036 0.035 1.000 14608 1.058 2.169
4 MLP 1.0e-04 0.101 0.108 0.999 14608 1.109 17.040
5 MLP 1.0e-02 0.001 0.001 1.000 14608 1.278 1.311
6 NanoGPT 1.0e-03 1.365 1.387 0.460 20160 23.255 61.963
7 NanoGPT 1.0e-04 1.728 1.720 0.412 20160 23.687 60.688
8 NanoGPT 1.0e-02 0.117 0.122 0.956 20160 23.442 60.061
9 TLTransformer 1.0e-03 1.382 1.384 0.463 13872 17.296 60.254
10 TLTransformer 1.0e-04 1.524 1.519 0.462 13872 17.255 60.110
11 TLTransformer 1.0e-02 0.022 0.014 1.000 13872 17.572 28.826
In [ ]:
 

Shakespeare Data¶

In [ ]:
@dataclass
class BardData:
    train: torch.Tensor
    test: torch.Tensor
    vocab_size: int
    stoi: dict[str, int]
    itos: dict[int, str]
    
In [ ]:
def make_shakespeare_data():
    import os
    import requests
    
    filename = 'shakespeare.txt'
    if not os.path.exists(filename):
        url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
        with open(filename, 'w') as f:
            f.write(requests.get(url).text)
    with open(filename, 'r') as f:
        text = f.read()
    print("length of dataset in characters: ", len(text))
    chars = sorted(list(set(text)))
    vocab_size = len(chars)
    print("all the unique characters:", ''.join(chars))
    print("vocab size:", vocab_size)

    # create a mapping from characters to integers
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }

    def encode(s):
        return [stoi[c] for c in s] 

    n = len(text)
    train_data = text[:int(n*0.9)]
    val_data = text[int(n*0.9):]
    train_ids = encode(train_data)
    val_ids = encode(val_data)
    print(f"train has {len(train_ids)} tokens")
    print(f"val has {len(val_ids)} tokens")

    # don't make the mistake of setting dtype to uint8 in an effort to save memory
    # somehow the code doesn't work with dtypes other than long
    train_ids = torch.tensor(train_ids, dtype=torch.long)
    val_ids = torch.tensor(val_ids, dtype=torch.long)

    return BardData(train_ids, val_ids, vocab_size, stoi, itos)
In [ ]:
bd = make_shakespeare_data()
length of dataset in characters:  1115394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1003854 tokens
val has 111540 tokens
In [ ]:
def get_batch(data: BardData, is_train: bool, batch_size: int, block_size: int, is_y_single_token: bool = False):
    d = data.train if is_train else data.test
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    if is_y_single_token:
        y = torch.stack([d[i+block_size] for i in ix])
    else:
        y = torch.stack([d[i+1:i+1+block_size] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

def train_and_eval_bard(model_config: Any, m: nn.Module, data: BardData,
                        train_config: TrainConfig, block_size: int, 
                        is_y_single_token: bool = True, is_nano_gpt: bool = False):
    m = m.to(device)
    num_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Number of parameters: {num_params}")

    optimizer = torch.optim.AdamW(m.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)

    training_losses = {}
    test_losses = {}
    training_accuracies = {}
    test_accuracies = {}

    outer_start = time.time()
    ep = 0
    while ep < train_config.epochs:
        start = time.time()
        optimizer.zero_grad()
        x, y = get_batch(data, is_train=True, batch_size=train_config.train_batch_size, 
                         block_size=block_size, is_y_single_token=is_y_single_token)
        if is_nano_gpt:
          output, loss = m(x, y)
          logits = einops.rearrange(output, "b s v -> b v s")
        else:
          output = m(x)
          if type(output) is tuple:
              output = output[0]
          if is_y_single_token:
              logits = output.squeeze()
          else:
              logits = einops.rearrange(output, "b s v -> b v s")
          loss = F.cross_entropy(logits, y)
        if ep % train_config.epoch_interval == 0:
            training_losses[ep] = loss.item()
            if is_y_single_token:
                preds = torch.argmax(logits, dim=-1)
                next_tokens = y
            else:
                preds = torch.argmax(logits[:, -1, :], dim=-1)
                next_tokens = y[:, -1]
            training_accuracies[ep] = torch.sum(preds == next_tokens).item() / preds.shape[0]
        
        loss.backward()
        optimizer.step()
        elapsed = time.time() - start

        if ep % train_config.epoch_interval == 0:
            with torch.no_grad():
                #calculate test loss
                test_x, test_y = get_batch(data, is_train=False, batch_size=train_config.train_batch_size, 
                         block_size=block_size, is_y_single_token=is_y_single_token)
                if is_nano_gpt:
                  output, test_loss = m(x, y)
                  test_logits = einops.rearrange(output, "b s v -> b v s")
                else:
                  output = m(test_x)
                  if type(output) is tuple:
                      output = output[0]
                  if is_y_single_token:
                      test_logits = output.squeeze()
                  else:
                      test_logits = einops.rearrange(output, "b s v -> b v s")
                  test_loss = F.cross_entropy(test_logits, test_y)
                test_losses[ep] = test_loss.item()

                if is_y_single_token:
                    test_preds = torch.argmax(test_logits, dim=-1)
                    next_tokens = test_y
                else:
                    test_preds = torch.argmax(test_logits[:, -1, :], dim=-1)
                    next_tokens = test_y[:, -1]
                
                test_accuracies[ep] = torch.sum(test_preds == next_tokens).item() / test_preds.shape[0]
                print(f"Epoch {ep}, train loss {training_losses[ep]: .3E}, test loss {test_losses[ep]: .3f}, " +
                    f"training accuracy {training_accuracies[ep]: .2f}, test accuracy {test_accuracies[ep]: .2f}, " +
                    f"time per example {elapsed * 1e6 / train_config.train_batch_size: .2f} µs")
                if time.time() - outer_start > train_config.time_budget_seconds:
                    print("Time budget exceeded, hence stopping training")
                    break
        ep += 1

    if len(training_losses) is None or len(training_accuracies) is None:
        raise RuntimeError("Training did not run at all")
    if len(test_losses) is None or len(test_accuracies) is None:
        raise RuntimeError("Tests did not run at all")
    
    total_elapsed = time.time() - outer_start
    print(f"Total training time {total_elapsed: .2f} s")
    result_row = ResultRow(
        model_config=model_config,
        train_config=train_config,
        num_parameters=num_params, 
        epochs=ep+1, 
        train_loss=training_losses[max(training_losses.keys())], 
        train_accuracy=training_accuracies[max(training_accuracies.keys())],
        test_loss=test_losses[max(test_losses.keys())],
        test_accuracy=test_accuracies[max(test_accuracies.keys())],
        train_time_in_seconds=total_elapsed,
        time_per_example_in_micros=total_elapsed * 1e6 / ((ep + 1) * train_config.train_batch_size),
        train_losses=training_losses,
        train_accuracies=training_accuracies,
        test_losses=test_losses,
        test_accuracies=test_accuracies,
    )
    return result_row
In [ ]:
bard_results = {}

Using TL Transformer¶

In [ ]:
tlconfig = TLConfig(
    d_model = 64, 
    debug = False, 
    d_vocab = bd.vocab_size, 
    n_ctx = 128, 
    d_head = 16, 
    n_heads = 4, 
    n_layers = 4, 
    d_mlp = 256)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=1024,
    lr=1e-2,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=300,
)
bard_results[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval_bard(
    tlconfig, TLTransformer(tlconfig), bd, train_config, block_size=tlconfig.n_ctx, is_y_single_token=False
)
tlconfig = TLConfig(
    d_model = 96, 
    debug = False, 
    d_vocab = bd.vocab_size, 
    n_ctx = 128, 
    d_head = 24, 
    n_heads = 4, 
    n_layers = 4, 
    d_mlp = 384
)
bard_results[("TLTransformer", f"run_{int(time.time())}t")] = train_and_eval_bard(
    tlconfig, TLTransformer(tlconfig), bd, train_config, block_size=tlconfig.n_ctx, is_y_single_token=False
)
Number of parameters: 216641
Epoch 0, train loss  4.221E+00, test loss  4.210, training accuracy  0.02, test accuracy  0.00, time per example  86.75 µs
Epoch 50, train loss  3.300E+00, test loss  3.335, training accuracy  0.00, test accuracy  0.00, time per example  103.44 µs
Epoch 100, train loss  2.947E+00, test loss  2.962, training accuracy  0.00, test accuracy  0.01, time per example  103.55 µs
Epoch 150, train loss  2.789E+00, test loss  2.784, training accuracy  0.01, test accuracy  0.01, time per example  103.28 µs
Epoch 200, train loss  2.644E+00, test loss  2.639, training accuracy  0.01, test accuracy  0.01, time per example  104.05 µs
Epoch 250, train loss  2.569E+00, test loss  2.556, training accuracy  0.01, test accuracy  0.01, time per example  103.91 µs
Epoch 300, train loss  2.486E+00, test loss  2.509, training accuracy  0.01, test accuracy  0.01, time per example  103.95 µs
Epoch 350, train loss  2.435E+00, test loss  2.458, training accuracy  0.01, test accuracy  0.01, time per example  103.47 µs
Epoch 400, train loss  2.369E+00, test loss  2.418, training accuracy  0.00, test accuracy  0.00, time per example  123.16 µs
Epoch 450, train loss  2.375E+00, test loss  2.403, training accuracy  0.01, test accuracy  0.00, time per example  100.34 µs
Epoch 500, train loss  2.197E+00, test loss  2.266, training accuracy  0.01, test accuracy  0.01, time per example  104.06 µs
Epoch 550, train loss  2.097E+00, test loss  2.179, training accuracy  0.01, test accuracy  0.00, time per example  103.48 µs
Epoch 600, train loss  2.054E+00, test loss  2.143, training accuracy  0.00, test accuracy  0.01, time per example  103.09 µs
Epoch 650, train loss  1.995E+00, test loss  2.085, training accuracy  0.01, test accuracy  0.01, time per example  103.53 µs
Epoch 700, train loss  1.933E+00, test loss  2.075, training accuracy  0.01, test accuracy  0.01, time per example  102.97 µs
Epoch 750, train loss  1.779E+00, test loss  1.926, training accuracy  0.01, test accuracy  0.01, time per example  103.28 µs
Epoch 800, train loss  1.736E+00, test loss  1.894, training accuracy  0.01, test accuracy  0.01, time per example  103.34 µs
Epoch 850, train loss  1.629E+00, test loss  1.829, training accuracy  0.01, test accuracy  0.01, time per example  104.22 µs
Epoch 900, train loss  1.609E+00, test loss  1.779, training accuracy  0.01, test accuracy  0.01, time per example  103.54 µs
Epoch 950, train loss  1.553E+00, test loss  1.756, training accuracy  0.01, test accuracy  0.01, time per example  103.18 µs
Epoch 1000, train loss  1.523E+00, test loss  1.753, training accuracy  0.01, test accuracy  0.01, time per example  103.07 µs
Epoch 1050, train loss  1.517E+00, test loss  1.746, training accuracy  0.00, test accuracy  0.01, time per example  102.80 µs
Epoch 1100, train loss  1.469E+00, test loss  1.709, training accuracy  0.01, test accuracy  0.00, time per example  103.87 µs
Epoch 1150, train loss  1.461E+00, test loss  1.692, training accuracy  0.01, test accuracy  0.01, time per example  104.11 µs
Epoch 1200, train loss  1.420E+00, test loss  1.663, training accuracy  0.01, test accuracy  0.01, time per example  103.55 µs
Epoch 1250, train loss  1.423E+00, test loss  1.657, training accuracy  0.01, test accuracy  0.00, time per example  111.39 µs
Epoch 1300, train loss  1.400E+00, test loss  1.658, training accuracy  0.01, test accuracy  0.01, time per example  93.21 µs
Epoch 1350, train loss  1.389E+00, test loss  1.650, training accuracy  0.00, test accuracy  0.00, time per example  104.23 µs
Epoch 1400, train loss  1.787E+00, test loss  1.917, training accuracy  0.01, test accuracy  0.00, time per example  103.62 µs
Epoch 1450, train loss  1.392E+00, test loss  1.640, training accuracy  0.01, test accuracy  0.00, time per example  103.04 µs
Epoch 1500, train loss  1.363E+00, test loss  1.653, training accuracy  0.00, test accuracy  0.01, time per example  103.73 µs
Epoch 1550, train loss  1.356E+00, test loss  1.621, training accuracy  0.01, test accuracy  0.01, time per example  102.62 µs
Epoch 1600, train loss  1.356E+00, test loss  1.640, training accuracy  0.00, test accuracy  0.01, time per example  102.62 µs
Epoch 1650, train loss  1.342E+00, test loss  1.618, training accuracy  0.01, test accuracy  0.01, time per example  103.59 µs
Epoch 1700, train loss  1.336E+00, test loss  1.627, training accuracy  0.01, test accuracy  0.00, time per example  103.35 µs
Epoch 1750, train loss  1.326E+00, test loss  1.613, training accuracy  0.00, test accuracy  0.00, time per example  103.23 µs
Epoch 1800, train loss  1.315E+00, test loss  1.622, training accuracy  0.00, test accuracy  0.00, time per example  104.13 µs
Epoch 1850, train loss  1.323E+00, test loss  1.615, training accuracy  0.00, test accuracy  0.00, time per example  105.23 µs
Epoch 1900, train loss  1.310E+00, test loss  1.617, training accuracy  0.00, test accuracy  0.00, time per example  103.34 µs
Epoch 1950, train loss  1.311E+00, test loss  1.611, training accuracy  0.00, test accuracy  0.01, time per example  103.51 µs
Epoch 2000, train loss  1.305E+00, test loss  1.636, training accuracy  0.01, test accuracy  0.00, time per example  103.68 µs
Epoch 2050, train loss  1.289E+00, test loss  1.598, training accuracy  0.01, test accuracy  0.01, time per example  104.22 µs
Epoch 2100, train loss  1.300E+00, test loss  1.623, training accuracy  0.01, test accuracy  0.01, time per example  103.89 µs
Epoch 2150, train loss  1.296E+00, test loss  1.582, training accuracy  0.01, test accuracy  0.01, time per example  104.03 µs
Epoch 2200, train loss  1.283E+00, test loss  1.607, training accuracy  0.00, test accuracy  0.00, time per example  103.20 µs
Epoch 2250, train loss  1.278E+00, test loss  1.609, training accuracy  0.01, test accuracy  0.00, time per example  103.32 µs
Epoch 2300, train loss  1.286E+00, test loss  1.607, training accuracy  0.01, test accuracy  0.01, time per example  103.30 µs
Epoch 2350, train loss  1.269E+00, test loss  1.613, training accuracy  0.01, test accuracy  0.01, time per example  103.16 µs
Epoch 2400, train loss  1.263E+00, test loss  1.617, training accuracy  0.01, test accuracy  0.01, time per example  103.18 µs
Epoch 2450, train loss  1.272E+00, test loss  1.590, training accuracy  0.01, test accuracy  0.01, time per example  103.24 µs
Epoch 2500, train loss  1.255E+00, test loss  1.625, training accuracy  0.01, test accuracy  0.01, time per example  103.02 µs
Epoch 2550, train loss  1.257E+00, test loss  1.618, training accuracy  0.01, test accuracy  0.01, time per example  103.09 µs
Epoch 2600, train loss  1.259E+00, test loss  1.633, training accuracy  0.01, test accuracy  0.01, time per example  102.88 µs
Epoch 2650, train loss  1.255E+00, test loss  1.649, training accuracy  0.01, test accuracy  0.01, time per example  103.66 µs
Epoch 2700, train loss  1.256E+00, test loss  1.634, training accuracy  0.01, test accuracy  0.00, time per example  103.62 µs
Epoch 2750, train loss  1.258E+00, test loss  1.623, training accuracy  0.00, test accuracy  0.01, time per example  103.01 µs
Epoch 2800, train loss  1.237E+00, test loss  1.628, training accuracy  0.00, test accuracy  0.01, time per example  102.92 µs
Epoch 2850, train loss  1.247E+00, test loss  1.625, training accuracy  0.00, test accuracy  0.01, time per example  103.67 µs
Epoch 2900, train loss  1.236E+00, test loss  1.613, training accuracy  0.01, test accuracy  0.00, time per example  104.43 µs
Time budget exceeded, hence stopping training
Total training time  300.81 s
Number of parameters: 472385
Epoch 0, train loss  4.199E+00, test loss  4.143, training accuracy  0.00, test accuracy  0.01, time per example  99.80 µs
Epoch 50, train loss  3.158E+00, test loss  3.144, training accuracy  0.00, test accuracy  0.00, time per example  130.82 µs
Epoch 100, train loss  2.906E+00, test loss  2.905, training accuracy  0.01, test accuracy  0.00, time per example  131.30 µs
Epoch 150, train loss  2.671E+00, test loss  2.659, training accuracy  0.03, test accuracy  0.03, time per example  130.64 µs
Epoch 200, train loss  2.590E+00, test loss  2.595, training accuracy  0.02, test accuracy  0.03, time per example  130.83 µs
Epoch 250, train loss  2.531E+00, test loss  2.523, training accuracy  0.02, test accuracy  0.01, time per example  133.76 µs
Epoch 300, train loss  2.487E+00, test loss  2.490, training accuracy  0.01, test accuracy  0.01, time per example  131.56 µs
Epoch 350, train loss  2.424E+00, test loss  2.452, training accuracy  0.02, test accuracy  0.02, time per example  130.67 µs
Epoch 400, train loss  2.347E+00, test loss  2.398, training accuracy  0.01, test accuracy  0.01, time per example  131.11 µs
Epoch 450, train loss  2.273E+00, test loss  2.311, training accuracy  0.01, test accuracy  0.00, time per example  130.74 µs
Epoch 500, train loss  2.254E+00, test loss  2.298, training accuracy  0.01, test accuracy  0.01, time per example  131.04 µs
Epoch 550, train loss  2.120E+00, test loss  2.195, training accuracy  0.01, test accuracy  0.01, time per example  131.05 µs
Epoch 600, train loss  2.090E+00, test loss  2.181, training accuracy  0.01, test accuracy  0.01, time per example  131.11 µs
Epoch 650, train loss  2.001E+00, test loss  2.142, training accuracy  0.01, test accuracy  0.01, time per example  131.90 µs
Epoch 700, train loss  1.888E+00, test loss  2.029, training accuracy  0.00, test accuracy  0.01, time per example  132.29 µs
Epoch 750, train loss  1.843E+00, test loss  1.989, training accuracy  0.00, test accuracy  0.00, time per example  131.19 µs
Epoch 800, train loss  1.757E+00, test loss  1.926, training accuracy  0.01, test accuracy  0.01, time per example  130.61 µs
Epoch 850, train loss  1.720E+00, test loss  1.923, training accuracy  0.01, test accuracy  0.01, time per example  131.67 µs
Epoch 900, train loss  1.654E+00, test loss  1.845, training accuracy  0.01, test accuracy  0.01, time per example  130.85 µs
Epoch 950, train loss  1.682E+00, test loss  1.879, training accuracy  0.01, test accuracy  0.01, time per example  130.97 µs
Epoch 1000, train loss  1.549E+00, test loss  1.768, training accuracy  0.01, test accuracy  0.01, time per example  131.76 µs
Epoch 1050, train loss  1.506E+00, test loss  1.736, training accuracy  0.01, test accuracy  0.01, time per example  132.47 µs
Epoch 1100, train loss  1.475E+00, test loss  1.701, training accuracy  0.01, test accuracy  0.01, time per example  131.71 µs
Epoch 1150, train loss  1.452E+00, test loss  1.723, training accuracy  0.01, test accuracy  0.01, time per example  130.95 µs
Epoch 1200, train loss  1.413E+00, test loss  1.698, training accuracy  0.01, test accuracy  0.00, time per example  131.56 µs
Epoch 1250, train loss  1.396E+00, test loss  1.676, training accuracy  0.00, test accuracy  0.00, time per example  130.85 µs
Epoch 1300, train loss  1.382E+00, test loss  1.663, training accuracy  0.01, test accuracy  0.01, time per example  131.45 µs
Epoch 1350, train loss  1.369E+00, test loss  1.632, training accuracy  0.01, test accuracy  0.01, time per example  130.92 µs
Epoch 1400, train loss  1.352E+00, test loss  1.625, training accuracy  0.01, test accuracy  0.00, time per example  131.19 µs
Epoch 1450, train loss  1.339E+00, test loss  1.648, training accuracy  0.00, test accuracy  0.01, time per example  131.13 µs
Epoch 1500, train loss  1.318E+00, test loss  1.625, training accuracy  0.01, test accuracy  0.01, time per example  131.40 µs
Epoch 1550, train loss  1.295E+00, test loss  1.605, training accuracy  0.01, test accuracy  0.00, time per example  131.12 µs
Epoch 1600, train loss  1.297E+00, test loss  1.608, training accuracy  0.01, test accuracy  0.01, time per example  132.11 µs
Epoch 1650, train loss  1.292E+00, test loss  1.616, training accuracy  0.01, test accuracy  0.01, time per example  131.17 µs
Epoch 1700, train loss  1.279E+00, test loss  1.619, training accuracy  0.01, test accuracy  0.01, time per example  130.97 µs
Epoch 1750, train loss  1.270E+00, test loss  1.613, training accuracy  0.01, test accuracy  0.01, time per example  130.65 µs
Epoch 1800, train loss  1.266E+00, test loss  1.621, training accuracy  0.01, test accuracy  0.01, time per example  130.82 µs
Epoch 1850, train loss  1.252E+00, test loss  1.629, training accuracy  0.01, test accuracy  0.01, time per example  132.94 µs
Epoch 1900, train loss  1.242E+00, test loss  1.626, training accuracy  0.01, test accuracy  0.01, time per example  130.14 µs
Epoch 1950, train loss  1.245E+00, test loss  1.624, training accuracy  0.00, test accuracy  0.01, time per example  130.73 µs
Epoch 2000, train loss  1.232E+00, test loss  1.622, training accuracy  0.00, test accuracy  0.01, time per example  130.76 µs
Epoch 2050, train loss  1.227E+00, test loss  1.613, training accuracy  0.01, test accuracy  0.01, time per example  130.94 µs
Epoch 2100, train loss  1.218E+00, test loss  1.660, training accuracy  0.01, test accuracy  0.00, time per example  131.15 µs
Epoch 2150, train loss  1.209E+00, test loss  1.647, training accuracy  0.01, test accuracy  0.00, time per example  134.50 µs
Epoch 2200, train loss  1.194E+00, test loss  1.659, training accuracy  0.00, test accuracy  0.00, time per example  130.24 µs
Epoch 2250, train loss  1.203E+00, test loss  1.648, training accuracy  0.00, test accuracy  0.00, time per example  130.75 µs
Time budget exceeded, hence stopping training
Total training time  303.39 s
In [ ]:
 

Using MLP¶

In [ ]:
cfg1 = MLPForSeq2SeqConfig(
    vocab_size=bd.vocab_size,
    input_len=128,
    n_embed=16,
    n_hidden=256,
    output_len=1)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-2,
    weight_decay=1e-4,
    epoch_interval=200,
    time_budget_seconds=300,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
    cfg1, MLPForSeq2Seq(cfg1), bd, train_config, block_size=cfg1.input_len, is_y_single_token=True
)
cfg2 = MLPForSeq2SeqConfig(
    vocab_size=bd.vocab_size,
    input_len=128,
    n_embed=8,
    n_hidden=128,
    output_len=1)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-2,
    weight_decay=1e-4,
    epoch_interval=200,
    time_budget_seconds=300,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
    cfg2, MLPForSeq2Seq(cfg2), bd, train_config, block_size=cfg2.input_len, is_y_single_token=True
)
Number of parameters: 542289
Epoch 0, train loss  4.216E+00, test loss  5.999, training accuracy  0.01, test accuracy  0.16, time per example  19.51 µs
Epoch 200, train loss  2.293E+00, test loss  2.493, training accuracy  0.33, test accuracy  0.29, time per example  18.44 µs
Epoch 400, train loss  2.301E+00, test loss  2.315, training accuracy  0.33, test accuracy  0.32, time per example  17.59 µs
Epoch 600, train loss  2.148E+00, test loss  2.259, training accuracy  0.37, test accuracy  0.35, time per example  27.56 µs
Epoch 800, train loss  2.176E+00, test loss  2.232, training accuracy  0.37, test accuracy  0.36, time per example  18.03 µs
Epoch 1000, train loss  2.117E+00, test loss  2.216, training accuracy  0.37, test accuracy  0.38, time per example  17.59 µs
Epoch 1200, train loss  2.098E+00, test loss  2.210, training accuracy  0.40, test accuracy  0.37, time per example  18.21 µs
Epoch 1400, train loss  2.049E+00, test loss  2.166, training accuracy  0.41, test accuracy  0.39, time per example  17.85 µs
Epoch 1600, train loss  2.038E+00, test loss  2.206, training accuracy  0.40, test accuracy  0.38, time per example  17.80 µs
Epoch 1800, train loss  2.033E+00, test loss  2.208, training accuracy  0.40, test accuracy  0.38, time per example  17.91 µs
Epoch 2000, train loss  1.944E+00, test loss  2.181, training accuracy  0.43, test accuracy  0.38, time per example  17.90 µs
Epoch 2200, train loss  1.977E+00, test loss  2.215, training accuracy  0.42, test accuracy  0.37, time per example  18.36 µs
Epoch 2400, train loss  1.966E+00, test loss  2.184, training accuracy  0.42, test accuracy  0.38, time per example  17.77 µs
Epoch 2600, train loss  1.969E+00, test loss  2.240, training accuracy  0.44, test accuracy  0.37, time per example  17.69 µs
Epoch 2800, train loss  1.950E+00, test loss  2.210, training accuracy  0.43, test accuracy  0.39, time per example  17.81 µs
Epoch 3000, train loss  1.978E+00, test loss  2.164, training accuracy  0.42, test accuracy  0.39, time per example  19.78 µs
Epoch 3200, train loss  1.917E+00, test loss  2.166, training accuracy  0.44, test accuracy  0.40, time per example  18.15 µs
Epoch 3400, train loss  1.908E+00, test loss  2.263, training accuracy  0.44, test accuracy  0.36, time per example  18.26 µs
Epoch 3600, train loss  1.895E+00, test loss  2.189, training accuracy  0.44, test accuracy  0.39, time per example  18.25 µs
Epoch 3800, train loss  1.931E+00, test loss  2.216, training accuracy  0.43, test accuracy  0.39, time per example  17.69 µs
Epoch 4000, train loss  1.920E+00, test loss  2.156, training accuracy  0.44, test accuracy  0.39, time per example  18.02 µs
Epoch 4200, train loss  1.856E+00, test loss  2.158, training accuracy  0.44, test accuracy  0.40, time per example  18.22 µs
Epoch 4400, train loss  1.918E+00, test loss  2.265, training accuracy  0.44, test accuracy  0.38, time per example  17.73 µs
Epoch 4600, train loss  1.930E+00, test loss  2.173, training accuracy  0.43, test accuracy  0.40, time per example  17.80 µs
Epoch 4800, train loss  1.900E+00, test loss  2.227, training accuracy  0.44, test accuracy  0.36, time per example  18.13 µs
Epoch 5000, train loss  1.847E+00, test loss  2.218, training accuracy  0.45, test accuracy  0.39, time per example  17.68 µs
Epoch 5200, train loss  1.874E+00, test loss  2.182, training accuracy  0.45, test accuracy  0.38, time per example  26.37 µs
Epoch 5400, train loss  1.841E+00, test loss  2.208, training accuracy  0.45, test accuracy  0.39, time per example  17.81 µs
Epoch 5600, train loss  1.850E+00, test loss  2.178, training accuracy  0.45, test accuracy  0.41, time per example  18.08 µs
Epoch 5800, train loss  1.860E+00, test loss  2.261, training accuracy  0.44, test accuracy  0.38, time per example  17.93 µs
Epoch 6000, train loss  1.889E+00, test loss  2.233, training accuracy  0.44, test accuracy  0.38, time per example  17.84 µs
Epoch 6200, train loss  1.875E+00, test loss  2.121, training accuracy  0.44, test accuracy  0.41, time per example  17.65 µs
Epoch 6400, train loss  1.848E+00, test loss  2.098, training accuracy  0.45, test accuracy  0.42, time per example  18.00 µs
Epoch 6600, train loss  1.835E+00, test loss  2.122, training accuracy  0.46, test accuracy  0.38, time per example  19.32 µs
Epoch 6800, train loss  1.821E+00, test loss  2.099, training accuracy  0.47, test accuracy  0.40, time per example  18.07 µs
Epoch 7000, train loss  1.874E+00, test loss  2.208, training accuracy  0.46, test accuracy  0.39, time per example  17.72 µs
Epoch 7200, train loss  1.824E+00, test loss  2.192, training accuracy  0.45, test accuracy  0.41, time per example  17.69 µs
Epoch 7400, train loss  1.888E+00, test loss  2.162, training accuracy  0.45, test accuracy  0.40, time per example  19.13 µs
Epoch 7600, train loss  1.848E+00, test loss  2.182, training accuracy  0.46, test accuracy  0.39, time per example  18.02 µs
Time budget exceeded, hence stopping training
Total training time  302.81 s
Number of parameters: 140105
Epoch 0, train loss  4.215E+00, test loss  3.657, training accuracy  0.01, test accuracy  0.14, time per example  21.02 µs
Epoch 200, train loss  2.320E+00, test loss  2.374, training accuracy  0.33, test accuracy  0.33, time per example  106.10 µs
Epoch 400, train loss  2.138E+00, test loss  2.230, training accuracy  0.37, test accuracy  0.34, time per example  18.01 µs
Epoch 600, train loss  2.061E+00, test loss  2.218, training accuracy  0.40, test accuracy  0.36, time per example  22.09 µs
Epoch 800, train loss  1.994E+00, test loss  2.136, training accuracy  0.42, test accuracy  0.39, time per example  17.73 µs
Epoch 1000, train loss  1.926E+00, test loss  2.165, training accuracy  0.44, test accuracy  0.38, time per example  17.38 µs
Epoch 1200, train loss  1.967E+00, test loss  2.192, training accuracy  0.41, test accuracy  0.37, time per example  17.61 µs
Epoch 1400, train loss  1.972E+00, test loss  2.099, training accuracy  0.42, test accuracy  0.39, time per example  17.68 µs
Epoch 1600, train loss  1.919E+00, test loss  2.155, training accuracy  0.43, test accuracy  0.38, time per example  17.81 µs
Epoch 1800, train loss  1.938E+00, test loss  2.146, training accuracy  0.43, test accuracy  0.39, time per example  17.54 µs
Epoch 2000, train loss  1.948E+00, test loss  2.110, training accuracy  0.43, test accuracy  0.40, time per example  18.01 µs
Epoch 2200, train loss  1.978E+00, test loss  2.128, training accuracy  0.42, test accuracy  0.39, time per example  21.03 µs
Epoch 2400, train loss  1.901E+00, test loss  2.104, training accuracy  0.44, test accuracy  0.39, time per example  18.00 µs
Epoch 2600, train loss  1.926E+00, test loss  2.126, training accuracy  0.43, test accuracy  0.40, time per example  17.99 µs
Epoch 2800, train loss  1.862E+00, test loss  2.117, training accuracy  0.47, test accuracy  0.38, time per example  17.66 µs
Epoch 3000, train loss  1.865E+00, test loss  2.098, training accuracy  0.45, test accuracy  0.41, time per example  17.76 µs
Epoch 3200, train loss  1.895E+00, test loss  2.087, training accuracy  0.44, test accuracy  0.42, time per example  17.57 µs
Epoch 3400, train loss  1.872E+00, test loss  2.025, training accuracy  0.45, test accuracy  0.43, time per example  17.48 µs
Epoch 3600, train loss  1.846E+00, test loss  2.057, training accuracy  0.45, test accuracy  0.41, time per example  18.18 µs
Epoch 3800, train loss  1.904E+00, test loss  2.157, training accuracy  0.44, test accuracy  0.38, time per example  18.20 µs
Epoch 4000, train loss  1.820E+00, test loss  2.060, training accuracy  0.46, test accuracy  0.41, time per example  17.69 µs
Epoch 4200, train loss  1.847E+00, test loss  2.148, training accuracy  0.45, test accuracy  0.40, time per example  18.02 µs
Epoch 4400, train loss  1.885E+00, test loss  2.066, training accuracy  0.44, test accuracy  0.42, time per example  17.57 µs
Epoch 4600, train loss  1.877E+00, test loss  2.058, training accuracy  0.45, test accuracy  0.41, time per example  17.71 µs
Epoch 4800, train loss  1.795E+00, test loss  2.097, training accuracy  0.47, test accuracy  0.40, time per example  17.78 µs
Epoch 5000, train loss  1.871E+00, test loss  2.150, training accuracy  0.44, test accuracy  0.39, time per example  17.54 µs
Epoch 5200, train loss  1.852E+00, test loss  2.071, training accuracy  0.47, test accuracy  0.40, time per example  17.56 µs
Epoch 5400, train loss  1.853E+00, test loss  2.201, training accuracy  0.46, test accuracy  0.36, time per example  17.37 µs
Epoch 5600, train loss  1.840E+00, test loss  2.063, training accuracy  0.46, test accuracy  0.41, time per example  17.35 µs
Epoch 5800, train loss  1.844E+00, test loss  2.103, training accuracy  0.44, test accuracy  0.41, time per example  17.73 µs
Epoch 6000, train loss  1.856E+00, test loss  2.212, training accuracy  0.46, test accuracy  0.39, time per example  17.63 µs
Epoch 6200, train loss  1.804E+00, test loss  2.143, training accuracy  0.46, test accuracy  0.41, time per example  17.56 µs
Epoch 6400, train loss  1.832E+00, test loss  2.145, training accuracy  0.45, test accuracy  0.39, time per example  17.54 µs
Epoch 6600, train loss  1.790E+00, test loss  2.162, training accuracy  0.47, test accuracy  0.38, time per example  17.70 µs
Epoch 6800, train loss  1.809E+00, test loss  2.100, training accuracy  0.46, test accuracy  0.41, time per example  17.56 µs
Epoch 7000, train loss  1.799E+00, test loss  2.115, training accuracy  0.47, test accuracy  0.41, time per example  17.38 µs
Epoch 7200, train loss  1.809E+00, test loss  2.154, training accuracy  0.46, test accuracy  0.39, time per example  18.01 µs
Epoch 7400, train loss  1.874E+00, test loss  2.168, training accuracy  0.45, test accuracy  0.38, time per example  17.74 µs
Epoch 7600, train loss  1.795E+00, test loss  2.068, training accuracy  0.46, test accuracy  0.42, time per example  17.78 µs
Epoch 7800, train loss  1.831E+00, test loss  2.135, training accuracy  0.46, test accuracy  0.40, time per example  17.61 µs
Time budget exceeded, hence stopping training
Total training time  307.70 s
In [ ]:
cfg = MLPForSeq2SeqConfig(
    vocab_size=bd.vocab_size,
    input_len=128,
    n_embed=16,
    n_hidden=64,
    output_len=1)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-2,
    weight_decay=1e-4,
    epoch_interval=25,
    time_budget_seconds=600,
)
bard_results[("MLP", f"run_{int(time.time())}t")] = train_and_eval_bard(
    cfg, MLPForSeq2Seq(cfg), bd, train_config, block_size=cfg.input_len, is_y_single_token=True
)
Number of parameters: 136401
Epoch 0, train loss  4.149E+00, test loss  4.607, training accuracy  0.02, test accuracy  0.17, time per example  51.37 µs
Epoch 25, train loss  3.201E+00, test loss  3.214, training accuracy  0.18, test accuracy  0.18, time per example  85.03 µs
Epoch 50, train loss  2.968E+00, test loss  2.920, training accuracy  0.21, test accuracy  0.22, time per example  88.08 µs
Epoch 75, train loss  2.727E+00, test loss  2.719, training accuracy  0.24, test accuracy  0.24, time per example  45.28 µs
Epoch 100, train loss  2.569E+00, test loss  2.584, training accuracy  0.28, test accuracy  0.27, time per example  50.96 µs
Epoch 125, train loss  2.512E+00, test loss  2.555, training accuracy  0.29, test accuracy  0.28, time per example  62.23 µs
Epoch 150, train loss  2.437E+00, test loss  2.460, training accuracy  0.29, test accuracy  0.30, time per example  82.49 µs
Epoch 175, train loss  2.460E+00, test loss  2.432, training accuracy  0.30, test accuracy  0.29, time per example  52.60 µs
Epoch 200, train loss  2.372E+00, test loss  2.405, training accuracy  0.33, test accuracy  0.32, time per example  50.60 µs
Epoch 225, train loss  2.392E+00, test loss  2.345, training accuracy  0.30, test accuracy  0.32, time per example  51.51 µs
Epoch 250, train loss  2.376E+00, test loss  2.399, training accuracy  0.33, test accuracy  0.31, time per example  53.17 µs
Epoch 275, train loss  2.361E+00, test loss  2.300, training accuracy  0.33, test accuracy  0.35, time per example  89.10 µs
Epoch 300, train loss  2.300E+00, test loss  2.345, training accuracy  0.35, test accuracy  0.32, time per example  54.12 µs
Epoch 325, train loss  2.310E+00, test loss  2.331, training accuracy  0.35, test accuracy  0.32, time per example  52.23 µs
Epoch 350, train loss  2.251E+00, test loss  2.405, training accuracy  0.35, test accuracy  0.32, time per example  56.03 µs
Epoch 375, train loss  2.283E+00, test loss  2.366, training accuracy  0.35, test accuracy  0.32, time per example  83.41 µs
Epoch 400, train loss  2.235E+00, test loss  2.322, training accuracy  0.35, test accuracy  0.33, time per example  62.75 µs
Epoch 425, train loss  2.297E+00, test loss  2.283, training accuracy  0.35, test accuracy  0.35, time per example  84.59 µs
Epoch 450, train loss  2.267E+00, test loss  2.300, training accuracy  0.34, test accuracy  0.34, time per example  58.65 µs
Epoch 475, train loss  2.259E+00, test loss  2.359, training accuracy  0.35, test accuracy  0.32, time per example  80.48 µs
Epoch 500, train loss  2.256E+00, test loss  2.327, training accuracy  0.35, test accuracy  0.33, time per example  51.15 µs
Epoch 525, train loss  2.210E+00, test loss  2.213, training accuracy  0.37, test accuracy  0.36, time per example  53.42 µs
Epoch 550, train loss  2.243E+00, test loss  2.329, training accuracy  0.36, test accuracy  0.35, time per example  52.59 µs
Epoch 575, train loss  2.205E+00, test loss  2.285, training accuracy  0.36, test accuracy  0.36, time per example  54.17 µs
Epoch 600, train loss  2.220E+00, test loss  2.285, training accuracy  0.36, test accuracy  0.34, time per example  75.96 µs
Epoch 625, train loss  2.194E+00, test loss  2.276, training accuracy  0.38, test accuracy  0.34, time per example  53.92 µs
Epoch 650, train loss  2.209E+00, test loss  2.340, training accuracy  0.36, test accuracy  0.33, time per example  52.06 µs
Epoch 675, train loss  2.155E+00, test loss  2.271, training accuracy  0.38, test accuracy  0.36, time per example  57.34 µs
Epoch 700, train loss  2.185E+00, test loss  2.316, training accuracy  0.38, test accuracy  0.36, time per example  75.45 µs
Epoch 725, train loss  2.133E+00, test loss  2.260, training accuracy  0.38, test accuracy  0.36, time per example  55.67 µs
Epoch 750, train loss  2.135E+00, test loss  2.308, training accuracy  0.37, test accuracy  0.35, time per example  51.90 µs
Epoch 775, train loss  2.186E+00, test loss  2.280, training accuracy  0.36, test accuracy  0.35, time per example  52.92 µs
Epoch 800, train loss  2.158E+00, test loss  2.344, training accuracy  0.38, test accuracy  0.33, time per example  57.37 µs
Epoch 825, train loss  2.176E+00, test loss  2.240, training accuracy  0.37, test accuracy  0.35, time per example  95.01 µs
Epoch 850, train loss  2.102E+00, test loss  2.281, training accuracy  0.38, test accuracy  0.34, time per example  85.03 µs
Epoch 875, train loss  2.121E+00, test loss  2.286, training accuracy  0.40, test accuracy  0.33, time per example  59.99 µs
Epoch 900, train loss  2.132E+00, test loss  2.226, training accuracy  0.39, test accuracy  0.36, time per example  51.90 µs
Epoch 925, train loss  2.161E+00, test loss  2.249, training accuracy  0.37, test accuracy  0.36, time per example  73.39 µs
Epoch 950, train loss  2.121E+00, test loss  2.279, training accuracy  0.39, test accuracy  0.37, time per example  79.24 µs
Epoch 975, train loss  2.193E+00, test loss  2.186, training accuracy  0.37, test accuracy  0.37, time per example  52.89 µs
Epoch 1000, train loss  2.092E+00, test loss  2.274, training accuracy  0.40, test accuracy  0.35, time per example  52.59 µs
Epoch 1025, train loss  2.105E+00, test loss  2.209, training accuracy  0.39, test accuracy  0.37, time per example  51.29 µs
Epoch 1050, train loss  2.143E+00, test loss  2.237, training accuracy  0.38, test accuracy  0.35, time per example  83.31 µs
Epoch 1075, train loss  2.122E+00, test loss  2.244, training accuracy  0.38, test accuracy  0.35, time per example  52.49 µs
Epoch 1100, train loss  2.161E+00, test loss  2.269, training accuracy  0.38, test accuracy  0.35, time per example  52.72 µs
Epoch 1125, train loss  2.085E+00, test loss  2.250, training accuracy  0.40, test accuracy  0.36, time per example  49.41 µs
Epoch 1150, train loss  2.137E+00, test loss  2.284, training accuracy  0.37, test accuracy  0.35, time per example  48.66 µs
Epoch 1175, train loss  2.104E+00, test loss  2.195, training accuracy  0.38, test accuracy  0.37, time per example  84.92 µs
Epoch 1200, train loss  2.121E+00, test loss  2.253, training accuracy  0.40, test accuracy  0.36, time per example  48.26 µs
Epoch 1225, train loss  2.103E+00, test loss  2.269, training accuracy  0.39, test accuracy  0.36, time per example  83.12 µs
Epoch 1250, train loss  2.113E+00, test loss  2.240, training accuracy  0.40, test accuracy  0.36, time per example  44.94 µs
Epoch 1275, train loss  2.059E+00, test loss  2.322, training accuracy  0.41, test accuracy  0.34, time per example  80.86 µs
Epoch 1300, train loss  2.188E+00, test loss  2.246, training accuracy  0.36, test accuracy  0.35, time per example  52.28 µs
Epoch 1325, train loss  2.114E+00, test loss  2.241, training accuracy  0.40, test accuracy  0.35, time per example  57.34 µs
Epoch 1350, train loss  2.111E+00, test loss  2.221, training accuracy  0.38, test accuracy  0.36, time per example  57.76 µs
Epoch 1375, train loss  2.105E+00, test loss  2.229, training accuracy  0.39, test accuracy  0.36, time per example  60.24 µs
Epoch 1400, train loss  2.142E+00, test loss  2.230, training accuracy  0.37, test accuracy  0.37, time per example  76.50 µs
Epoch 1425, train loss  2.067E+00, test loss  2.222, training accuracy  0.40, test accuracy  0.36, time per example  50.90 µs
Epoch 1450, train loss  2.101E+00, test loss  2.256, training accuracy  0.37, test accuracy  0.35, time per example  59.67 µs
Epoch 1475, train loss  2.119E+00, test loss  2.245, training accuracy  0.39, test accuracy  0.36, time per example  58.59 µs
Epoch 1500, train loss  2.164E+00, test loss  2.272, training accuracy  0.37, test accuracy  0.34, time per example  83.59 µs
Epoch 1525, train loss  2.091E+00, test loss  2.249, training accuracy  0.40, test accuracy  0.36, time per example  52.30 µs
Epoch 1550, train loss  2.117E+00, test loss  2.156, training accuracy  0.39, test accuracy  0.37, time per example  65.28 µs
Epoch 1575, train loss  2.106E+00, test loss  2.270, training accuracy  0.39, test accuracy  0.36, time per example  51.39 µs
Epoch 1600, train loss  2.137E+00, test loss  2.280, training accuracy  0.39, test accuracy  0.35, time per example  53.02 µs
Epoch 1625, train loss  2.126E+00, test loss  2.318, training accuracy  0.38, test accuracy  0.33, time per example  85.76 µs
Epoch 1650, train loss  2.060E+00, test loss  2.220, training accuracy  0.40, test accuracy  0.36, time per example  83.03 µs
Epoch 1675, train loss  2.048E+00, test loss  2.279, training accuracy  0.40, test accuracy  0.33, time per example  54.12 µs
Epoch 1700, train loss  2.113E+00, test loss  2.218, training accuracy  0.39, test accuracy  0.37, time per example  46.88 µs
Epoch 1725, train loss  2.091E+00, test loss  2.210, training accuracy  0.40, test accuracy  0.36, time per example  81.26 µs
Epoch 1750, train loss  2.110E+00, test loss  2.257, training accuracy  0.38, test accuracy  0.36, time per example  59.73 µs
Epoch 1775, train loss  2.062E+00, test loss  2.277, training accuracy  0.40, test accuracy  0.37, time per example  52.60 µs
Epoch 1800, train loss  2.083E+00, test loss  2.197, training accuracy  0.39, test accuracy  0.37, time per example  51.79 µs
Epoch 1825, train loss  2.072E+00, test loss  2.267, training accuracy  0.40, test accuracy  0.36, time per example  53.23 µs
Epoch 1850, train loss  2.060E+00, test loss  2.199, training accuracy  0.42, test accuracy  0.37, time per example  77.29 µs
Epoch 1875, train loss  2.038E+00, test loss  2.240, training accuracy  0.39, test accuracy  0.36, time per example  52.67 µs
Epoch 1900, train loss  2.071E+00, test loss  2.282, training accuracy  0.40, test accuracy  0.36, time per example  58.44 µs
Epoch 1925, train loss  2.116E+00, test loss  2.239, training accuracy  0.39, test accuracy  0.37, time per example  51.54 µs
Epoch 1950, train loss  2.077E+00, test loss  2.214, training accuracy  0.40, test accuracy  0.36, time per example  90.86 µs
Epoch 1975, train loss  2.069E+00, test loss  2.260, training accuracy  0.41, test accuracy  0.36, time per example  56.73 µs
Epoch 2000, train loss  2.117E+00, test loss  2.180, training accuracy  0.39, test accuracy  0.38, time per example  54.01 µs
Epoch 2025, train loss  2.118E+00, test loss  2.194, training accuracy  0.39, test accuracy  0.37, time per example  53.36 µs
Epoch 2050, train loss  2.124E+00, test loss  2.220, training accuracy  0.40, test accuracy  0.38, time per example  89.01 µs
Epoch 2075, train loss  2.061E+00, test loss  2.199, training accuracy  0.40, test accuracy  0.36, time per example  85.78 µs
Epoch 2100, train loss  2.059E+00, test loss  2.146, training accuracy  0.40, test accuracy  0.39, time per example  54.12 µs
Epoch 2125, train loss  2.068E+00, test loss  2.237, training accuracy  0.40, test accuracy  0.36, time per example  58.74 µs
Epoch 2150, train loss  2.010E+00, test loss  2.205, training accuracy  0.41, test accuracy  0.36, time per example  212.56 µs
Epoch 2175, train loss  2.060E+00, test loss  2.193, training accuracy  0.40, test accuracy  0.38, time per example  83.75 µs
Epoch 2200, train loss  2.005E+00, test loss  2.242, training accuracy  0.41, test accuracy  0.36, time per example  51.65 µs
Epoch 2225, train loss  2.006E+00, test loss  2.221, training accuracy  0.42, test accuracy  0.36, time per example  52.76 µs
Epoch 2250, train loss  2.030E+00, test loss  2.235, training accuracy  0.41, test accuracy  0.36, time per example  51.50 µs
Epoch 2275, train loss  2.104E+00, test loss  2.198, training accuracy  0.40, test accuracy  0.36, time per example  52.65 µs
Epoch 2300, train loss  2.064E+00, test loss  2.186, training accuracy  0.39, test accuracy  0.36, time per example  82.65 µs
Epoch 2325, train loss  1.985E+00, test loss  2.273, training accuracy  0.43, test accuracy  0.34, time per example  55.74 µs
Epoch 2350, train loss  2.097E+00, test loss  2.284, training accuracy  0.39, test accuracy  0.35, time per example  58.31 µs
Epoch 2375, train loss  2.031E+00, test loss  2.244, training accuracy  0.41, test accuracy  0.36, time per example  67.23 µs
Epoch 2400, train loss  2.059E+00, test loss  2.237, training accuracy  0.40, test accuracy  0.36, time per example  89.22 µs
Epoch 2425, train loss  2.065E+00, test loss  2.233, training accuracy  0.41, test accuracy  0.35, time per example  53.10 µs
Epoch 2450, train loss  2.066E+00, test loss  2.263, training accuracy  0.40, test accuracy  0.36, time per example  70.09 µs
Epoch 2475, train loss  2.038E+00, test loss  2.234, training accuracy  0.41, test accuracy  0.36, time per example  56.24 µs
Epoch 2500, train loss  2.015E+00, test loss  2.195, training accuracy  0.41, test accuracy  0.38, time per example  72.66 µs
Epoch 2525, train loss  2.027E+00, test loss  2.191, training accuracy  0.40, test accuracy  0.36, time per example  82.10 µs
Epoch 2550, train loss  2.070E+00, test loss  2.207, training accuracy  0.41, test accuracy  0.36, time per example  52.18 µs
Epoch 2575, train loss  2.043E+00, test loss  2.227, training accuracy  0.40, test accuracy  0.36, time per example  43.97 µs
Epoch 2600, train loss  2.075E+00, test loss  2.157, training accuracy  0.40, test accuracy  0.39, time per example  51.23 µs
Epoch 2625, train loss  2.021E+00, test loss  2.216, training accuracy  0.41, test accuracy  0.37, time per example  77.24 µs
Epoch 2650, train loss  2.087E+00, test loss  2.256, training accuracy  0.38, test accuracy  0.36, time per example  44.60 µs
Epoch 2675, train loss  2.064E+00, test loss  2.231, training accuracy  0.41, test accuracy  0.36, time per example  58.65 µs
Epoch 2700, train loss  2.070E+00, test loss  2.240, training accuracy  0.40, test accuracy  0.35, time per example  53.43 µs
Epoch 2725, train loss  2.090E+00, test loss  2.190, training accuracy  0.39, test accuracy  0.38, time per example  51.50 µs
Epoch 2750, train loss  2.055E+00, test loss  2.208, training accuracy  0.41, test accuracy  0.37, time per example  81.13 µs
Epoch 2775, train loss  1.991E+00, test loss  2.209, training accuracy  0.42, test accuracy  0.36, time per example  51.20 µs
Epoch 2800, train loss  2.017E+00, test loss  2.253, training accuracy  0.42, test accuracy  0.38, time per example  44.74 µs
Epoch 2825, train loss  2.000E+00, test loss  2.257, training accuracy  0.41, test accuracy  0.36, time per example  49.46 µs
Epoch 2850, train loss  2.044E+00, test loss  2.245, training accuracy  0.41, test accuracy  0.37, time per example  47.12 µs
Epoch 2875, train loss  2.035E+00, test loss  2.176, training accuracy  0.42, test accuracy  0.37, time per example  95.37 µs
Epoch 2900, train loss  2.008E+00, test loss  2.221, training accuracy  0.43, test accuracy  0.35, time per example  52.43 µs
Epoch 2925, train loss  2.068E+00, test loss  2.213, training accuracy  0.41, test accuracy  0.37, time per example  48.14 µs
Epoch 2950, train loss  2.028E+00, test loss  2.276, training accuracy  0.42, test accuracy  0.35, time per example  63.92 µs
Epoch 2975, train loss  2.044E+00, test loss  2.222, training accuracy  0.41, test accuracy  0.38, time per example  54.70 µs
Epoch 3000, train loss  2.042E+00, test loss  2.272, training accuracy  0.40, test accuracy  0.36, time per example  75.18 µs
Epoch 3025, train loss  2.025E+00, test loss  2.220, training accuracy  0.41, test accuracy  0.37, time per example  53.82 µs
Epoch 3050, train loss  2.044E+00, test loss  2.180, training accuracy  0.40, test accuracy  0.36, time per example  51.23 µs
Epoch 3075, train loss  2.041E+00, test loss  2.262, training accuracy  0.39, test accuracy  0.37, time per example  56.45 µs
Epoch 3100, train loss  2.085E+00, test loss  2.246, training accuracy  0.39, test accuracy  0.36, time per example  104.63 µs
Epoch 3125, train loss  2.065E+00, test loss  2.256, training accuracy  0.41, test accuracy  0.36, time per example  52.41 µs
Epoch 3150, train loss  2.028E+00, test loss  2.169, training accuracy  0.41, test accuracy  0.37, time per example  50.84 µs
Epoch 3175, train loss  2.023E+00, test loss  2.211, training accuracy  0.40, test accuracy  0.38, time per example  53.56 µs
Epoch 3200, train loss  2.076E+00, test loss  2.223, training accuracy  0.40, test accuracy  0.36, time per example  65.03 µs
Epoch 3225, train loss  2.005E+00, test loss  2.211, training accuracy  0.42, test accuracy  0.37, time per example  81.99 µs
Epoch 3250, train loss  2.070E+00, test loss  2.209, training accuracy  0.39, test accuracy  0.36, time per example  52.90 µs
Epoch 3275, train loss  2.061E+00, test loss  2.171, training accuracy  0.40, test accuracy  0.37, time per example  81.10 µs
Epoch 3300, train loss  2.036E+00, test loss  2.243, training accuracy  0.39, test accuracy  0.36, time per example  50.79 µs
Epoch 3325, train loss  1.998E+00, test loss  2.253, training accuracy  0.42, test accuracy  0.38, time per example  82.06 µs
Epoch 3350, train loss  2.002E+00, test loss  2.291, training accuracy  0.42, test accuracy  0.35, time per example  47.06 µs
Epoch 3375, train loss  1.991E+00, test loss  2.199, training accuracy  0.41, test accuracy  0.39, time per example  45.69 µs
Epoch 3400, train loss  2.047E+00, test loss  2.242, training accuracy  0.40, test accuracy  0.37, time per example  45.34 µs
Epoch 3425, train loss  2.023E+00, test loss  2.260, training accuracy  0.42, test accuracy  0.36, time per example  51.17 µs
Epoch 3450, train loss  1.984E+00, test loss  2.218, training accuracy  0.42, test accuracy  0.37, time per example  81.51 µs
Epoch 3475, train loss  1.973E+00, test loss  2.148, training accuracy  0.42, test accuracy  0.37, time per example  51.27 µs
Epoch 3500, train loss  2.072E+00, test loss  2.197, training accuracy  0.39, test accuracy  0.36, time per example  53.87 µs
Epoch 3525, train loss  2.063E+00, test loss  2.213, training accuracy  0.41, test accuracy  0.37, time per example  53.11 µs
Epoch 3550, train loss  2.040E+00, test loss  2.226, training accuracy  0.40, test accuracy  0.37, time per example  82.27 µs
Epoch 3575, train loss  1.959E+00, test loss  2.233, training accuracy  0.42, test accuracy  0.37, time per example  59.19 µs
Epoch 3600, train loss  1.972E+00, test loss  2.270, training accuracy  0.43, test accuracy  0.36, time per example  50.42 µs
Epoch 3625, train loss  2.041E+00, test loss  2.223, training accuracy  0.40, test accuracy  0.37, time per example  52.94 µs
Epoch 3650, train loss  2.049E+00, test loss  2.208, training accuracy  0.40, test accuracy  0.37, time per example  58.56 µs
Epoch 3675, train loss  2.011E+00, test loss  2.230, training accuracy  0.40, test accuracy  0.37, time per example  93.26 µs
Epoch 3700, train loss  1.992E+00, test loss  2.214, training accuracy  0.42, test accuracy  0.36, time per example  88.48 µs
Epoch 3725, train loss  2.044E+00, test loss  2.241, training accuracy  0.43, test accuracy  0.36, time per example  56.74 µs
Epoch 3750, train loss  2.011E+00, test loss  2.162, training accuracy  0.42, test accuracy  0.38, time per example  58.43 µs
Epoch 3775, train loss  1.965E+00, test loss  2.210, training accuracy  0.43, test accuracy  0.38, time per example  52.24 µs
Epoch 3800, train loss  1.982E+00, test loss  2.208, training accuracy  0.42, test accuracy  0.37, time per example  82.27 µs
Epoch 3825, train loss  2.039E+00, test loss  2.182, training accuracy  0.41, test accuracy  0.37, time per example  52.78 µs
Epoch 3850, train loss  2.042E+00, test loss  2.307, training accuracy  0.41, test accuracy  0.36, time per example  55.92 µs
Epoch 3875, train loss  1.999E+00, test loss  2.226, training accuracy  0.41, test accuracy  0.38, time per example  158.04 µs
Epoch 3900, train loss  1.992E+00, test loss  2.299, training accuracy  0.42, test accuracy  0.36, time per example  56.08 µs
Epoch 3925, train loss  2.075E+00, test loss  2.191, training accuracy  0.40, test accuracy  0.37, time per example  80.45 µs
Epoch 3950, train loss  2.014E+00, test loss  2.221, training accuracy  0.42, test accuracy  0.37, time per example  52.09 µs
Epoch 3975, train loss  2.029E+00, test loss  2.260, training accuracy  0.42, test accuracy  0.35, time per example  51.94 µs
Epoch 4000, train loss  1.946E+00, test loss  2.207, training accuracy  0.44, test accuracy  0.37, time per example  51.97 µs
Epoch 4025, train loss  2.004E+00, test loss  2.247, training accuracy  0.41, test accuracy  0.35, time per example  85.34 µs
Epoch 4050, train loss  2.007E+00, test loss  2.161, training accuracy  0.41, test accuracy  0.38, time per example  51.78 µs
Epoch 4075, train loss  2.014E+00, test loss  2.179, training accuracy  0.41, test accuracy  0.38, time per example  83.09 µs
Epoch 4100, train loss  2.018E+00, test loss  2.202, training accuracy  0.41, test accuracy  0.36, time per example  61.74 µs
Epoch 4125, train loss  2.014E+00, test loss  2.212, training accuracy  0.43, test accuracy  0.38, time per example  86.39 µs
Epoch 4150, train loss  1.993E+00, test loss  2.190, training accuracy  0.42, test accuracy  0.38, time per example  157.08 µs
Epoch 4175, train loss  1.963E+00, test loss  2.199, training accuracy  0.44, test accuracy  0.39, time per example  59.74 µs
Epoch 4200, train loss  2.027E+00, test loss  2.252, training accuracy  0.41, test accuracy  0.37, time per example  52.88 µs
Epoch 4225, train loss  1.978E+00, test loss  2.256, training accuracy  0.42, test accuracy  0.37, time per example  50.34 µs
Epoch 4250, train loss  1.966E+00, test loss  2.182, training accuracy  0.42, test accuracy  0.38, time per example  81.53 µs
Epoch 4275, train loss  2.010E+00, test loss  2.202, training accuracy  0.42, test accuracy  0.38, time per example  51.86 µs
Epoch 4300, train loss  2.054E+00, test loss  2.258, training accuracy  0.40, test accuracy  0.36, time per example  52.43 µs
Epoch 4325, train loss  1.994E+00, test loss  2.180, training accuracy  0.41, test accuracy  0.39, time per example  52.35 µs
Epoch 4350, train loss  2.020E+00, test loss  2.241, training accuracy  0.42, test accuracy  0.38, time per example  72.38 µs
Epoch 4375, train loss  2.018E+00, test loss  2.247, training accuracy  0.39, test accuracy  0.36, time per example  78.79 µs
Epoch 4400, train loss  2.052E+00, test loss  2.308, training accuracy  0.40, test accuracy  0.37, time per example  54.31 µs
Epoch 4425, train loss  1.978E+00, test loss  2.165, training accuracy  0.42, test accuracy  0.39, time per example  160.59 µs
Epoch 4450, train loss  2.023E+00, test loss  2.202, training accuracy  0.41, test accuracy  0.36, time per example  60.88 µs
Time budget exceeded, hence stopping training
Total training time  600.99 s
In [ ]:
 

Using NanoGPT¶

In [ ]:
model_config = GPTConfig(
    block_size=128, 
    vocab_size=65, 
    n_layer=4, n_head=4, n_embd=64, dropout=0.1
)
train_config = TrainConfig(
    epochs=20000,
    train_batch_size=2048,
    lr=1e-2,
    weight_decay=1e-4,
    epoch_interval=50,
    time_budget_seconds=300,
)
bard_results[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval_bard(
    model_config, GPT(model_config), bd, train_config, block_size=128, is_y_single_token=False, is_nano_gpt=True
)
cfg2 = dataclasses.replace(model_config, n_embd=96)
bard_results[("NanoGPT", f"run_{int(time.time())}t")] = train_and_eval_bard(
    cfg2, GPT(cfg2), bd, train_config, block_size=128, is_y_single_token=False, is_nano_gpt=True
)
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
number of parameters: 0.20M
Number of parameters: 212416
Epoch 0, train loss  4.193E+00, test loss  3.763, training accuracy  0.01, test accuracy  0.01, time per example  136.27 µs
Epoch 50, train loss  3.315E+00, test loss  3.315, training accuracy  0.01, test accuracy  0.01, time per example  126.46 µs
Epoch 100, train loss  3.312E+00, test loss  3.312, training accuracy  0.01, test accuracy  0.01, time per example  125.70 µs
Epoch 150, train loss  3.308E+00, test loss  3.307, training accuracy  0.01, test accuracy  0.00, time per example  126.57 µs
Epoch 200, train loss  3.117E+00, test loss  3.116, training accuracy  0.01, test accuracy  0.01, time per example  125.18 µs
Epoch 250, train loss  2.947E+00, test loss  2.941, training accuracy  0.01, test accuracy  0.01, time per example  125.28 µs
Epoch 300, train loss  2.744E+00, test loss  2.741, training accuracy  0.01, test accuracy  0.01, time per example  125.31 µs
Epoch 350, train loss  2.626E+00, test loss  2.625, training accuracy  0.01, test accuracy  0.01, time per example  125.22 µs
Epoch 400, train loss  2.563E+00, test loss  2.569, training accuracy  0.01, test accuracy  0.01, time per example  125.77 µs
Epoch 450, train loss  2.500E+00, test loss  2.500, training accuracy  0.01, test accuracy  0.01, time per example  125.50 µs
Epoch 500, train loss  2.511E+00, test loss  2.509, training accuracy  0.01, test accuracy  0.01, time per example  125.82 µs
Epoch 550, train loss  2.452E+00, test loss  2.452, training accuracy  0.01, test accuracy  0.01, time per example  125.27 µs
Epoch 600, train loss  2.418E+00, test loss  2.415, training accuracy  0.01, test accuracy  0.01, time per example  126.10 µs
Epoch 650, train loss  2.378E+00, test loss  2.378, training accuracy  0.01, test accuracy  0.00, time per example  128.21 µs
Epoch 700, train loss  2.344E+00, test loss  2.345, training accuracy  0.01, test accuracy  0.01, time per example  125.52 µs
Epoch 750, train loss  2.308E+00, test loss  2.306, training accuracy  0.01, test accuracy  0.01, time per example  125.64 µs
Epoch 800, train loss  2.281E+00, test loss  2.270, training accuracy  0.01, test accuracy  0.01, time per example  125.77 µs
Epoch 850, train loss  2.267E+00, test loss  2.261, training accuracy  0.01, test accuracy  0.01, time per example  128.29 µs
Epoch 900, train loss  2.158E+00, test loss  2.147, training accuracy  0.01, test accuracy  0.01, time per example  130.54 µs
Epoch 950, train loss  2.062E+00, test loss  2.065, training accuracy  0.01, test accuracy  0.01, time per example  125.06 µs
Epoch 1000, train loss  1.994E+00, test loss  1.992, training accuracy  0.01, test accuracy  0.01, time per example  125.06 µs
Epoch 1050, train loss  1.921E+00, test loss  1.920, training accuracy  0.01, test accuracy  0.01, time per example  125.04 µs
Epoch 1100, train loss  1.879E+00, test loss  1.875, training accuracy  0.01, test accuracy  0.01, time per example  125.37 µs
Epoch 1150, train loss  1.815E+00, test loss  1.812, training accuracy  0.01, test accuracy  0.01, time per example  125.81 µs
Time budget exceeded, hence stopping training
Total training time  302.89 s
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0
number of parameters: 0.45M
Number of parameters: 466080
Epoch 0, train loss  4.188E+00, test loss  3.797, training accuracy  0.01, test accuracy  0.03, time per example  149.12 µs
Epoch 50, train loss  2.971E+00, test loss  2.969, training accuracy  0.01, test accuracy  0.01, time per example  147.15 µs
Epoch 100, train loss  2.744E+00, test loss  2.729, training accuracy  0.01, test accuracy  0.01, time per example  147.11 µs
Epoch 150, train loss  2.595E+00, test loss  2.591, training accuracy  0.00, test accuracy  0.00, time per example  147.27 µs
Epoch 200, train loss  2.510E+00, test loss  2.511, training accuracy  0.01, test accuracy  0.01, time per example  147.33 µs
Epoch 250, train loss  2.402E+00, test loss  2.398, training accuracy  0.00, test accuracy  0.00, time per example  147.21 µs
Epoch 300, train loss  2.286E+00, test loss  2.287, training accuracy  0.01, test accuracy  0.01, time per example  147.28 µs
Epoch 350, train loss  2.198E+00, test loss  2.193, training accuracy  0.01, test accuracy  0.01, time per example  147.10 µs
Epoch 400, train loss  2.101E+00, test loss  2.095, training accuracy  0.01, test accuracy  0.01, time per example  146.94 µs
Epoch 450, train loss  2.008E+00, test loss  2.009, training accuracy  0.01, test accuracy  0.01, time per example  147.14 µs
Epoch 500, train loss  1.924E+00, test loss  1.916, training accuracy  0.01, test accuracy  0.01, time per example  147.40 µs
Epoch 550, train loss  1.847E+00, test loss  1.859, training accuracy  0.01, test accuracy  0.01, time per example  147.47 µs
Epoch 600, train loss  1.769E+00, test loss  1.761, training accuracy  0.01, test accuracy  0.01, time per example  146.95 µs
Epoch 650, train loss  1.706E+00, test loss  1.701, training accuracy  0.01, test accuracy  0.01, time per example  147.16 µs
Epoch 700, train loss  1.658E+00, test loss  1.654, training accuracy  0.01, test accuracy  0.00, time per example  146.91 µs
Epoch 750, train loss  1.623E+00, test loss  1.620, training accuracy  0.01, test accuracy  0.01, time per example  147.35 µs
Epoch 800, train loss  1.591E+00, test loss  1.587, training accuracy  0.01, test accuracy  0.01, time per example  147.44 µs
Epoch 850, train loss  1.555E+00, test loss  1.560, training accuracy  0.01, test accuracy  0.01, time per example  147.65 µs
Epoch 900, train loss  1.521E+00, test loss  1.518, training accuracy  0.01, test accuracy  0.01, time per example  147.25 µs
Epoch 950, train loss  1.489E+00, test loss  1.489, training accuracy  0.00, test accuracy  0.01, time per example  147.07 µs
Epoch 1000, train loss  1.474E+00, test loss  1.470, training accuracy  0.01, test accuracy  0.00, time per example  147.37 µs
Time budget exceeded, hence stopping training
Total training time  308.73 s

Results¶

In [ ]:
sb_results = []
for key, value in bard_results.items():
    sb_results.append((key[0], f"{value.train_config.lr: .1e}", value.train_loss, value.test_loss, value.num_parameters, value.time_per_example_in_micros, value.train_time_in_seconds))
import pandas as pd
pd.DataFrame.from_records(sb_results, 
    columns=["model", "Learning Rate", "train_loss", "test_loss",  
             "num_parameters", "time_per_example_in_micros", "train_time_in_seconds"]
).round(3)
Out[ ]:
model Learning Rate train_loss test_loss num_parameters time_per_example_in_micros train_time_in_seconds
0 MLP 1.0e-02 1.848 2.182 542289 19.452 302.810
1 MLP 1.0e-02 1.831 2.135 140105 19.259 307.696
2 TLTransformer 1.0e-02 1.236 1.613 216641 101.261 300.808
3 TLTransformer 1.0e-02 1.203 1.648 472385 131.621 303.390
4 NanoGPT 1.0e-02 1.815 1.812 212416 128.491 302.885
5 NanoGPT 1.0e-02 1.474 1.470 466080 150.595 308.727
In [ ]: