When Gemma 4 dropped, the timeline started talking about a different model again. Not really Gemma 4 itself, but the older Gemma 3n trick that had already planted a question in everyone's head: how does an 8B-class model run in 2 or 3 GB of RAM on a phone without the whole explanation collapsing into "we quantized it and prayed"?
That question made me curious, because half the time a flashy memory claim turns out to be boring. Better kernels, aggressive quantization, marketing doing cardio. This one felt different. It sounded architectural rather than optimizational, and that was enough to make me stop reading explanations and start trying to build it myself.
the first thing that clicked for me
The intuition is simpler than the name makes it sound. A normal transformer has one big input embedding table. A token enters the model, gets one vector, and then every layer has to work with that same initial package as it moves upward. That always struck me as slightly awkward, because lower layers care about different things than upper ones. Early layers are still close to token identity, spelling, and local patterns. Later layers are doing more composition and meaning. But the architecture makes all of them inherit the same suitcase from the front door.
PLE changes that arrangement. Instead of one fat embedding at the entrance, each layer gets its own small table, and for the current token, that layer does its own lookup and injects a small learned vector into the residual stream. That was the moment the whole thing stopped sounding like a weird Google trick and started sounding like a clean systems idea: not one bag for the whole building, but a small drawer on every floor.
the memory trick that sold me
Then the byte math made it click even harder. If each layer has a 256-dimensional per-token table and those values are stored in INT4, the cost per token per layer is trivial:
256 dims × 4 bits = 1024 bits = 128 bytes per layer
30 layers × 128 bytes = 3840 bytes per token
That is roughly 4 KB per token, which means the extra knowledge does not have to sit in VRAM at all. It just needs to be fetched quickly enough for the current token, and fetching 4 KB from storage is not a hard problem anymore.
The old mental model said all weights must live in VRAM all the time. The PLE version says the backbone lives in fast memory, but layer-local token knowledge can live on disk and stream in when needed. The model lives in RAM. The knowledge lives on disk. Once I saw it that way, I was in.
my first mistake was trying to retrofit it
Naturally I started with the messy engineer version, which is to say I took a pretrained model, taped the new mechanism onto the side, and hoped the charts would go up. I used SmolLM2-135M as the base, added per-layer tables plus the injection path, trained only the extras, and watched the validation perplexity come down. For a while it looked promising enough that I started thinking the internet had been underplaying this whole idea.
But I was asking the wrong question. The real question was never "does retrofit PLE beat doing nothing?" It was "if I spend the same extra parameter budget somewhere simpler, does PLE still win?" So I ran the control.
the control that killed the story
I matched the extra parameter budget with LoRA, using the same base model, same training recipe, and same idea of adding trainable capacity on top. LoRA beat it on metric, not on philosophy. Here are the numbers from that run:
| Method | Validation PPL |
|---|---|
| Dense student | 48.80 |
| Retrofit PLE | 48.34 |
| LoRA at matched budget | 45.86 |
That killed the comfortable version of the story, because it meant retrofit PLE wasn't proving the architecture was special. It was mostly proving the boring fact that adding trainable parameters tends to help a model, and LoRA happened to use those extra parameters more effectively.
The honest conclusion was that retrofitting PLE onto a pretrained model was the wrong premise in the first place. Google's claim was never "staple this onto anything and you win." The claim was architectural: train the backbone and the per-layer tables together from the start so they co-adapt during pretraining. If you skip that step, you don't really have PLE. You have an adapter wearing a PLE costume.
then I got distracted by a more practical idea
After the LoRA control I wandered into a different project, which turned out to be useful but not the same thing. The reasoning was: if the core memory insight behind PLE is that lookup-heavy weights don't always need to live in VRAM, can I apply that idea to existing Hugging Face models without any retraining at all? That turned into a small library I called embed_offload, which does exactly one thing. It moves the existing input embedding table to mmap'd disk, keeps generation working through HuggingFace's standard .generate(), and reduces VRAM with a drop-in API.
On untied models, the results were clean:
| Model | VRAM saved | Output match |
|---|---|---|
| Pythia-410m | 103 MB | bit-exact |
| Pythia-1b | 206 MB | bit-exact |
| Phi-3-mini | 197 MB | bit-exact |
That is a real systems result, but it isn't PLE. It's just relocating the existing embedding table. There are no per-layer tables, no new architecture, no from-scratch training. The core memory insight is the same, but the technique underneath it is completely different. Then you asked the question I should have pinned down much earlier.
"This is after the per layer embedding technique, right? Please make sure you are not hallucinating and just providing me any random results."
The honest answer was no, so I dropped the detour and went back to the real task.
forget everything, do it properly
With those detours closed, the real experiment got clean. I would train two tiny models from scratch, sharing the same tokenizer, optimizer, data, seed, steps, and scale. One would be a plain dense decoder-only transformer. The other would be the same backbone with Per-Layer Embeddings baked in from step zero. Then I'd compare them under the only condition that actually matters, which is matched inference VRAM.
I built a small LLaMA-style decoder with RMSNorm, RoPE, and SwiGLU, and made two versions:
Dense baseline
- 10.91M parameters
- all of it lives in the normal model weights
PLE model
- 11.01M backbone and gates in VRAM
- 12.58M extra PLE table parameters that can live on disk
That split is the whole point. The PLE model isn't smaller in total parameter count; it's actually larger. It just stores the extra capacity somewhere cheaper. Both models trained on WikiText-103 for 6000 steps, about 12 million tokens total, on my MPS laptop, with no synthetic tricks and no hand-picked comparison. Just train both and look at the numbers.
the result that finally counted
This time, PLE actually won.
| Model | VRAM | Disk | Val PPL |
|---|---|---|---|
| Dense baseline | 43.7 MB | 0 | 160.39 |
| PLE with tables on disk | 44.1 MB | 25.2 MB | 149.88 |
At essentially the same inference VRAM, PLE got 6.6% better perplexity. That was the real result I had been after: not "retrofit kind of helps," not "I found a separate offloading trick," but the actual architectural claim that extra capacity can live on disk while the backbone stays in memory, with quality improving at the same VRAM budget.
I also verified the disk-backed path against the in-memory PLE path:
max |logit diff| = 1.36e-03
in-memory forward = 1.7 ms
disk-backed forward = 6.0 ms
That tiny logit difference is just fp16 save-load noise, so the disk-backed path was real and producing the same computation as the in-memory one, not an approximation.
the one line I would tattoo on this experiment
Retrofit doesn't prove the architecture. The architecture has to grow up with the idea. You don't duct-tape a memory layout onto a trained network and call it the same thing.
That was the story, and the story was useful because it's what made the idea real for me. But if you just want to implement PLE, the mechanism is what you actually need, so let me separate it cleanly from the detours.
what PLE actually is
If you want to implement PLE, the architecture is just three pieces:
- a normal transformer backbone
- one small embedding table per layer
- a gate that decides how that layer should use its own table lookup
That's it. The key idea isn't to replace the input embedding; it's to distribute token-specific knowledge across the stack instead of forcing one giant table at the entrance to do everything up front.
the clean mental model
The simplest way to think about one token flowing through a PLE model is as a series of small lookups. The token first enters through the normal input embedding, exactly like in a standard transformer. Then layer 1 does its usual work, but layer 1 also has its own side drawer: it looks up that same token id in its own small table, gets a learned vector, and injects it into the residual stream through a gate. Layer 2 does the same thing with its own table and its own gate. Same token id, different vector at every depth.
So the model isn't carrying one fixed embedding upward and hoping it's enough. It's refreshing token-specific information at every layer. That's why the architecture feels different. A normal transformer says "here is the token, good luck." PLE says "here is the token again, in the language this particular layer cares about."
the per-token flow
In implementation terms, one forward pass looks like this:
input_ids -> tok_embed -> hidden states x
for each layer i:
x = x + attention_i(x)
x = x + ffn_i(x)
ple_vec = ple_table_i[input_ids]
x = x + gate_i(x, ple_vec)
logits = x @ tok_embed^T
The line in the middle is the whole story:
ple_vec = ple_table_i[input_ids]
Every layer has its own ple_table_i, so the same token gets a different learned vector depending on where it sits in the network. That's why this isn't just a weird embedding resize trick. The architecture is explicitly saying that token information should be depth-dependent.
why the gate matters
The dumb version of PLE would just add the layer table vector directly:
x = x + ple_table_i[input_ids]
That's too crude, because not every token needs the same amount of extra information at every layer, and not every layer should trust its raw table lookup equally. So you add a gate. In my implementation, the gate takes the current hidden state, compresses it to ple_dim, multiplies it elementwise with the PLE lookup, and projects it back up to the model dimension.
In plain text math:
h = RMSNorm(x)
g = GELU(W_down h)
delta = W_up (g * ple_vec)
x = x + delta
Where:
xis the current hidden state at that layerple_vecis the token's lookup from that layer's tableW_downmaps hidden size -> ple_dimW_upmaps ple_dim -> hidden size*is elementwise multiplication
So the layer is doing two things at once: looking up token-specific information from its own table, and deciding from the current hidden state how much of that information should matter right now. That second part is what earns the architecture its keep. Without the gate, you have a hard additive bias per token. With it, the per-layer table becomes conditional information that the network can weigh differently at every depth.
what it looks like in code
The tiny LLaMA-style decoder from Part 1 has a single config flag, use_ple. When it's off, the model is just a dense language model. When it's on, each block gets its own nn.Embedding(vocab_size, ple_dim) table plus a PLEGate. The gate itself is tiny:
class PLEGate(nn.Module):
def __init__(self, d_model: int, ple_dim: int):
super().__init__()
self.norm = RMSNorm(d_model)
self.down = nn.Linear(d_model, ple_dim, bias=False)
self.up = nn.Linear(ple_dim, d_model, bias=False)
def forward(self, hidden, ple_vec):
h = self.norm(hidden)
gate = F.gelu(self.down(h))
return self.up(gate * ple_vec)
Inside each transformer block it slots in right after attention and FFN:
x = x + self.attn(self.attn_norm(x), cos, sin)
x = x + self.ffn(self.ffn_norm(x))
if self.use_ple and ple_vec is not None:
x = x + self.ple_gate(x, ple_vec)
And the model-level lookup is just one line per layer:
for i, block in enumerate(self.blocks):
ple_vec = self.ple_tables[i](input_ids)
x = block(x, cos, sin, ple_vec)
That's the full implementation. It really is that small. The complete codebase, including training scripts, benchmarks, the retrofit experiments, and the disk-backed inference path, lives at github.com/imdigitalashish/ple-layerlookup.
why the tables can live on disk
The design choice that unlocks everything isn't the exact gate formula; it's that the PLE tables are indexed by token id, not by hidden state. That makes them pure lookup structures, which have one beautiful property: for a given token at a given layer, you only need one row. You don't need the whole table in accelerator memory to take a step. If this were a dense matrix multiply you had to run on every token, the whole disk-backing story would collapse. Row lookup is different, because it's sparse, predictable, and easy to stream.
The serving budget is also shockingly small. For 30 layers, ple_dim=256, INT4 storage, each token costs 3,840 bytes per step, or about 4 KB. So the backbone weights stay in VRAM, the per-layer tables live on disk, and when token t reaches layer i you fetch row t from table i, send it through the gate, and move on. That is the whole production argument, and it really is just row lookup at 4 KB per token.
what the disk-backed path actually looks like
I saved each layer's table as a raw binary file plus a small meta.json with vocab size, ple_dim, layer count, and storage dtype. Then I mmap each layer file and do row lookups by byte offset. The logic is basically:
row_bytes = ple_dim * itemsize
offset = token_id * row_bytes
row = mm[offset : offset + row_bytes]
Wrapped in a tiny helper:
class DiskPLETables:
def lookup(self, token_ids, layer_idx):
flat = token_ids.flatten().cpu().numpy()
out = np.empty((len(flat), self.ple_dim), dtype=self._np_dtype)
mm = self._mms[layer_idx]
for i, tok in enumerate(flat):
off = int(tok) * self._row_bytes
out[i] = np.frombuffer(mm[off:off+self._row_bytes], dtype=self._np_dtype)
return torch.from_numpy(out).reshape(*token_ids.shape, self.ple_dim)
Then during inference, instead of doing the in-memory lookup:
ple_vec = self.ple_tables[i](input_ids)
I pass in the disk-backed version:
ple_vec = disk.lookup(input_ids, i)
The model is the same, the gate is the same, and the forward path is the same. The only thing that changes is where each table row came from. That's why the disk-backed verification actually means something: I wasn't benchmarking a second pretend model, I was swapping the data source under the same architecture.
if you want to build this step by step
If I were doing the implementation again from scratch, I'd do it in this order:
Build the dense baseline first
- tiny decoder-only LM
- train it end to end
- make sure loss moves and validation works
Add a config flag for PLE
use_pleple_dim
Add per-layer tables
nn.Embedding(vocab_size, ple_dim)for each layer
Add the gate module
- hidden -> ple_dim -> elementwise multiply with
ple_vec-> back to hidden
- hidden -> ple_dim -> elementwise multiply with
Inject after the normal block computation
- attention
- FFN
- then PLE gate injection
Train dense and PLE with the exact same recipe
- same tokenizer
- same seed
- same optimizer
- same steps
- same data
Measure the right thing
- not total params alone
- not training loss alone
- inference-time VRAM versus quality
Only then add disk-backed lookup
- save tables separately
- mmap them
- verify max logit diff against in-memory path
That order matters. If you jump straight into the disk tricks before you know the architecture itself works, you'll spend a lot of time optimizing a concept you haven't even validated yet. I know because that's exactly what I did.
drag it yourself
The reason this architecture is interesting isn't that it exists on paper. It's that the numbers underneath it are tiny.
PLE memory map
Drag the sliders to see where PLE capacity lives. Flip the mode toggle to move the per-layer tables between VRAM and disk.
Flip the mode to "tables on disk" and watch the VRAM drop. Push the layer count up, drop the quantization to INT4, and notice that the per-token streaming budget barely moves. That's the entire production argument in one slider.
why this matters beyond Gemma
For years ML acted like all weights must live in VRAM and disk was embarrassing. But operating systems solved that kind of problem in the 1960s. Not everything needs to sit in the fastest tier of memory at all times. What matters is what needs to be hot right now, and whether the cold stuff has an access pattern cheap enough to stream.
PLE is that lesson applied to transformer weights, not for every parameter but specifically for lookup-shaped capacity. Once you accept that split, the design questions get sharper. Which weights need to be hot? Which ones can be paged? Which parts of the model are really compute-heavy, and which parts are just memory layout pretending to be compute?
That's why this one grabbed me. It isn't a model trick so much as a memory-systems trick dressed as model design, and one 11M-parameter backbone trained on a laptop was enough to stop it from feeling theoretical.
Dense 160.39. PLE 149.88. Same VRAM. Extra knowledge on disk.
That is enough. 🫡
if you want to go deeper
The main post ends above. Everything below this is the receipts drawer. If you just wanted the story plus the clean implementation, you can stop here. If you want the failed paths, the side experiments, and the numbers that convinced me to stop fooling myself, keep reading.
why the retrofit lost to LoRA
This was the first real reality check. Once I matched the extra trainable budget, the three-way comparison looked like this:
| Method | Validation PPL |
|---|---|
| Dense student | 48.80 |
| Retrofit PLE | 48.34 |
| LoRA at matched budget | 45.86 |
That gap is why I stopped treating retrofit PLE as validation of the architecture. My read is that the pretrained backbone was never taught, during pretraining, to expect these extra per-layer token injections, so when I added them later, the new tables were mostly acting as a generic adapter. Useful but not special. LoRA was doing the same broad thing more directly, which is why LoRA won at equal budget.
The lesson I'd carry forward to any similar architecture experiment is this: if the claim is really about how capacity is distributed through the network, then the network usually has to be born that way. Retrofitting can still be a useful probe, but it isn't the final proof of anything.
the embed_offload detour was real, just not the same thing
I don't want to erase this part, because it was still a real systems result. It just wasn't PLE. I built a drop-in library that moved the existing input embedding table to mmap'd disk, and measured it across a few Hugging Face models on a Colab T4:
| Model | Base VRAM | Offloaded VRAM | Saved | Output match |
|---|---|---|---|---|
| SmolLM2-135M | 269.0 MB | 240.8 MB | 28.2 MB | small drift |
| SmolLM2-360M | 723.6 MB | 676.6 MB | 47.1 MB | small drift |
| Qwen2.5-0.5B | 988.1 MB | 852.2 MB | 135.8 MB | small drift |
| Pythia-410m | 810.7 MB | 707.6 MB | 103.0 MB | bit-exact |
| Pythia-1b | 2023.6 MB | 1817.5 MB | 206.0 MB | bit-exact |
| Phi-3-mini | 7642.2 MB | 7445.2 MB | 197.0 MB | bit-exact |
The tied-versus-untied distinction matters here. For untied models like Pythia and Phi-3, the embedding table can move to disk cleanly and outputs remain bit-exact. For tied models, the input embedding and output head share storage, so to actually save VRAM I had to replace the output side with an int8 copy, which saves memory but introduces small logit drift.
It's a useful tool using a different technique from PLE, and I'm keeping it in the story because it came from the same underlying insight: lookup-shaped parameters are much more movable than we usually pretend.
the disk-backed PLE benchmark that actually mattered
After the from-scratch run, I verified three deployment modes for the tiny PLE model:
| Mode | VRAM | Disk | Val PPL |
|---|---|---|---|
| Dense baseline | 43.7 MB | 0 | 160.39 |
| In-memory PLE | 94.4 MB | 0 | 149.88 |
| PLE tables on disk | 44.1 MB | 25.2 MB | 149.88 |
This is the table I care about most, because it shows what PLE is actually buying you. If you load everything into VRAM, the bigger PLE model of course wins, and that comparison isn't interesting. The interesting case is the third row: move the PLE tables out to disk and keep only the backbone and gates in VRAM, and the memory footprint matches dense while the quality advantage stays.
The verification numbers were:
max |logit diff| = 1.36e-03
in-memory forward = 1.7 ms
disk-backed forward = 6.0 ms
slowdown = +255%
That slowdown looks dramatic in percentage terms, but this is a case where absolute numbers matter more than ratios. It's an 11M-parameter model where the dense compute is so cheap that lookup overhead becomes a large fraction of total step time. At realistic model sizes, backbone arithmetic and memory movement dominate much more heavily, so I'd treat the toy-model slowdown as evidence that disk lookup has a cost, not as evidence that the serving idea fails.
the scaling math is why this idea is plausible
The core byte budget is simple:
bytes_per_token = num_layers × ple_dim × bits_per_value / 8
For a 30-layer model, that gives you:
| ple_dim | INT4 bytes/token | INT8 bytes/token |
|---|---|---|
| 128 | 1,920 | 3,840 |
| 256 | 3,840 | 7,680 |
| 512 | 7,680 | 15,360 |
Even the 256-dim INT4 case is only about 3.8 KB per token, which is why this idea stopped sounding insane to me once I ran the numbers. If someone had told me "we're streaming 80 MB of token-specific weights from disk every step," I'd have said no chance. But 4 KB per token is a completely different sentence. It's small enough that storage bandwidth and OS caching stop sounding like science fiction and start sounding like ordinary systems engineering.
what I would test next if I kept pushing this
If I were extending this beyond the tiny laptop proof, I'd do four things:
Scale the same experiment to a stronger small baseline
- something in the 100M to 500M range
- keep the matched-VRAM comparison honest
Quantize the disk tables harder
- the proof here used fp16 on disk for simplicity
- the real deployment story wants INT8 or INT4
Batch the disk lookups properly
- my reference code is intentionally simple
- a production path should avoid Python loops and gather rows in bigger chunks
Measure end-to-end generation throughput, not just forward latency
- especially once KV cache and real sequence growth enter the picture
That's where this gets serious. The interesting question isn't "does the toy demo work?" anymore. It's "where is the crossover point where extra streamed capacity actually beats just making the dense model fatter?"
my actual conclusion after all the side quests
After all the side quests, my actual conclusion is narrow but solid:
- Retrofit PLE isn't convincing enough to validate the architecture on its own.
embed_offloadis useful, but it isn't PLE.- From-scratch PLE really does buy better quality at matched VRAM, at least in this small reproduction.
That's enough for me to take the idea seriously, not as universal hype but as a genuine architectural bet about where token knowledge should live in a transformer.