speaker 1: So this is lecture ten. We're going to take a brief respite from scaling laws, and we're going to talk about inference. So the question is, inference is a very simple one, given a fixed model that we've trained, generate responses, given prompts. Okay, first, we're going to start by understanding what the implications of inference are and the workload that it entails. And then we're going to talk about ways of making inference faster. And throughout this lecture, you're going to see that there's a lot of inference is a very deep topic. It's actually the, we didn't do inference last year in lecture, so this is the first year we're doing it. But there's actually many, many topics that could spam multiple lectures, which I'll try to condense into one. So inference shows up in multiple different places. The most obvious place is if you actually want to use a model, you want to use it to chat, you're using cursor or something to do code completion, if you're running batch data processing job using your language model, all of these cases demand inference because you need to generate tokens from your actual model. But it also shows up in other contexts. If you want to even evaluate your model, as I say on instruction following, you need to do inference. There is a lot of interest in test time compute, which means thinking more before you actually output some the final answer. And that's also more inference, because thinking is basically generate tokens. And then finally, even training itself, if you're using reinforcement learning, you need to sample responses and then evaluate them based on some reward. And that also requires inference. So inference isn't just I want to put up a chatbot demo. Inference actually is going to underlie many of the basic functions of a language model. And even though it's one lecture, I want to stress how actually important it is for many things. And we'll probably come back to this when we talk about in alignment later in the class. So now inference is important. So the theme of this class is efficiency. And efficiency clearly matters. Training is a one time cost, but inference you repeat multiple times. So here's some anecdotal stats on why inference is a big deal. So Sam says opai generates 100 billion words a day, which is quite a lot. And even cursor, which is not that new of a product, is allegedly generating a billion lines of accept the code each day. So it's just gives you an idea of how much inference is accounting for, and your cost of inference compared to training are definitely increasing. So how do you measure what inference, good inference, looks like? So there's time to first token ttft. So this is how long a user, an individual user, needs to wait before any generation happens at all. And this matters clearly for interactive applications. If you have a big prompt and then you have to wait there for 10s, that may not be a good user experience. Latency is how fast tokens are arriving after maybe the first first token. This also matters for interactive applications. Throughput is something a bit different. Throughput is how many tokens in general are generated per not for overall users. So this is particularly useful in batch processing applications. So you can think about that. Throughput is high throughput doesn't mean low latency because some your request might just take a very long time and you still have high throughput. Latency is kind of like the worst case over any individual, your user. So what do you need to think about when you think about the efficiency of inference? So in training, the key idea is that you get to see all the tokens, at least a supervised training, which means that you can paralyze over the sequence. This is exploited heavily in the transformer, right? So you've done the transformer training. You know that you basically construct these tensors over the entire sequence, and it's just like sor tensor tensor mamals, and then you get your output. But the key defining feature of inference, at least for transformers, is that you have to generate sequentially. You can't paralyze because the generation of a token depends on all the past. So this is going to be the key thing that's going to make inference a lot harder. And in particular, it's going to be harder to utilize all the compute that's available, and it's going to be memory limited, as we'll see in detail later. So a lot of people are doing inference. Anyone who actually has a product and platform quickly realizes that these cost and doing large models is going to go up. So they spend a lot of time and engineering effort trying to reduce that time. So both providers serving clothmodels and providers serving open weight models pay a lot of attention to inference. Ts, more so than, I think, the average academic, because we're not actually serving any models. We're just training and getting a score and putting in the paper. But people who are actually serving models pay a lot of attention to inference. So there's also a bunch of open source packages which are interesting to look at as well. Okay. So I want to understand the inference workload kind of in detail. So I'm going to review briefly the sort of this transformer math that we you did in assignment one, and we talked a little bit about it during the first week of class. So this is from the scaling jacksml book, which it's something you guys should really take a look at. I think it does an excellent job of outlining many of the key concepts here. And they have this, a really nice diagram that shows essentially the computation graph taken an input and having it go through attention. And the mlp layers in particular, we're going to use this notation. So just to kind of review this the quickly so b is the number of sequences in your batch, l is the number of layers, t is the sequence length you can think about as the number of tokens that you're going to generate or query using. S is also the sequence length, but how many your kind of condion in your prompt v the vocabulary. D is the dimensionality of your model. F is the mlp hidden dimension, which is usually four times D. H is the attention head dimension, n is the number of query head. So generally, n times H equals d and then in gqa group query attention, you have a different number of key value heads as query heads. User k is smaller than n, and g is the number of groups. So k times g equals n. And this diagram shows that you take your x, you feed through the q kv matrices, and you do a bunch of things. So remember that the flops required for a fifa PaaS is six times the number of tokens, which is b times t times the number of parameters. Plus for the attention sion, there's another order of t, so t times t is t squared dependence. Okay, so let's also review arithmetic intensity, which is going to help us characterize when something is compute limited versus memory limited. So just to start with the basic mammal, so let's take a matrix x, which is b by d, and a matrix W, D by f. And just to give some color to this computation, b is the batch size, d the hidden dimension, and f is the up projection matrix and the gta mlp. So let's do count the number of flops and memory read rights for just doing x times W. Okay, so we can start with initializero. And what one has to do for this is we're going to read x from hbm. So that means it can encourage the memory cost of two times b, times d, assuming everything is a bf 16. You also read W, so that's two times deep times f then you do the Mamal and that incurs two times b, times d, times f flops. So remember, this is from the first lecture. Hopefully this is review and then you have to write it back out, which is you have to pay another transfer. Okay? So the total number of flops is just the mammal. And the number of bytes transferred is essentially the size of all the matrices that are red and written. And arithmeric intensity is basically the ratio. So the ratio is this expression. And in general, just to simplify things a bit, generally the batch size that's much less than dnf b, maybe you know hunds and a dnf might be thousands or tens of thousands. So I'm using simpi here just to keep myself from making silly mistakes. So basically, I'm letting c go to infinity and d scales as c times b and f scales as c times b, and that gets you a simplified equation of b. So the arithmetic intensity is b for this particular matrix multiplication. And the okay. So the way to interpret this is how many flops are done per byte that was transferred. So now the second part is you look at the accelerator, which for H -100 flops per second is 9008, 989 teraflops, memory bandwidth, 3.3 terbytes per second, and you divide, and that gives you what is called the accelerator intensity. And if you look at the computational intensity, which is b, if it's created an accelerated intensity, that means you're computer limited, that means you're able to use all the GPU's or GPU's. And if you're less than that, then you're memory limited, which is bad. And so your computer limited, in this matrix multiplication case, if b is greater than 295 for H -100. And all of this is a bit idealized. The actual details there's it's this is giving you a kind of a first order approximation. So in extreme case, so generally, if you that means if you use batches of size, let's say, 300, then you'll be going to be able to saturate the GPU. But what happens if your batch is really small? So in particular, b equals one, which essentially corresponds to a matrix vector of product. Then the arithmetic insensiis basically one. And that is really, really bad. That means you're going to be memory limited. And which kind of makes sense, because basically you're reading and writing this d times, actually, you're just reading this d times f matrix, and you're performing essentially the same number of flops, right? So the ratio between the flops and the reads is the same, which gives you one. And one is bad. You want a lot of flops to be done for any memory readbecause. Memory reads are slow. But this is, in this essence, what happens with generation because you're proceeding token by token, we'll see that basically your arithmetic intensity is going to be like one, and that's why generation is going to be memory limited and not compulimited. Okay. So this is a very simple example that I think gets at the core of why generation is going to be slow. So maybe I'll pause and take any questions on this just to make sure everyone's clear. Yeah, I mean, when doing it inspider, we have a bad sense of larger headone. So I think I heard the question, why don't we have a batch size more than one? So get to why you can. But there's batch size is going to mean batch size. I'm sequence length later. Okay. So in summary, matrix multiplications are the kind of the core computation. So we just studied a matrix multiplication and counted the number of flops it requires over the number of region writand. We show that that ratio, which is intensity, depends on one of the dimensions, in this case, the batches dimension. And that's why big matrices are good, because I can saturate your computer. Whereas if you have even a thin matrix, b equals one, that's really bad because you're spending a lot of time reading from memory and not doing that much compute. Okay, so now let's talk about the arithmetic intensity of inference. Okay, so let's just kind of get more into the weeds of what inference looks like. So the naive thing you can imagine doing, and all these nice pictures are taken from this book, is that you have a transformer. You give the prompt in, gives you loggets over the vocabulary of the next token, and you just sample from that. And then once you get that, you attach it to the prompt and then you feed it through the transformer and you look at the logic sample again and you repeat. So that's the sort of most naive thing to do. And know the complexity here is pretty bad because each token you generate is like an n squared or t squared computation through the transformer. Okay, so that's no good. But if you look at this closely, you'll notice that you're doing a lot of redundant work, right? All of these, the work in you basically encoding the prefix basically stays the same. So this is for a bidirectional transformer would be different, but at least for an autoregressive causal transformer, it is the case that you should be able to share a lot between prefixes. And so the solution is you cache and you cin the him hpm because that's where you have enough space to store stuff. So this is what looks like if you have a kv cache, it schematically. So you take your prompt, the prefill step is you feed it through the transformer and you compute this kv cache and then you generate the logets over the next token. And then you put that into, you take that generated token and the cache, and then you can feed it through the transformer, but you've already computed these, so you don't have to do that again. You just need to compute this new kv vector for this token. And now that allows you to more quickly generate the next token and so on. So basically, you're filling up this kv cache, which corresponds to the tokens that you've either prefilled with or that you've generated so far. Okay? So instead of t squared per token, it's going to be more like t. Okay? So concretely, the kv cache is for every sequence in your batch, for every token in your sequence, for every layer of the transformer, for every head, you're going to store an H dimensional vector. So you might think that this is going to take a lot of memory and you wouldn't be wrong. Okay? So there's two stages of inference. So prefill is you're given your prompt encoded in a vector. So this is just like what you do in training. It's paralyzable, it's fast, your compute limited, life is good. And then you're doing generation, which is you're generating response hokens one by one sequentially. And this is a part that is going to give us a lot of trouble in terms of efficiency. So now let's compute the flops and memory io for both for the transformer. So we're going to break it down into the mlp layers and the attention layers. And just for notation wise, we're going to do this computation with s being the number of tokens we're conditioning on. Think about the length of the prompt and t is the number of tokens we're generating or are query using. And in prefill, t is going to be s because we're of I mean, we're not generating t tokens, but we're sort of querying using each of these tokens and a generation where t is just one. Okay. So hopefully, the matrix multiplication is so fresh in your head because this is going to be essentially that, but a little bit more complicated because it's a transformer. So we're going to count the flops and bytes generated. So first we're going to take x, which is b by A, T by d matrix. I think maybe these t should be s's, but anyway, so that involves doing a bunch of transfers. Basically the size of that matrix times two because pf 16. Then there's the three way matrices, the up projection, the gate and the down projection. They're all the same size, up to transposition. So you need to transfer those. Then you do the up projection. That's some number of flops. So b times D F, C times f. So whenever you multiply two tensors, basically the contracting dimension that only gets counted once, whereas other dimensions you just kind of gathered together, you need to write it out. You also have the gate, which is the same same thing. You write it out, you compute your nonlinlinity, you multiplied some in your down project, and that's A B, times t, times d, times f, which is basically the same number of flops, and you write out the result. Okay? So if you look at the counting, I guess maybe I'll just you know you can check the results. Actually, you don't need to check it because this is simpi and it's a guarantee to v, correct? So but again, we're going to assume that b times d is much smaller than d and and f, and we get that the intensity is b times t. Okay? So this is analogous to the matrix multiplication case where then atthermatic tic intensity, which we want to be high, depends on how large your batches and how many tokens you're they're essentially generating. Okay? So now if you look at the two stages, prefill life is good, remember, because we can just make bt large enough. You use a batch size. Even a batch size of one actually is maybe okay if you have long enough sequence. So that's not a problem. Now generation, this is where it becomes a little bit harder because you're generating one token at a time. So t is one, right? So if t is one, that means for bt to be large, you need b to be large and b is essentially the number of concurrent requests. Okay. So this is kind of interesting because your sort efficiency depends on having large batch sizes because I mean, intuitively it makes sense. If you can take a lot of requbatshing together, then you can get better efficiency, at least throughput. But this also depends on what b is because if it's you're only getting a few requests at a time, then you're not going to be able to use your hardware very efficiently. And this talk speaks to the sort of the very dynamic aspect of inference which we'll come back to later in the lecture. Okay, so now what about attention? Turns out attention is even worse for reasons I'll try to get into. So let's do accounting flops by ytes transferred. Okay, so I'm going to read the qkv matrices from hbm. I'm going to compute the attention sion, which is a matrix which is q times k, and the number of flops is b times s, times t, times d. So remember, s and t are the same during a prefelso. That's your sequence length squared times b, times d. And then I'm sort of only looking at the matrix multiplications because the flops from other steps don't really matter. And then you project out to, sorry, you take a combination of this nv. So actually, this is mathematically incorrect because there's some soft maxes there, but the essence of the mamals are the same. So that's the same number of flops. And then you write to hpm. Okay. So and here, I'm assuming there would be more bytes transferred if you didn't use flash attention. Flash attention means that you don't have to keep on writing back to hbm into intermediate steps, but the order is actually not really affected. So qualitatively, doesn't really matter whether you use flash attention or not. But the math here depends on the constants matter. But let's look at the flops and the bytes. Transferand, if you divide and simplify, you get this rather nice expression. I mean, nice in that it's simple, not nice in that it's good efficiency, which is s times t divided by s plus t. Okay, so let's try to interpret this a bit. So in prefill, t equals s, so that means your prefill intensity is order s. So that's good, right? Because as long as you have long enough sequences, then you're good to go. And generally the sequences can assume or long enough during generation. However, you'll see that the intensity is essentially one s over s plus one, but that's basically one. And remember, one is really bad. Okay, so but notice like what there's no dependence on b at all. So unlike an mlps, remember in mlps, the generation of the prefill was bt, which is great. And then armeof intensity was b, which was not great because it depends on the whims of your users and workloads, but still could be larger than one. Whereas for attention, it's actually just always less than one. No matter how how long your sequences there are, how many users there are, it's always one. So why is this intuitively that there's no dependence on b, the batch dimension? So the reason is that in the mlp layers, intuitively every sequence hits the same mlp weights. So whereas in attention layer, each sequence has its own kv cache, because the kv cache is sequence specific, which means that you can't really know. In the mop case, you can kind of read all the weights and then you process a batch intuitively. Whereas in an attention case, every sequence kind of requires additional memory. You don't get any kind of savings if you if you kind of batch them up mathematically. I guess you can look at it through here where the number of flops, there's A B here, which is expected, but the number of bytes transferred b times that there's a scaling in b. So when you divide that b cancels, whereas over here, there is A B here, but we're assuming that df dominates. So when you divide, basically, there's no b essentially left in the denominator. So you can look at it mathematically or you can just kind of reason about it intuitively as during further the attention the kiv cache is sort of every sequences own one unix snowflake. Okay. So the summary is prefill is compulimited where its generation is memory limited. The mlp arithmetic intensity is b, which to make good enough you need a bunch of concurrent requests. But attention intensity is one which and it's also impossible to improve that. Okay, I'll pause a bit for any questions. Okay, so let's move on. So now we know that inference due thanks to generation is memory limited. Let's try to study the throughput and latency, at least in theory. So let's focus on. Let's see actually. Okay, so we're going to make some assumptions. So all of this sort of napkin math is a little bit stylized, but it gives you roughly the right kind of scaling and the right way to think about things. So we're gonna to assume that communication and compute can be perfectly overlapped, which is obviously false, but it's good enough for making these qualitative estimates. So what we're going to do is we're going to instantiate the latency and throughput for a lama 2:13b on H -100. Okay. So for a 13b, here are the values. So the let's just put the sequence length to be 1000, hidden dimension to be model dimension to be 5004 times X I don't know if that's that's not four times, but anyway, f is some multiple of that number of heads, number of key value, I guess, query heads, number of key value heads, which for lama two is the same. We'll get to that point later and so on. And for the memory bandwidth of H -100, that's the number. Okay, so that's the config and we're going to compute the memory lengency and throughput. So so first, let's just quickly get the number of parameters. You guys did this in assignment one. So I won't belabor this, but it's some expression that depends on all the different variables. And to store the parameters, we're going to use f bf 16 because inference is generally going to be 16 bit, not 32 bit. So we're going to multiply it to it. So that's the memory that the parameters take. Okay? We don't need gradients, we don't need optimizer states because we're not training, but we do have to store the kv cache, which are the sum of the activations, not all the activations, but some of them for every sequence of length s and how much we have to per store per sequence, it's basically the sequence length times the number of key value heads times the dimension of that head times the number of layers times basically two for basically both a key and the value and two for bf 16. Okay, so that's how much the cache size takes. And so the total memory is batch size times the cache per sequence plus the memory per the parameter size. So now latency is going to be determined by memory ile. Remember, it's memory limited. So we're just going to compute how much memory needs to be transferred into the GPU to do this computation. And simply, memory over the memory bandwidth and throughput is essentially the inverse of lane c, but scaled up by b because we're looking at generating b tokens in parallel. Okay. So now if we substitute our lama two config, we'll see that the number of parameters checks out. It's 13 billion roughly. The memory latency and throughput have these expressions. So memory grows. Obviously, this is the parameter size. This is the key value. Cache size times b. Latency also goes up as a function of b. Throughput increases, but you'll see that it increases up to a point. The b shows up in both the numerator and the denominator. So there's limits to how much you can stretch throughput even if you could fit everything in memory. Okay? So those are the expressions for latency, throughput and memory. Four, this particular model. So now let's instantiate with different batch sizes. So if b equals one, then the lane c is about eight milliseconds. So eight mils every 80 milliseconds you generate a token and the throughput is 124 tokens per second. Okay, so that's 13b on the H -100 if you're using batch size of one. So now what happens if you use batch size of 16? So you'll see that the memory usage increases because you need to store the kvcache for all 64 sequences. Now the latency goes up because you kind of have to instead of just proessing one, you have to kind of wait for everything to finish. But the throughput also goes up actually quite a lot. Okay? So you're seeing kind of this immediate trade deoff between latency and throughput. If you're on low latency, you just use one, one b equals one. But if you want I throughput, you want larger b. In general, what happens if you use batch size of even larger? So 256, you see that the lane ency goes up, throughput goes up, but you see that the throughput isn't going up that much because you get dimenreturns after a while. But the most kind of you can actually do this on a 100 because if you look at them, the memory is 240 gigs, so it doesn't even fit. Okay? So the batch size, you can only increase to a certain point because of memory. Okay. So just to recap, there's a trade deoff between latency and throughput. Smaller batch sizes, you battery latency. Larger batch sizes yield better throughput. And finally, last week talked about parallelism for training. And it was kind of complicated, annoying, at least one type of parallelism for inference is really, really nice and simple. You just launch m copies of a model, okay? No communication because the model, you don't need to update the models. The lane c is the same and the throughput increases by m. So that's pretty good. So always remember that. You know don't forget easy things. Now there are cases where you if you have a large enough model, then maybe it doesn't even fit on a single GPU and you need to shard the model. And in this case, you also want to start sharthe kv cache in some cases to get better efficiency. So there's for more details, check out this book chapter. Okay. So the time to first token, which is a metric I mentioned earlier, is essentially a function of the prefill. It's basically how long does it take to encode the prompt? And usually it's you. This is compute limited. So you're basically going as fast as you can and there's not much you can do about it given a fixed architecture. And well, okay, so sorry, you can improve it if you reduce the batch size still. But if you want to improve the throughput, you have to increase the batch size. Okay. So any questions about that? So this was on computing the throughput and latency. And because of the memory limited argument that I gave in the previous part, I just focus on memory and compute how many bytes need to be sent. And that gives me a rough bound on the latency in practice, the compute, there are some regimes where a compute does matter, but I'm sort of ignoring that. Just just keep things simple. Okay, a question. Yes, this is assuming a single GPU while creating a apps from multiple users. Each of these filhave will be complicated. Yeah. So the question is if you have multiple users and you're batching together, they might arrive at different times, they're going to finish at different times. So we're going to get to that. That's going to be a special issue that we're going to have to deal with. Any other questions? Okay. So now we have a good handle on what the inference workload looks like. We looked at the earth thmeintensity. We looked at the transformer inference with respect in arithmetic intensity. We saw that it was memory limited thanks to the tension where the kv cache has to be special for every sequence. And then using that, we can compute throughput and latency, which are the main inference metrics that we we care about. Now how do we make things better? Okay. So there are some things that you can do on that are loless, you can write better kernels, you can improve your systems. But I would say that the the kind of there's a lot you can do if you're willing to take shortcuts. And these are kind of really interesting because technically this lecture is on inference, but secretly it's on model architectures. Because what you'll see is that a lot of the changes in model architecture are going to have direct impact on inference, and we're actually inspired by needing to do inference quickly. Okay. So the big bottleneck here is the kv cache, right? Because remember, memory limited, which means that the less memory stuff takes, then the faster you go, not just because of a flops, even though that's the permit, but mostly due to memory, because it's mostly about memory transfers. If that's one thing you take away from this lecture, it's like all about the kind of the memory for speed. Okay? So the problem is that if you just start walking away at the kv cache or you might lose accuracy. So how can you make sure you don't lose too much accuracy but still maintain your kvche small? So there's a bunch of ideas I'm going to go through that all essentially try to change the architecture to reduce your kv cache. Some of these ideas I think you've seen, but I'll go through them kind of in this sort more systematic way. So there's this idea called group query attention. So multi headed attention, which is the vanilla transformer, keeps around basically number of heads. And for each of the that number, you have same number of keys, values and queries. There was one time of mulquery attention, which you only have one key and one value, basically one key value. You had turned out that that was not very expressive. So there was a sort of intermediate point where you have a reduced number of keys and values and then you have more queries. So why are we doing this? Well, remember, we want to reduce the kv cache size. So the fewer keys and values there are, the better. So the batch size and the sequence length doesn't get changed. But it's in the dimensionality of these vectors don't change, but it's the number of key value heads that we're reducing. Okay. So so that's basically the idea and this paper shows that you do get lengency and throughput improvement. So times per sample. And as you increase the number of groups, then up to eight or so, basically there's a negligible it's really fast compared to the full attention. And as you increase the number of groups, obviously, you kind of end up at the original. So that's latency and throughput improvements. And just to actually do this kind of more rigorously, so we have our llama to 13b model. And if we compute the statistics, this is using a batch size of 64. Remember, this is what we got. I guess I should prnow latency here. Well, and then if you run it with a gqa, you see that the memory is reduced and the throughput goes way up. So this is actually great. So this is what happens if I take the llama to 13B Architecture, and I just reduce for every query hi'm going. So for every key value head, I have five query heads. That's what one to five ratio means. So which this also means we can use a larger batch size, because remember last time we tried to do 256, it even fit in nh 100s memory. So now we can actually comfortably fit into the H -100 memory, and then we can further improve the throughput by using a larger batch size. So you can see kind of a lot of different effects here. By reducing the number of key value pairs, the memory of the kv cache reduces. That means the throughput lanency go up automatically because fewer memory transfers. And Furthermore, is a secondary effect. I can rethe batch size within the GPU, and that further improves the throughput. Okay, so that's wonderful. We have to also make sure the accuracy doesn't drop. So this is this original paper that shows that this is full attention. This is gqa. The time is much less, but the accuracy is basically the same. Okay, now you know what actually happened. So I don't so lama two did not use this ratio, but lama three actually picked up no gqa and probably motivated by the kind of inference cost. Actually, lama two, I think the 70, the large model did have gqa, but not the smaller ones. Okay, so that's a gqa. There's another way to reduce the key value cache. And this comes from deep seeq. So this is actually from the deep seev two paper. And this is called multihead lane intention, which tato lectured about previously. But I'll try to talk about it in the context of inference and its implications. So the basic idea is here's full attention. And gqa says I'm going to use fewer keys and values. Mla says I'm not going to change the number of keym values. I'm going to project these into a lower dimensional space. So it's another way of shrinking the kv size, but just in, I guess, in a different dimension. So instead of using n times H dimensions for each for the kv cache of each token, I'm going to project out to c dimensions. And this is what deep seat did. It's actually quite a aggressive reduction from 16000 to five to twelve. Only wrinkle is that this is not compatible with ropes, so they need to add a few more dimensions to put rope back in. But overall, this is actually quite promising from a kv reduction perspective. I'm not going to do the math, but you can just trust me that you can see kind of how the kv cache would be reduced a lot and you get to the same kind of lengency and throughput advantages. And in terms of accuracy, they actually showed that compared to gqa, the mh, sorry, the actually maybe I'm showing the wrong thing here. Ch, okay, I meant to show that mla actually improves, but this table does not show that, so I have to dig that up later. But anyway, mla does preserve the accuracy as well. Okay, so there's another idea which says, well, you know, the gqa basically shares, you can think about it as a sharing key value vectors right within a token and within a sequence. But we can also look at something called cross layer attention, which there's a paper on this, but I think many people have been kind of thinking about this and doing this. So I don't know if this is actually the first paper, but basically if you look at the transformer eridiagram, you have the key value projection of one layer and then you have the next layer. And these key value vectors are separate usually. But the idea here with cla is that we're just going na use the same key value projection across layers. That's why it's called cross layer attention. So just as gqa shares across heads, cla shares across layers. So here we they show that they empirically improve the parreto frontier of accuracy and the kvcache size. So kv cache size, which relates to throughput and latency, you want to be small and you want perplexy also to be small so they're able to improve that. Okay? So notice that I mean, for example, H 64 heads, you know the cache size goes, it gets reduced, but the validation proplexy does go up a little bit. But overall, there's kind of advantage in making that trade off. Okay? So there's yet another way to do things. So local attention, which has been explored actually quite a bit since even kind of the there's a long former, there's an OpenAI paper and then mistreal, and I think many others use this as well. It's a very, I guess, a natural idea instead of if you look at a full attention diagram, it's dense and squared. That's where a lot of your complexity comes from. And basically the idea is you're going to just attend to only the past k tokens, which means that in the kv cache, as you're generating the sequence, you don't have to remember everything. As soon as the token kind of falls outside your window that you have attention, you can just throw it away. So loattention is very, you could say that the kv cache size remains constant as opposed to growing with a sequence length. So this is really good, right? Because that you means for even long c sequences you can have quite a small cache, okay. But you know the problem is that this still you know hurts accuracy because if you just think about it, like why are we doing attention instead of rand ends is that we needed to have long range model run long range dependencies. And this is in some sense even the call attention sion is a little bit kind of overselling this this is only looking at the local context, which is not very expressive. So what you do here is you can interleave local attention with full global attention hybrid layers. So for example, character I used for every six layers, they had one global attention, global layer and five local layers. So it looks something in addition to cross layer attention, so it looks something like this where full attention, every layer, you have to store the kv cache. And for what they did is that for every six layers, you have the full attention, but in between you have this local attention. And on top of that, they have kv csharing locally, both for the local attention and the global attention. So this is like all the tricks got of, you know, I'm not all the tricks, but many of the tricks got of combine you know, together. So in summary, these are a few ways to reduce the kv cache size because remember, inference is memory limited. So you want to reduce the cache size, but you don't want her t accuracy too much. And there's many ways to do it. You can lower the dimensionality of a kv cache. You can have few kv cache vectors. You can reduce the dimensionality of a kv vector, you can share the kv cache across layers, and also you can use local attention on some of the layers. Okay. Any questions about the set of tricks for reducing the kv cash? Yeah, Yeah, very question about this foattention. So like me like I feel like the westarall being shared across like the layers. Like do you just have like one set of weights? We're just like kv and that's shared across the blue. Yeah. So the question is, are the weghts so the kv caches shared, but the weights are shared. So what happens is the wefor doing the you know the projection need be shared. So there's some consistency there. Yeah, there's another question. The context size is too large and we now and then it increases the daily catch as well. When you come the context that is limited to the transmodel and it increases the finish risk. So when we try reducing those, Yeah. So the question is if you have really long contlet's, say your prompt is huge, that's going to intrinsically take a lot of kv cash. So all these tricks can try to reduce that. You can do more aggressive things that like you know there's ideas like you know just tokens or ways to summarize the prompt, which we're not going to talk about in this class, but there's ways to address the long prompt situation as well. Okay, so now I'm going to talk about even more radical ways of making inference go faster by changing the transformer. So the kvcache, these are basically variants of the transformer, but maybe you can actually go outside the transformer and and do better because the transformer wasn't really designed with you heavy inference workloads in mind. They would just try to try and train a good model that efficiently. It was mostly about training efficiency. And the regression, as we sort of pointed out, is really causing this kind of bottleneck here with the auregression plus kind of the full attention. So we're going to talk about two directions, state space models and diffusion models. This is going to be fairly quick. So the idea of safe space models is actually drawing ideas from single processing and control theory. Initially, the motivation was trying to model long context sequences without suffering the n squared blow up. So it wasn't necessary about inference speed. But it turns out if you solve that problem, you get faster inference too. So there's a kind of early paper on s four which uses classical state space models, which are basically these kind of linear dynamical systems, which are you've been used to kind of model long contacts and sort of shoehorning them into the kind of a modern neural setup this work has is no nice in that it has sort of this rnn kind of interpretation due to the linearity structure, and also has a convolution milk interpretation as well. So they published this paper and showed that it, I think, worked really well on these long context synthetic tasks. But what they found is that I guess what was discovered is that they don't really work a well for language modeling. And that's obviously a kind of a disappointment because a lot of the value of transformers is being able to do language well. So in a series of papers there, the sort of identified a set of kind of synthetic tasks that captured the essence of why these models weren't working well. And that's basically these associative recall tasks. So here's a synthetic tasts where you're given basically a, say, sequence of key value pairs. And the goal is to predict, basically look up the key and output the value. So in some sense, it's kind of a logically a trivial task, but it's long sequence because I can have a lot of key value pairs and I'm going to have to look far back. It can be arbitrary, long dependence. And you can see that local attention is not gonna to work very well because it's just gonna to remember the last few sequences. And the problem with staspace monos is that they were sort of good for these kind of signal processing tasks. But really, this is like you need to go isolate a particular key value pair and pull out the the answer. And for those hypoacid and Roy work, so there's a bunch of work, I'm not deciiting. There's like hyena, H H three and then mommba, which basically tweaked the or change the H ssm to basically handle these associative recall taand. Eventually it worked better up to kind of one b scale with matching transformers. And the m duidea of mamba has been popular and scaled up even to a 52b moe by AI 21. Folks notice that in this case, they still had to use a transformer. So a transformer, but only every, I guess, eight layers. They had a transformer. The rest of them were mamba layers. And so that still led to a fairly big savings and speed up. But more recently, there's this kind of revival of this older idea called linear attention, where instead of, let's see if I can make this bigger, it's actually a kind of a very simple idea. So you know what local attention or sliding window attention is? Linear attention is this idea that you essentially, so in the basically in the attention computation, there's a key in a query and you dot product them and you take the exp of that, which is basically giving you an kernel. So you can basically take a tailor expansion of that and write that computation as basically dot products of some nonlinear map. So then what you essentially have is you can think about for every key value position, you are basically applying some sort of nonlinity blowing up into some space and then doing some linear computation over it. And because it's linear attention, it actually kind of behaves like an rnn and it's linear in the sequence length rather than quadratic. I know that was a little bit fast, but I just want to give you sort of the taste of it. And so this idea has sn't actually scaled up quite successfully. So there's this organization called minimax that's training pretty legitimate models, up to 456 billion parameter moes. And they use this basically a linear attention idea. Now they have to use full attention still once in a while, it seems. I don't think people have been able to get around having some full attention, but at least it seems like people have been able to get rid of most of full attention in like most of the layers are not full attention anymore. They're just either linear layers or local attention layers, which are much, much more efficient. Okay. So the linear plus local attention you know now are actually yielding serious state of R models. And it's probably you safe to say that well, I don't know what's exactly the close model providers are doing, but I would suspect that there would be at least as kind of advanced in terms of as efficient as this in leveraging sparsity. So it's kind of an interesting question when people ask like, well, you know, is attention all you need is transformers it? Well, you know, yes and no. I mean, I guess in some sense there is still the sense it's squred there. Maybe we'll be able to get rid of it. But most of the transformer has been like pretty radically changed by having other much Sliter weight components, and you're still able to get much of the same kind of accuracies. And all of this is you know really helpful for inference because on these nonfull attention layers, you're basically replacing the ordered t kv cache, which grows as a sequence link with something that's constant. And there's papers by that follow up, I think on the based paper where they go, there's some either in this paper or in follow up work analyzing basically the trade deoff between the kv size and the ability to do various types of kind of recall tasks, which makes sense because if you don't store very much, you won't be able to solve certain tasks. But there is this trade off curve that you can try to play with. Okay. So that's all I'll say about the state space models. So now let's talk about a completely different style of generation models, diffusion models. So diffusion models have been very popular image generation, but they turn out to be fairly tricky to get working in text, although there recently have been some advances here. So the idea of diffusion is that you instead of generate auregressively, you just generate every token in parallel. So obviously, if you only do that via some simple layer is not going to be very good. You can't generate all the words in parallel and expect it to be coherent. But what you do is you iterate and you keep on refining this generation until it gets to your final generation that you output. And the idea behind genuine and parallel, you no longer are regressively bound and that generating all tokens in parallel well, can be done in parallel. So you get to saturate your GPU's relatively easy as all your context length is large enough. So recently, there's this incession labs has produced some pretty interesting models. There's not much written about them, but you can see kind of a demo of the generation and process. It just kind of generates code instantaneously, but it's obviously a kind of broken code and then it kind of refines over time. So and this is one of their benchmarks that show that at least on coding, I'm not sure about other tasks that if you look at that tokens per second, these models are like way out here in terms of speed compared to anything that's transformer, even jammbba. Remember, it was like a hybrid mamba transformer architecture is quite slow compared to these diffusion models. So now whether diffusion models are be will be kind of general purpose and powerful enough in all of these, that remains to be seen. But I think it's you know you have such a lead on the kind of the tokens speed here that even if you I think you can put more compute and kind of recover some of the accuracy losses if you need to. Okay. So the summary here is that I think this whole kind of architecture, novel architecture thing is actually really exciting for inference because they allow you to sidestep kind of fundamental obstacles, right? So if you're dealing with attention, you just have this fundamental kvcache obstacle that you can quantize, you can optimize, but it's still there. And so by making a kind of safe space model, you're shrinking that to like a constant size. And as long as you can keep up the accuracy, which is big if, then you win big time. Same with the diffusion models. Autoregressive generation is a key bottleneck. Now if you just generthings in parallel now, all of a sudden you kind of change the game you completely. So there's much more work to be done here in proving inference. So as you can see now, inthe inference game is much broader than it seems at first sight. It's not really about kind of necessarily the systems optimizations to make it fast, although you obviously need those. But I think the real gains are coming from like real radical changes in architecture. Okay. So if about ten minutes left, I'll go through these quickly quantization and model pruning. So quantization, the key idea is just reduce the precision of the numbers. So very easy to do. And the thought is that less memory means less bytes transferred higher. Sorry, there should be lower lane ency, higher throughput. And you do have to worry about accuracy. Of course, that's the trade deoff. If you look at the different types of formats, fe 32 use for training, not used for inference, really, pf 16 is sort of the default for inference. You can go down to fp eight or int eight, which now is less accurate, but much cheaper than even fp eight. So people do do a bunch of inference in int eight, which if you look at the range, I mean, it's an integer between 127-128, which is now that is pretty low precision and people are even going down to int four, which is they're not okay. So int four is pretty you know low. There's also other ways you can do okay. So you can so once you kind of decide that you want to quantities, I guess you could do several things. You can train with the quantization, but obviously that means you need to tree train all and more. I guess commonly you do post training quantization where you take an existing model and you try to quantize it and try not to screw things up too much. So the there's a paper called element int eight, which I'll talk through kind of briefly. So in quantization, basically what happens is that you take your vector, which is, let's say, fp 16, and then you need to figure out the dynamic range. If you want to pack into int eight, you need to figure out what the largest value is. And once you figure that out, you can kind of know, divide by that and multiply by 128, and then you get your integers. And then if you need to decquantize, then you kind of go the other way. So basically, quantization means that remember, memory is a bandwidth, right? So bottleneck. So all your transfers are happening in end date, but when you actually do, I guess you sometimes have to upcast to a floating point to actually do the arithmetic. Okay? So the problem with inate is that not everything fits nicely. And you have these outliers which appear in larger networks that screw things up. So what this paper did is that you take this matrix, you identify the really large Allier values, and then you handle them separately, used in full 16 bit precision, and then do the most vast majority in 88. So this works well, but is actually a bit slower. So the motivation here wasn't an inference speed by more even being able to feed your model into memory. There's another paper called activation aware quantization. And here the idea is that you're kind of quantizing of the weights, but you're going to figure out which weights to quantize based on the activations. You know, really quickly this you're going down to actually in three. And this obviously reduces memory by quite a bit and leads to a three x speed up. And so the general idea here is that you want to you get a train model and just happens that some of the weights or activations are going to be abnormally large. So for those you handle separately and then everything else you can kind of work in low precision. Okay. To talk about model pruning ideas, very like quantization, it's sort of the basic idea is very simple. You just rip out parts of an expensive model to make it cheaper, and then you fix it up. So in this mvd paper, what they do is they first identify important either layers or heads or hidden dimensions using a small calibration size. They use some simple lest scores to compute that. And then you just remove the unimportant layers or hidden units or pads. And then now if you just take that model, it's going to be clearly worse, right? But so then the last step is you distill the original model into the pundent model. So you kind of repair the model from the initialization, which is your prune. So you're not starting from scratch. You're starting from something that's worse, but hopefully not worse, and hopefully retains a lot of the same structural properties of the original model. It's just maybe not kind of like calibrated in some sense, and the results are pretty good on that. So they have these 15 billion parameter models that they're able to reduce to eight b with hardly any drop in this, I guess, at least according to mmlu, and then down to four b with some drop, but you're also going down quite a bit to a four b model. Okay. So maybe just summarize the taking shortcuts idea. You can reduce inference complexity without Herting accuracy. You can do it from scratch where you just define a fresh architecture that's by construction fast and just train it. Or you can distill, you define architecture, you can take a slow model and you figure out as some sort of scheme to initialize the new model with the old model, and then you basically do distillation. So okay, so now all of these are a little bit unsatisfying because they're lossy, so you get massive speed ups, but you always wonder, well, maybe this model isn't as actually good as original. So speculative decoding or speculative sampling allows you to basically have your kick and eat too. So recall there's two stages of inference. You prefill, which you're given a sequence, you encode all this tokens. And parallel, there's a compute limited, which is great. Notice that this also gives you log probabilities for each of the tokens. And then there's generation, which is one token at a time. Its memory limit, it's slow. So in other words, checking is faster than generation. So intuitively, this makes sense. But hopefully now you also appreciate the math behind why this is. And the spececule of sampling idea is actually really, really, again, it was proposed in parallel by these two independent teams from Google. And the idea is to use a cheap draft model p to just run ahead and generate some tokens, and then you're going to evaluate those tokens with a target model. And because evaluation of given tokens is just prefill. So you can do that in parallel, which is fast, and then you accept it if it looks good. So this is what it looks like in real life. So if you're using a big model generating one token at time, that's slow. But in speculative decoding, you have a draft model that's racing head, generating a lot of tokens and using the big model to essentially verify and sometimes itreject and sometimes itaccept. And the acceptance rate basically determines how fast of a speed up you have. Okay. So here is the more formal algorithm. So you're going to have a look ahead of k. So you're going to use your draft model and generate k tokens autoregressively, so this is hopefully fast because your draft model is small and then you're given these k tokens that you generated, and I'm going to score them based on I'm going to compute the probability under the target model q. Now I'm going to decide whether I want to accept this not or not. So I go through each token and I'm going to essentially accept it with probability q over p. And the one just is make sure this know probabilities are between zero and one. If this kind of looks like if people arenfamiliar with metropolis, Hastings is kind of where this kind of comes from. So intuitively, you're sampling with p. So you need to divide that out because you want p, you are q. So this is kind of an importance weight on this. So if you accept it, then great. You kind of move on and you look at the next draft token and so on. And if you don't accept it, then you're going to sample from the target model, the slow model, but you kind of do this correction where you've already tried to sample using p. So you don't need to do that anymore. You subtract it out and you sample from q. So this is basically kind of a rejection sampling with a proposal p and a sorry target q. The only difference is that you are sampling, you injection sampling. If you reject and you reject and you just try again and then try again, and here we don't want to keep on kind of looping forever because if you reject, we're just going to say, okay, fine, we'll bite the bullet and just sample from the more expensive model. So the cool thing here is that you're guaranteed to get an exact sample from the target model. Okay, so those of you familiar with sampling, this is shouldn't be too surprising. You're able to use kind of prior information to speed up sampling. But in the language modeling context, this is kind of nice. I'm gonna to skip the this is not really a proof. This is just kind of some derivation to show that for a case of vocab two, why this? These formulas kind of give you the right unbiased sampling procedure and it works pretty well. So the accuracy you it should be actually the same since it's the same model, but maybe there's some randomists there. But the speed up is you're getting a factor of two speed up essentially. So in practice, what you do is you you have something like a 70b model in your draft model is much, much smaller. And if your target model is 80b, eight b, then your draft model might be one b, and you generally want to make the draft model as closest to your target as possible. And so if you're doing some distillation that could make it even better. There's a of this is a pretty hot area of research and inference. There's a lot of ways to improve this process. You can use Medusa, which is this way, to have the draft model instead of generally auregressively sample multiple tokens in parallel or eago, where you're actually taking high level features of the target model and pumping them into the draft model to generate. So the draft model doesn't actually have to stand alone. It can be kind of glom onto to the target model to help it generate. So summary exact sampling from the target model thanks to math. And this exploits the symmetry between checking and generation, right? So or prefill and generation. And there's actually a lot of room for innovation on a draft model, which can you know everything that we've talked about before where you can have different radical architectures, different ways of quantizing, all of those apply. The only thing is that you get to basically guarantee that you're getting exact sample. Okay. So now I'll go out of time, but quickly go through the question that came up earlier, which is that in practice, when you're serving, there's live traffic requests come at different times. If they finish at different times, they have some of them have shared prefixes, some of them don't. They have different lanes. So it's very heterogeneous in comparison to training where you get basically a dense block of tokens and you're basically going na push it through your GPU at full speed. So what do you do in this case? So there's a series of papers that kind of explore this. And the basic idea is, so the last two parties are more of a kind of systems level contribution. So the idea is that you don't wait for batches to you. The train leaves, there's the train doesn't wait for you. So when a new batch comes, you're just going na put it in, which means that sort of the worker that's generating tokens needs to kind of hand control back to the scheduler every step. So you generate a token, come back to the schedule and say if there's new, new request, then they get stuck in, and then it kind of continuso. You're kind of not wasting any time waiting around for requests. Now there's kind of a problem with batching, I think, which is behind the question. Batching works when everything thingthe same dimensionality, but every quest might be a different length. So there's this idea of selective batching, where you basically break up your computation for attention. Everything has to be handled separately. But for your mps, remember, which are the bulk of the computation, you can actually take tensors of different sizes and you just flatten them because they don't interact. They can just like be kind of kind of along for the ride in the batch dimension. Okay, I know that was fast, but I'll just quickly go over page attention. Now this is the paper behind vm, which some of you probably have used, and this addresses the memory usage problem. So if you have a kv cache and prompts are coming in and finishing, then your cache is going to get fragmented. You're going to allocate a bunch of space for a request, but you don't know how many tokens are going to generate it. So there's going to be internal fragmentation. And then there's also going external fragmentation where there's padding between the requand responses. So that's no good. So the page attention basically says, remember operating systems we have and how kind of virtual memory works? We divide a kv cache into a sequence of contiguous blocks, and then we just put them wherever we find y space. So if you have two requests coming in, then they might just know the first request might be here, here and here, and the second request might be here and here. So the blocks are the things that you're going to keep contiguous and that's going to give you your allowing time to coalesyour memory. So you can also play these tricks where if you have sharing of prefixes, then there's another idea from operating systems, which is copy on, right? So you basically maintain reference counters for how many basically sequences are using this particular block. And then if you need to kind of diverge and have blocks go in different directions, then you copy and you reduce the reference count. There's a bunch of other vm optimizations, which I won't go through, but basically the summary is remembering your operating systems classes. You can apply them to inference as well. Okay. So quick summary. Inference is really, really important. It's where and the characteristics are distinct from training. You're memory limited and it's also dynamic. So which leads a bunch of new challenges. We saw a whole host of different techniques around new architecture, quantization, pruning, distillation, speculative decoding. There's ideas from systems which can allow you to better use your memory, overlap, communication and compute and things like that. But I would say that there's probably even more opportunity in sort of the modeling and the architecture, because if you think about it all, you you don't inference ninarally is inference in a particular model. How do I run this particular model? But who cares? Worry about that particular model. You care about delivergood accuracy given your resource budget. So a lot of these ideas that are trying to the reduce the kb cache, changing the transformer are basically ways to sideset the problem and say, well, I have something that's more efficient and now I can train in a way that gets me better accuracy than I win. Okay, so that's all I have, and I will see you next time. Then we're back to skating the lots.