Dwarkesh PodcastReiner Pope on Dwarkesh Patel: Why Token Cost Tracks Batch
Weight fetches dominate token cost until batch crosses 300 times MoE sparsity; past that crossover, compute binds and cost per token hits its lower bound.
EVERY SPOKEN WORD
120 min read · 24,387 words- 0:00 – 31:59
How batch size affects token cost and speed
- DPDwarkesh Patel
Today I'm interviewing Reiner Pope, who is CEO of MatX, which is a new chip startup. Previously, he was doing TPU architecture and many other things at Google. This is a very different format from my usual interviews. This is gonna be a blackboard lecture. We're gonna get up in a second. We, in fact, built this whole new studio with specifically this format in mind. Um, and so it's a pleasure to get to inaugurate it with you. We're gonna be talking about model architecture, ML infra, many other things. And, um, the reason I think it's an important topic is because once you actually understand how training and inference actually work in a cluster, as we'll see a lot of things about why AI is the way it is, why AI architectures are the way they are, why, um, API prices are the way they are. Fundamentally also h-how-- why AI progress is the way it is, start making sense, and you need to understand the details to get there, and you need a blackboard to understand the details. So Reiner, thank you so much for doing this.
- RPReiner Pope
Yeah. Very happy to be here.
- DPDwarkesh Patel
Okay. Uh, full disclosure, I am a angel investor in MatX, but that's unrelated to this podcast. Um, Reiner, maybe to kick us off, I'll ask this question. So we have a couple of companies like Claude and Codex and Cursor are offering something like, uh, fast mode, where for 6x the price, they'll give-- stream you tokens at 2.5x the speed. Mechanically, I'm curious what's going on here. Why-- Like, why is it the case if you can pay more to get faster latency? And two, could you keep going? Could you pay 100x more and somehow get even faster speeds or much, much faster speeds? Um, and three, could you go the other way? Could you have something like, uh, Claude Code Slow mode, where if you are willing to wait for minutes on end, you could get, um, even cheaper prices. So maybe this will help motivate the kind of analysis that you'll be doing through the lecture.
- RPReiner Pope
Great. I mean, to jump to the-- a little bit to jump to the conclusion, the big effect is batch size, but what, what we're gonna do now is quantify exactly what that looks like-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... and what its implications are on latency and cost. Uh, there's gonna be s- another effect, which is, um, you can call it speculative decoding or multi-token prediction. We can maybe come back to that later, but I think the first thing that we'll talk through is batch size. So what I'd like to introduce is, um, sort of the two principles of analysis. Firstly, we're gonna look at a roofline analysis of how we run a transformer model on, on a cluster of chips. Um, we'll take a sort of a, let's say a, a Blackwell NVL 72, uh, cluster, so a rack of 72 GPUs. Um, and so the roofline analysis means we look at, uh, memory bandwidth and, and c-compute, uh, performance. And then the other side of that is that we're going to look at just two simple factors of the model, which are the time to operate on the weights, and then the time to operate on the context, the KV cache. So let's jump in. What we're gonna try and do is we're gonna try and estimate the time that it takes, uh, to, to run an inference of a certain shape. Now, we're not perfect here. We can't, uh, exactly predict the time, and so in-instead we're gonna approximate, and so we're gonna say that the time must be greater than or equal to a certain quantity. And so we're gonna consider two different, um, aspects. We're gonna look at the time for, uh, it takes to, uh, do the memory fetches, uh, and then the time it takes to do the compute. And it'll turn out that this actually gives us a very strong predictive power, even with a simple model. So one by one, what is the time that it takes to do the compute? So there are really two things I need to do in the compute. I need to, um, multiply by all of the active parameters, um, and then I need to do some work on the attention. So multiplying by all the active parameters. I have a certain batch size that I'm running, and then I've got a number of, uh, active parameters in my model. And then, um, and then I'm just gonna divide this by the compute throughput, which is, uh, the flops of the chip. So this is a hardware concern. So th-this, this actually accounts for all of the compute time for all of the weight matrix multiplies. Um, there's a little caveat here. We, we've sort of ignored the time to do any of the attention computation, but that in general can be, will be quite small in comparison to this.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So, so we'll ignore this.
- DPDwarkesh Patel
May-maybe I'll just interrupt from time to time to ask some very naive questions or to clarify some, uh, basic points. But just for the audience, you're not serving one user at a time. The batch refers to the fact that you're serving many different users at the same time.
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
Um, and that's a whole batch.
- RPReiner Pope
Yeah. So I can motivate the batch at least a little bit.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So, um, I mean, we will see exactly why batch is such a favorable optimization.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
But what will turn out to be the case is that, uh, if you do not batch together many users, um, the cost and the economics you get is, can be like a thousand times worse than-
- DPDwarkesh Patel
Mm.
- RPReiner Pope
than if you do batch many users together.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Um, and, and we'll be able to see that quite explicitly.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
And then, uh, number of active parameters. This is saying, like if I think-- look at, for example, a DeepSeek model, uh, the DeepSeek V3 model has about thirty, thirty-seven billion active parameters, and then s- uh, seven hundred billion total parameters. So this is-- We're, we're focusing on just the ones that are active for a single, I mean, a token. Okay. So we're modeling the compute performance. I'm gonna keep writing equals, but in, in all of these cases, you can think of this time as being at least this much, and, and maybe there'll be some terms we ignored. Um, on the memory side, um, what do we need to do, uh, with memory? We, we need to fetch, um, we need to fetch all of the weights, and so there is some time to fetch all of the, the total number of parameters, not just the active parameters. Um, so there's weight fetch time, and then in addition, uh, there's a KV cache fetch time. So there is, um, this actually depends on batch size. Uh, so for every element of the batch, we have to fetch, uh, an entire context length worth of tokens, and then there's a size per token, so, uh, um, like bytes, bytes per, for, for one token. Um, uh, and so this is a model parameter.
- DPDwarkesh Patel
And, and maybe, um, just back in, let's, let's just explain what the KV cache is real quick.
- RPReiner Pope
Yeah. So when I do a forward pass, uh, let me draw actually a, um, how the autoregressive inference works.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So this is during decode. Um-So if I think I have a bunch of tokens, uh, uh, text, I'm growing a tensor because, uh, ultimately the tokens are represented as some, like, tensor of, uh, in some, uh, embedding dimension, and then in this direction, I have the sequence length. Um, the work of running a decode is I, I have to run each token through a, um, uh, through a whole bunch of matrix multiplies over a bunch of different layers. Um, and I have, I have-- in general, I'm gonna have to do that work over, uh, all of these, uh, tokens. But then one step of decode is actually to produce just this one additional token up here.
- DPDwarkesh Patel
Yep.
- RPReiner Pope
And so what I'm gonna do there is I'm gonna run a full forwards pass of, uh, multiplying by all of the weight matrices i-in the entire model. Um, but then I've got this attention mechanism where this token sort of-- it's, it's, like, looking at all of the past tokens, um, in this way, and what is it looking at specifically? It is looking at some internal representation that the model has produced of the tokens, and we call that the KB cache. So this process of attending, this, this single token attending to all of the history of tokens, um, that's attention. It is mostly, uh, dominated by memory fetches rather than, um, than matrix multiplies.
- DPDwarkesh Patel
Mm-hmm.
- RPReiner Pope
So we've got the amount of memory that we're fetching shown over up here, and then this, of course, just then divided by the, uh, memory bandwidth. Um, so, uh, so the memory bytes per second. So in fact, th-this, these equations here are actually, uh, enough for us to now some, draw some fit lines. And so the things that we'd like to look at are sensitivity to batch, and then also, um, which we'll draw separately, to context length. So we said that the big, big effects you can get is, like, some, some trade-off in latency versus, uh, versus cost, um, in, in batch size. So, so let's draw them out. I think there's just really two graphs that we want to draw. Um, we'll first just draw, um, batch size versus, uh, time here. So when we look at the shape of this, we've got a maximum of, well, the sum and then, and then another term. Um, so let's look at these terms one by one and how they scale, uh, uh, the time for compute and, and memory, uh, and how they show up. So let's first look at this compute time. Uh, this is just purely linearly, linear in batch size with no-
- 31:59 – 47:02
How MoE models are laid out across GPU racks
- DPDwarkesh Patel
Yeah. Where were we?
- RPReiner Pope
Uh, Sparse Mixture of Experts.
- DPDwarkesh Patel
Yes.
- RPReiner Pope
Um, h- maybe how we lay that on, out on a GPU.
- DPDwarkesh Patel
Yep.
- RPReiner Pope
So, um, let's zoom in on the Mixture of Experts layer first and, and, and sort of draw what that looks like. So we typically, um, will have a, some kind of a router layer-
- DPDwarkesh Patel
Mm-hmm
- RPReiner Pope
... um, which is making the decision of where we route, uh, the experts, uh, the, the tokens to. So we have tokens coming in here. They go through a router layer, and then we have a bunch of different, um, experts. Uh, and I'll draw, draw a few more, um, to line some up. Um, and then the router will make a decision on which experts am I going to route to, and it'll be a small fraction of them, maybe one in 32. So maybe it'll make a decision to route to this one, um, uh, maybe this one, and maybe this one.
- DPDwarkesh Patel
Mm-hmm.
- RPReiner Pope
Uh, these experts, so these, each expert itself is a normal MLP. It has a up projection and then a down projection-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... with a non-linearity in between. Um, and then finally, we sort of do the inverse operation. So where we were broadcasting things out here, um, we're gonna bring them back in and sum them up. So, uh, bringing them in like this. Uh, and then finally, we have our residual connection. So the, the token is also passed through here, and it gets added to the result of the MoE layer. So, so this is a normal MoE layer. Um, what I want to talk through is how this is mapped to a, like a GPU rack, um, and what this means for communication, uh, because I think this will, will start to show some of the, the limits of how sparse we can go.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So, um, the standard practice here, and it, it is the best solution, is to use, um, expert parallelism. So that means different experts go on different GPUs. So if we take something like a DeepSeek model, um, they have two hundred and fifty-six experts. Um, let's say we want to run that on a Blackwell rack. Um, so there are seventy-two GPUs. Um, we have a divisibility problem. This is not a power of two. Um, so we'll just, like, simplify and say we're only gonna use sixty-four of them. Um, just ignore the other eight. It's not a big deal. Um, and so we, we have four experts per GPU. Uh, very simple. Um, uh, for the sake of the diagram, I'll actually just say let's, let's say we have two experts per GPU. So we, um, we end up just putting, uh... These are the GPU boundaries. Every pair of experts is on its own GPU. Um, and then we can look at the communication cost. We had some experts stored-- some tokens stored centrally here. They get routed to all of these experts. Um, and so, uh, there is some communication cost paid here. There's the same communication cost paid on the output. Um, and then the hope is that, uh, this does not become communication limited. Um, now, what is the traffic pattern here? Um, the traffic pattern here is that any GPU, in fact, will be talking to any other GPU, depending on, um, the, the decisions made by the model. So this is an all-to-all traffic pattern.
- DPDwarkesh Patel
So when you say any GPU in the pretense-
- RPReiner Pope
Yeah
- DPDwarkesh Patel
... the router is more than one GPU?
- RPReiner Pope
Yeah, the rou- So I, I drew this as one router. Uh, in reality, you would actually have many copies of the router, and so you would have, um, as, as many routers as, as GPUs, in fact.
- DPDwarkesh Patel
As, as, as the incoming G- incoming traffic.
- RPReiner Pope
Yeah. So these are, these are, the-these are sixty-four GPUs. These are sixty-four GPUs.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
It's actually the same GPUs. We just, like, draw them as, as separate because they're serving different purposes.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So at this point, any GPU can be sending to any, any other GPU.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So this all-to-all pattern, um, of communication that shows up, uh, and how, how the Blackwell racks are configured, um, is a, is a perfect fit for the, um, the communication pattern that the, uh, MoE actually wants to do. Um, however, if you think, "Maybe I want to do, like maybe one rack is too slow and I want to do two racks," um, then I have this challenge that, like, maybe I've got some sort of rack boundary drawn outside here like this. Um, and I no longer, in fact, have all-to-all communication between all the GPUs in two racks. Um, and so the rack-to-rack communication ends up being a substantial bottleneck.So, uh, this sort of, like, the fundamental thing here is that one rack is actually the, bounds the size of an expert layer you can do. And so, uh, this has been part of what's been driving towards, um, larger and larger interconnect domains.
- DPDwarkesh Patel
Yeah. Um, but before we... It may be worth you explaining what exactly a rack is.
- RPReiner Pope
Mm-hmm.
- DPDwarkesh Patel
The differences in bandwidth between a rack and within a rack-
- RPReiner Pope
Yeah
- 47:02 – 1:03:27
How pipeline parallelism spreads model layers across racks
- DPDwarkesh Patel
When you're operating within a single scale of domain, is that a consideration specifically for either forward or backward, or specifically for prefill versus decode? Or is it, is it preferred to always be within a scale up-
- RPReiner Pope
Yeah
- DPDwarkesh Patel
... whatever kind of workload you have, whether you're doing a pre-training run or whether you're doing RL generation or whether you're doing inference for users?
- RPReiner Pope
Yeah. Really interesting. Um, so, okay. So, uh, to answer that question, we're gonna need to talk about the communication patterns, um, that... So we've talked about the mixture of expert communication pattern.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
That is this all to all. Um, uh, there's all to all. All to all. Um, all to all very strongly favors, um, uh, f-full connectivity, which is what, what we've kind of just shown here, and it favors being within one rack. Um, there are other kinds of parallelism, uh, besides expert parallelism, which, which, which we just showed here. In the literature is tensor parallelism. This is, um, m- with the trend towards smaller experts, this has become much less relevant, so we can ignore that. Um, but the other two things that we have available are data parallelism and pipeline parallelism. Um, and they are actually much, they can be a much better fit for, uh, using multiple racks. So let's focus on pipeline parallelism specifically. Um, this is one layer of MoE. Um, I'm gonna have like 100 more layers up above. Um-I could decide at this point, for example, to move to a different rack, uh, change rack. Now, is that gonna become a communication bottleneck? So we can actually just solve for when this becomes a communication bottleneck. Um, but before we do that algebraically, like let's just sort of visualize it out and, and sketch the path. So we're gonna have a bunch... This is another MoE layer, and we're gonna have another MoE layer here, and so on. Um, uh, so let's say I change rack here, and then some number of layers later, I change rack here as well. Um, so our, our, our methodology that we're gonna use to determine whether we have a communication bottleneck in this, like, in this point where we change rack, um, is we're gonna compare the-- this, this is the scale out, um, scale out, um, bandwidth requirements to the scale up bandwidth requirements.
- DPDwarkesh Patel
Mm-hmm.
- RPReiner Pope
So let's write this. Uh, and, and I mean, the, the hint is gonna be that, um, there's a lot more transcends here. Like, we're sending many things here, whereas we're only sending one thing here, and then we're also maybe doing it many times. That's, so that's going to be the, uh, what, what makes the difference.
- DPDwarkesh Patel
Uh, c-can I try to guess?
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
Just out of curiosity to see if I'm actually un-understanding. Um, it seems like you're, you're sending, like, batch size into the rack.
- RPReiner Pope
In here?
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Yes.
- DPDwarkesh Patel
Uh, but the communication within a rack is sort of batch size times number of GPUs.
- RPReiner Pope
Yeah. So number of activated GPUs, right? So like I, I don't send to this GPU at all, right? So there's an explosion from one to, like, it's three times larger here in, in this diagram.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Um, the key thing is that I, I didn't even need to send to this GPU at all, and so that's a big saving.
- DPDwarkesh Patel
I see. Yeah.
- RPReiner Pope
Okay. So we're gonna talk through, um, uh, sort of how much more, uh, w-what is the slowdown of-- to what extent is scale up, uh, a bottleneck over scale, uh, o-over scale out? So, uh, we will directly jump to the ratio of the time spent on, uh, scale up, time on scale up over the time spent on scale out. So this is, this is the quantity we're talking about. Um, and the first consideration is that the scale up is like, um, uh, scale up is, is eight times faster than scale out generally. And so, uh, at a baseline, if the bandwidths were the same, we would have this one o-one over eight, which is coming from bandwidth, bandwidth. But then we have some amount of expansion in, in, in how much data we're sending. So if one token comes in here-
- DPDwarkesh Patel
Mm
- RPReiner Pope
... then this one token gets routed to, in the DeepSeek case, it'll get routed to maybe 32 experts or, or 16 experts. It gets routed to some number of experts. So this is the number of activated experts, number of activated experts. Um, and then it also... The s- the same thing applies on multiple different layers. So maybe I'm gonna run two layers. So, um, there's also multiple times, um, number of layers, uh, per stage.
- DPDwarkesh Patel
And don't you need to multiply the whole thing by two for the, um-
- RPReiner Pope
For the up and down
- DPDwarkesh Patel
... for the overall, yeah.
- RPReiner Pope
Yes, yes, and there's a factor of two. Thank you. Um, so what we would like is the, for the scale up time to be greater than the scale out time, um, because, like the scale up time is the more important and precious resource. And so we just, we want this one, we would like this number to be greater than or equal to one. Um, and this really doesn't seem hard. Like we've, we've-- there's just a factor of eight that we need to overcome, so we need the product of these three things to be bigger than eight. Um, typically, we have a fairly large number of activated experts. It could be eight, um, by itself. Um, and then we can increase the number of layers per stage a lot until, until we satisfy this.
- DPDwarkesh Patel
I see.
- RPReiner Pope
Um, so what this ends up looking like is that I can, in fact, have an entire pipeline of racks where one rack does one layer, and then I move on to the next rack, and I do another layer, and then I move on to the next rack. I can do another layer.
- DPDwarkesh Patel
It's interesting to me that the best parallelism, uh, strategy in practice ends up being one which physically resembles the actual architecture. It's not some galaxy brain thing. You know, it's like, "Oh, we have experts. We're gonna put them on G- different GPUs. Oh, we have different layers. We're just gonna put them on different racks." Isn't that... I feel like that's interesting that the physical and-
- RPReiner Pope
The, the, the model architecture matches, like the, the, the cutting matches the model architecture.
- 1:03:27 – 1:18:49
Why Ilya said, “As we now know, pipelining is not wise.”
- DPDwarkesh Patel
to get started. [upbeat music ends] So macro question, everybody's talking about the memory wall right now.
- RPReiner Pope
Mm-hmm.
- DPDwarkesh Patel
Memory's getting super expensive. There's not enough memory. Smartphone volume will go down thirty percent because there's not enough memory. Hyperscalers are spending-- This is shocking. If I'm-- Dylan said they're spending fifty percent of their CapEx this year.
- RPReiner Pope
On memory?
- DPDwarkesh Patel
On memory.
- RPReiner Pope
Oh, that's believable. Yeah.
- DPDwarkesh Patel
It's so-- But it's so-- Like what is hyperscaler CapEx? That's like high hundreds of billions-
- RPReiner Pope
Yeah
- DPDwarkesh Patel
... maybe a trillion, and they're spending half of that on memory. Okay, so that, that is a huge constraint. That's why we're not gonna get new laptops and phones this year.
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
Um, but at the same time, we're, we have too much memory? Like, people are willing to put too much memory into these systems?
- RPReiner Pope
Right. So, um, this is-
- DPDwarkesh Patel
Like what, why, why is Jetset shoving all this memory into these racks if-
- RPReiner Pope
Yeah, so-
- DPDwarkesh Patel
... if you don't need it.
- RPReiner Pope
Yeah. So we've-- Like, in, in the, um, equations we had here before we erased them, we were doing memory time, so memory bandwidth and, and compute bandwidth. Let's now start looking at, uh, memory capacity.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So we'll start off with just, like memory capacity without even thinking about parallelism scheme. Um, and so the, um, uh, like the capacity of memory, um, or the, or, or the, the demand on memory is, um, the number of total parameters, um, plus... So, so this is what we need to fit the weights in some system that we are using.
- DPDwarkesh Patel
Mm-hmm.
- RPReiner Pope
Um, and then we need to fit the KVs as well. So KVs go as batch size times the length of the context, um, times, uh, times the bytes, bytes per, bytes per tok. Um, okay. So, um, what I was arguing about in this context, and the case I was making, uh, for pipelining, is that, um, we will actually-- there are some techniques that allow us to solve this. Are there techniques that allow us to solve this? So let's, let's consider... So we're gonna run this on some number of GPUs, and, and we-we're gonna say, um, we're gonna have one extent, which is, um, uh, E is gonna be the expert parallelism. So how many-- When we had this sharding of, uh, a expert layer across many GPUs, how much of that, uh, to what extent do we do that? How many GPUs? Um, so we're gonna say that this is fact, for example, sixty-four. And then P is going to be the extent of pipeline, pipelining. Um, and so this is the number of racks, which, who knows, maybe, maybe we'll, m-maybe we'll pick four or something like that. What we want to calculate-- So this is the, this is like the total, um, total memory re-requirement across the system. Um, but now I'm going to calculate a, um, a memory requirement per GPU. So per, per GPU memory requirement, uh, we're gonna have... I guess I'll use a lowercase cmem. Um, and well, obviously, we just take all of these numbers and divide it by E and P. Really easy. So, um, uh, it's this N total, um, plus the batch times length of context times bytes per tok. Um, all of this is divided by E times P. Okay, so this is-- Like why is this correct as divided this way? Um, well-Where, where we're saying we, we knew that the parameters were perfectly divided amongst all the, the GPUs in a rack. They're als- the layers are perfectly divided amongst the, the, the different racks. So that works here, and somehow we're going to arrange, I'll hand wave exactly how, somehow we can arrange the same perfect sharding of, of the contexts across GPUs in a rack and, and, and then based on layer across, uh, racks, um-
- DPDwarkesh Patel
And, and sorry, four is the number of racks?
- RPReiner Pope
Uh, yeah, for example.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Yeah. Um, so, um, this is the place where we actually need to go back and analyze this batch size B, and you were making this comment that there's micro-batching versus global batching. So, um, let's come back to this pipelining diagram here. Um, we've got one batch going forward here, and then as I drew it, it kind of just, like, disappeared. That's not really correct. If you think about, um, how decode is working, I have a bunch of tokens that I have generated already. I do one forwards pass where I generate a new token, and then, and then I c- push, like, then I write that to my KB cache-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... and then I do another forwards pass that generates the next token. So I'm actually gonna be running this batch zero in a loop. So in fact, I go forwards. Once I finish, I can start the next iteration of the loop up here.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
And so we'll just fill this in. We'll d- we'll have the... Oh, uh-
- DPDwarkesh Patel
Mess [laughs] .
- RPReiner Pope
Yes. Um, yeah, so we've got the two, three, um, two and three, uh, two, three. Uh, so let's split this batch. This batch will be the global batch size, so B is going to be the, um, number of, number of micro-batches times the batch of, like, the batch size per micro-batch. So how many micro-batches do we need? So the number of micro-batches in this diagram is four, zero, one, two, three. Um, and then the batch size per, um, like, the, the micro-batch size, this is still this, like, 2,000-ish number. Um, this is the one that is, like, um-
- 1:18:49 – 1:32:52
Because of RL, models may be 100x over-trained beyond Chinchilla-optimal
- DPDwarkesh Patel
Yeah. Okay. A super tangential question. There's Chinchilla scaling, which tells you how, how big should a model be relative to the amount of data you're gonna train it on. Um-But now obviously you're not just trying to optimize for the highest quality model you could get with training compute. You want the best results a user can get-
- RPReiner Pope
Yes
- DPDwarkesh Patel
... with a mixture of training and inference compute.
- RPReiner Pope
Mm-hmm.
- DPDwarkesh Patel
So then there's a question of how much should you over-train a model-
- RPReiner Pope
Mm-hmm
- DPDwarkesh Patel
... such that that compute amortized over training and inference is minimized to get a certain performance. But now with RL inference, there's-- or R-RL, there's another consideration, which is you're gonna do some amount of pre-training.
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
That pre-training will be used both for RL generation and then for inference for the final user. A-and by over-training here, I mean while it would've been more efficient just from a training compute perspective to have a bigger model-
- RPReiner Pope
Mm-hmm
- DPDwarkesh Patel
... that you train for less time because it can learn faster, maybe you, you get a smaller model, you spend more compute training it than you otherwise would've, but now it's cheaper to give it to users. Like basically, okay, may-maybe, may-- let me make the q-q-question more concrete. How much more than Chinchilla-optimal are models over-trained?
- RPReiner Pope
Hmm. Yeah.
- DPDwarkesh Patel
And has that changed as a result of RL generation?
- RPReiner Pope
This is a place where we have to do a bit of guesswork because, like, the, um, the updated scaling laws and, and the us- and the model traffics are not reported, and so we have to guess there. Um, but, uh, one way to look at it, um... Let me firs-first just make a sort of a general heuristic claim. If I am, if I have some, like, cost, and I've got a total cost, which is a sum of, like, cost A and cost B, like maybe this is the training cost and this is the inference cost.
- DPDwarkesh Patel
Yep.
- RPReiner Pope
Um, and so I want to minimize this sum. For many, uh, for many curves that end up being the case, the minimum tends to be where these are, where the costs are equalized. Um, that's something of a heuristic claim, but, uh, you can, you can-- it tends-- i-- like there are many examples where it's true, like, uh, where one is one over X and the other one is, is X, for example. Um, they tend to be-
- DPDwarkesh Patel
Mm-hmm
- RPReiner Pope
... uh, minimized at, at, at the point where, uh, they equal each other. Um, it's also true for like, um, E to the X and like E to the minus X and all, all kinds of other things.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Um, uh, like so basically, I've got some, I've got some curve that's going down, some other curve that's going up, and they tend to be minimized at, at this e-e-equal point. Um, heuristically, I will conjecture that that is true, um, for the setup you described as well. Um, uh, like actually showing that that would be true would require looking at the scaling laws and, um, and like fitting these like weird exponents. Um, but, but things that do follow power laws tend to, tend to have this property. So I'll just make that claim and move on. Um, so we're gonna say that the, uh, cost of training, um, plus the cost of inference, um, we want to equalize these. Um, uh, we'll do pre-training only first 'cause it's a little-- Well, actually we, we can do all of it in general, so, so actually, we'll, we'll cost optimize. Um, cost of pre-training, so number of, uh, so number of, number of active params, um, times the data on pre-training. So that's the cost of pre-training. There's a factor of six out here, which is the number of flops. Um, this is the famous six ND formula. Um, and then in, in RL, we have approximately the same thing. We've got like the same number of active parameters, um, but now it's, uh, the amount of data is the RL data. Um, there's this extra like efficiency multiplier, which is, um, or inefficiency, like the, um, the inefficiency, um, uh-
- DPDwarkesh Patel
Wh-wh-which is the fact that you're not training on all your rollouts.
- RPReiner Pope
Well, yeah. Th-th-th-there's that. Um, and then the other perhaps even bigger inefficiency is that, um, this involves a substantial amount of decode, and often decode runs at, uh, less MFU than, than, than training.
- DPDwarkesh Patel
I see. Okay. Uh, so if you're doing a backward pass on every single generation in RL, it would be six ND.
- RPReiner Pope
Yeah. So this could be a smaller number, right?
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Like, this could be somewhere, so, um-
- DPDwarkesh Patel
It would at least be two, 'cause that's a forward-
- RPReiner Pope
Yeah, somewhere in the range of two to six.
- DPDwarkesh Patel
Yeah. [laughs]
- RPReiner Pope
So we'll just like, we'll say somewhere in the range of two to six and leave it at that.
- 1:32:52 – 2:03:52
Deducing long context memory costs from API pricing
- DPDwarkesh Patel
um, in the spirit of trying to deduce things, we can publicly look up the prices of the APIs of these models, and, um, maybe we can learn something from that. So, uh, first, with, uh, longer context, um, [clears throat] Gemini three point one is, um, fifty percent more expensive if you go over two hundred K tokens than if we're below two hundred K tokens. W- I mean, at a high level, I understand why that might that be, but why specifically fifty percent?
- RPReiner Pope
Yeah. Um, so I mean, why specifically fifty percent? Let's, let's sort of, um... So, so the high level-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... uh, even in the first place is, um, there is some amount of, uh, increasing cost with, with context length.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Um, and, and, uh, w-we can bring that back up. That was the, um, the, the memory time versus the compute time. So, um, okay. So we, we've put up these same equations from before of the, the time for memory fetches, which is the weights and, and the KB cache, um, and then the, the time for the compute, which is just the, uh, matrix multiplications for the weights. I, I will also draw the, um, the, the cost curve. Um, but this time I'll do it as a function of context length, uh, instead of as a function of batch size. Um, so this is time over, uh, yeah, just, just time. Uh, and so this is the cost curve as a function of context length. Um, we'll draw the compute. Um, the com-- the, the cost of the compute is actually constant as a function of context length. There's no dependence here on context length. In reality, there is some dependence, but it is very mild dependence, so we'll ignore it. Um, so this is the, um, time for the compute. This one. Uh, and then we'll also draw the dependence, uh, of the memory fetch on, on context length. And this starts at a large number for the weights and then grows gradually with, um, with the context length. So, uh, maybe here, um, and then grow gradually with context length. And so you take the maximum, and you see there is this inflection point here. So now-- so this is the costs that, uh, that, that, for example, Gemini might be paying. Um, and then you think how, how, how might you put a pricing structure on top of that? Um, you would like to ensure that no matter what the context length is, you are, you are still profitable. So-
- DPDwarkesh Patel
Interesting
- RPReiner Pope
... and so we've got a two-tier pricing structure. Maybe we've got something that looks like this up to some max context length.
- DPDwarkesh Patel
Yeah. Fascinating.
- RPReiner Pope
So I think it says something about, um, given that the bump is at two hundred K, it probably means that this is somewhat aligned with this con-- crossover point.
- DPDwarkesh Patel
Ah.
- RPReiner Pope
Maybe not exactly aligned with.
- DPDwarkesh Patel
Fascinating.
- RPReiner Pope
Um, so we can actually probably even complete that calculation just to see where it lands out. Um, we can solve for the number of bytes per token-
- DPDwarkesh Patel
Huh
- RPReiner Pope
... um, if, if, if we sort of make some assumptions about the number of active parameters. So solving for the number of bytes per token, um, we're gonna assume like the, the point where we equalize, um, the time of memory and the time of compute is at, let's say, two hundred K, uh, tokens. Um, so we equalize these two. Um, we're also gonna just, uh, assume that the batch size is large enough that the, um, the memory time spent on weights is, is negligible. So we'll forget about this, and we'll focus on the actual memory time spent on KB cache. So that ends up saying copying this term over, batch times len context, um, times, uh, bytes per token, um, over mem bandwidth, is gonna be equal to, uh, number of activated params over FLOPs. And then we're gonna solve for bytes per token. Um... Batch size was missing here. Shows up here, and then it cancels out by the time we get to here.And, uh, and I, I dropped the len context
- DPDwarkesh Patel
Yeah
- RPReiner Pope
So we can plug in numbers. This number, this is, this is this-- well, is the reciprocal of the number that we saw before. It's, yeah, this is like one over three hundred, um, which is reasonably stable across many, um, different hardware platforms. We conjecturally said that maybe number of activated tokens is like a hundred billion. And length of the context we said was 200K. Um, something is wrong here, though. Length of the context should be on the denominator, not the numerator. Um.
- DPDwarkesh Patel
Um, one six six seven. Like about one, one kilo-- almost two kilobyte.
- RPReiner Pope
That's, uh, that, that is plausible actually. Um, so we said around two kilobytes. Um, so, um, so l-let's just do a, a sanity check for this, um, for what this could be. Um, there are two mechanisms that people do, uh, attention with a small number of bytes per token. Um, one is, uh, dense attention with a lot of reuse across layers. Um, so Character AI has a blog post talking about that, alternating long and short context. Um, and like in the Character AI kind of model, uh, which also showed up in the Gemma models, the global context, which is really what we're talking about here, global context, um, was shared across all the layers. And so to get this two kilobytes, you could get that, for example, as, um, a d head of a hundred and twenty-eight, um, is, is typical. Um, and then, like the number of bytes is typically, um, number of attention layers, um, uh, times two times d head, uh, times, uh, number of, uh, Q heads. So, um, this is the number of unique contexts per layer. Do you ha- do you share the, the context across many layers, or do you use it only once? Um, uh, so in Character AI-like models, uh, this number is one. Um, we said this is a hundred, hundred twenty-eight. Um, and, uh, this is a choice which typically ranges from one... Uh, sorry, this is KV heads, I meant. Um-
- DPDwarkesh Patel
The difference between a head and a KV head is that-
- RPReiner Pope
The KV heads are the heads that are stored in memory, like store the contents of the previous tokens. The Q heads are the, um, the retrieval heads. They're, they're only used temporarily, and they're, they're used by the attending token. So, um, in this autoregressive context-
- DPDwarkesh Patel
Yep
- RPReiner Pope
... I've got KV heads associated with all of the contexts.
- DPDwarkesh Patel
Yep.
- RPReiner Pope
And then Q heads associated with this new token here.
- DPDwarkesh Patel
But, but, but this head, the one twenty-eight.
- RPReiner Pope
Oh, uh, this is, um, it, it-- this, this number is actually the same for... Oh, so this d head is the dimension of the vector.
- DPDwarkesh Patel
Ah, yeah. Oh, so it's just-
- RPReiner Pope
Yeah. Uh, and number of KV heads is typically in the range of one to eight.
- 2:03:52 – 2:13:39
Convergent evolution between neural nets and cryptography
- DPDwarkesh Patel
So we're sitting down because I wanna ask you some questions that, uh, I guess don't need a blackboard. Um, you have this ex-extremely interesting blog post where you talk about how, at a high level, the architecture of different cryptographic protocols looks a lot like neural networks, and there's this conversion evolution where they both need to jumble information across all their inputs. For cryptographic protocols, it's to make sure that there's, like, each new input into a hash function will totally scramble what happens. For neural networks, of course, they need to consider informa- how this piece of information changes what you should make of this other piece of information. And that is a extremely interesting point. I guess the, uh... at a high level, the, the difference in what they're trying to do, in some sense, they're trying to do the inverse thing.
- RPReiner Pope
Right.
- DPDwarkesh Patel
Which is, um, cryptographic protocols are trying to take information which has structure and make it look indistinguishable from randomness.
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
And, uh, neural networks are trying to take things which are, look like random, protein sequences, DNA, garbled text, and extract higher level structure from it. So they have similar high-level mechanisms, but they're actually kind of trying to do the opposite things. Um, yeah, I wonder what you make of that.
- RPReiner Pope
Yeah. Um, so I mean, the, like, the mixing, like, uh, I tried to look for other examples where mixing, like scrambling, mixing shows up as well. There's actually almost even like a physical example where, like, you're stirring something, you're making a cake and you want to stir the batter. And like, literally the idea, like first stir it this way and then stir it this way is like actually not too bad of an approach.
- DPDwarkesh Patel
No.
- RPReiner Pope
Um, but beyond that, like in, back to the digital world, um, th-there are some differences, and the one you talk, uh, call out is, is a pretty strong difference. Um, the way it shows up, um, like what makes neural nets, uh, like if you just randomly initialize a neural network, actually, maybe it's a reasonable cryptograph, like a, a, a cipher as well, because like the random initialization is gonna jumble stuff in a complicated way.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
It may even like do what you want, who knows? Um, uh, the thing that makes it interpretable is the gradient descent. So you can differentiate a neural network, um, and get a meaningful derivative. Um, and we do a lot of work to, like, not overcomplicate the derivative. So the residual connection keeps it, like, con-contained and simple. Um, and the, uh, and so does like the Layer Norm, uh, stuff that we do. Um, one of the biggest attacks against, uh, cryptographic ciphers is also to differentiate the cipher. Um, ciphers run in a different number field. They run in, um, uh, the field of two elements, so just binary. Um, whereas neural nets run like in theory, in the field of real numbers. Um, uh, and so you have to differentiate with respect to like binary numbers. Um, but you can absolutely differentiate a cipher, and this is called differential cryptanalysis. And, uh, like, basically what it says is that if you take a small difference of the input, how, like, uh-
- DPDwarkesh Patel
Mm.
- RPReiner Pope
It's quite difficult to make, uh, the difference of the output be small. Like, uh, like, uh, the whole job of a, of a well-designed cipher-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... is to make the difference in output very large. Um, so I, I guess the distinction is that the, the optimization goals at that point are about complexifying. They, they don't have the same residual connections or, um, or like Layer Norms that would-
- DPDwarkesh Patel
Yeah. I mean, I, I guess a place where the, the two merge is backdoors.
- RPReiner Pope
Mm.
- DPDwarkesh Patel
Um, okay, so with a backdoor neural LM, you're trying to hide, um... What do you consider an input? It's not an input into the forward pass, but it's an input into the backward pass. But you're trying to hide an input into the backward pass.
- RPReiner Pope
Like you're-- Like, this is like an adversarial, uh-
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
Uh, context.
- DPDwarkesh Patel
Yeah.
- RPReiner Pope
So yeah, I mean, in fact, this is, like, this is actually a place where you get exactly the, um, sort of avalanche property that ciphers have as well. Um, like adversarial attacks on typically like image classification models, right, are, can I find a perturbation of the image that, a very, very small perturbation of the image-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... that totally changes the classification-
- DPDwarkesh Patel
Yeah
- RPReiner Pope
... totally changes the output?Th- that is the common case in ciphers, whereas it-- that's the, like, undesired case in, in neural nets, for sure. Yeah.
- DPDwarkesh Patel
Okay. So I was asking you, uh, has... have neural networks actually been used for cryptography? And, um, we realized it might be better to just do this on the blackboard.
- RPReiner Pope
Yeah.
- DPDwarkesh Patel
Um, so I'm curious, are, are they actually being used for cryptography?
- RPReiner Pope
Yeah. So using neural nets for cryptography-- Well, in general, cryptography, like creating a new cipher, is a very, very dangerous proposition.
Episode duration: 2:13:40
Install uListen for AI-powered chat & search across the full episode — Get Full Transcript
Transcript of episode xmkSf5IS-zw
Get more out of YouTube videos.
High quality summaries for YouTube videos. Accurate transcripts to search & find moments. Powered by ChatGPT & Claude AI.
Add to Chrome