Stanford CS336 Language Modeling from Scratch | Spring 2025 | 05 GPUs
演讲者首先介绍了课程作业的安排,并点明本次讲座的核心内容是图形处理器(GPU)。GPU对语言模型的运行至关重要,讲座旨在揭开CUDA和GPU的神秘面纱,帮助理解其工作原理及性能波动的原因,例如为何在特定矩阵乘法规模下GPU会变慢。学习目标包括让听众熟悉GPU,并能够利用CUDA等工具加速算法,例如理解FlashAttention这类高效算法的构建基础。演讲者提及了硬件发展的重要性,指出深度学习的进步得益于更快的硬件、更优的利用率和并行化。接着,演讲回顾了计算能力扩展的历史,从早期依赖登纳德缩放定律(Dennard scaling)提升CPU单核性能,到该趋势饱和后,转向并行计算的必要性,这也是GPU发展的关键。演讲者对比了CPU和GPU的设计理念:CPU侧重于低延迟,拥有复杂的控制单元以快速完成单个任务;而GPU则侧重于高吞吐量,通过大量并行计算单元(如ALU)同时处理多个任务,即使单个任务延迟可能更高,但总体处理效率更高。最后,演讲者初步介绍了GPU的内部结构,核心概念是流式多处理器(SM),每个SM包含多个流处理器(SP),SM负责控制逻辑和任务分发,而SP则对不同数据执行相同的指令,实现大规模并行计算。
标签
媒体详情
- 上传日期
- 2025-05-13 16:31
- 处理状态
- 已完成
- 转录状态
- 已完成
- Latest LLM Model
- gemini-2.5-pro-exp-03-25
转录
speaker 1: So hopefully everyone's having a good time with assignment one. It's due tonight. Let us know if you need an extension. Assignment due two is coming out soon. We're putting on the finishing touches onto some of the Triton stuff. Hopefully you'll enjoy it. You'll get to implement flash attention two or parts of flash attention two, which I think will be nice. So today we're going to talk about GPU's. GPU's are the thing that makes our language models go. So they're pretty critical to get right. And if you haven't really studied you, the hardware that makes your models run, they can seem pretty mysterious. So my goal today is to try to make kuda and GPU's less magic. And one of the things that I want to demystify, you don't have to understand the plot. There's a lot on the slide, I know, you know, why do GPU's get slow? And they get slow in very mysterious ways. You know I will try to talk through this plot towards the end of lecture. As you increase the size of your matrix multiplies. You might expect you know the either gets slower or faster or whatever. You get these very unpredictable looking wave like patterns and you're like, why is my GPU fast at certain multiples of certain numbers and slow at others, right? That's very mysterious. We'll try to understand that. The other thing is we would like to understand how to make fast algorithms. I think almost all of you have heard of flash attention. It's the thing that makes much longer context possible by very cleverly computing the attention operation inside a transformer. And so maybe you would like to know, come up with new algorithms or new implementations like flash attention, right? Like what primitives and what components do we need to understand in order to be able to do that? Right? So those are kind of the two learning goals of today. The first one is, you know by the end of the lecture, you should feel kind of comfortable with GPU's. You should kind of understand how they work. And the second one is you should feel comfortable accelerating certain parts of your algorithms. You make a new architecture, you should hopefully feel like you can try to accelerate that with kuda. And because hardware is not necessarily the domain in which I work, you know there's special resources that I have to give a lot of credit to, especially Horace hablog, where he's got a lot of fun GPU facts that you can learn about. For example, why are matrix multiplies that are filled with zeros faster than ones that are not filled with zeros? You can learn by going to his blog. There's also other resources that I've drawn from, like the kuda mode group and the nice tpu book from Google. If this topic interests you, you know I'd encourage you to go and look at those resources to learn more because this is in some ways like a shallow but hopefully complete coverage of the hardware. So today we're only going to focus on non parallel parts of the hardware stack. So we're going to study the GPU like a single accelerator in depth, how they work in some important parts. I'm also going to talk very, very briefly about tpu's because in some ways, they're very similar conceptually to a GPU. And so my discussion here is going to Carry over. And then once we understand kind of the hardware and execution model of the GPU, then we're gonna to try to understand what makes GPU's go fast on certain workloads, what makes them slow. We're gonna to understand the performance. And then the last part, this is kind of going to be almost like a hands on piece. I'm going to try to walk through flash attention, right? I'm going to take all the lessons that we've learned and try to walk you through flash attention saying, see, here's how it all comes together. So that's the last part of today's lecture. So you know, many of you have taken an nlp course, and these days in an nlp course, I think you teach some amount of scaling laws. And so you've probably seen this, right? And so this is just setting the context. We know that having more compute is helpful for training large language models. This is a pre training scaling chart, but you could replace this with an inference scaling chart if you would like. It's generally agreed upon that the more compute you have, the more processing you can do on your data. You can adjust more data, you can train larger models. All of those lead to improperformance, right? So you might think of, of course, you know deep learning is really important, but what's really driven performance is you know faster hardware, better utilization improparalylization, right? So that's kind of setting dying the stage of why hardware is important to understand. And of course, once you think about compute scaling, you ask, okay, how do we get compute scaling? How do we get our models to train faster? So kind of in the early days of semiconductor scaling, if you were thinking about, ok, our cpu's, how do they get faster? They they would scale under something called dennered scaling, right? With Moore's law, you would sort of double the amount of transistors on a chip every year. And if you have this doubling, what you end up is Denard scaling, where smaller and smaller transistors can be driven at faster and faster clock speeds with lower and lower power, which in turn give you more performance, right? And then in the 19 eighs to 2000 ands, this sort of tapped out, you can kind of see in this chart here by Hennessy and Patterson that single threat performance, that's the blue dots here that basically started to taper out. Of course, a number of transistors didn't really start falling off. You did have no chips with higher and higher transistor densities, but that wasn't helpful. It wasn't giving you higher throughput on single threats. And so this means that we can't just do computation faster in absolute terms. You know what we have to make up for it with is parallel scaling, right? So the story of scaling for deep learning and neural networks is going from single threat scaling, which is us doing your computation faster in absolute terms, to parallel scaling, where you have a lot of workloads that are all computed at once. And this is one of my favorite sort of compute scaling charts by bill dwley in his keynote, where you know he's showing the super exponential increase in the number of sort of integer operations per second going from the earliest k 20s to the H -100. And it's kind of like this really remarkable exponential or super exponential curve. And so you we have to really understand how to take advantage of this curve in order to really get the most out of our language model. So that's kind of going to be our goal. And so I've already hinted at this kind of important difference, right? Cpu is something that I think everyone's familiar with once you start doing programming, right? It's this execution model. If you have a program, it goes through and in a single thread, it executes step by step what's happening. And in order to support that kind of an execution model, what do you need? Well, you need big control units, so you just need to generally run these things very quickly because you have a lot of branching and you have a lot of conditional control logic, right? So the cpu, this is an abstracted diagram, is going to dedicate a lot of its chip towards large Control Branch prediction, and it's going to run these very quickly because it doesn't have that many threads. There are cpu's with lots and lots of cores now, but compared to a GPU, it's almost nothing. And so in contrast, the GPU has really tons and tons of compute units alus, right? So there's the little Green boxes and there's much smaller amounts of the chip dedicated to control. So there's a little bit of control logic sort of orchestrating tons and tons of compute units operating in parallel. And I think mentally, so this is kind of the picture of what is being emphasized in a cpu versus GPU. But if you kind of look at what the design goals are, they design for very different sort of goals. So you can think about cpu's as optimizing for latency. I want to finish my tasks as quickly as possible. So if I have tasks t one through t four here on the right side, you know in a cpu, I'm gonna to try to finish each task as quickly as possible. And so if you want any one of these tasks to be finished quickly, t one is gonna to complete really quickly. In GPU, you're optimizing for high throughput. Like I don't care about latency. I just want all of my tasks that I have in aggregate to complete as quickly as possible and to support that. You know maybe you have lots of threads and these threads can go to sleep and wake up very quickly. And in the end, you finish all of your workload key one through t four before the cpu one does, even though individually all of these have sort of higher latency, right? So they have different sort of design principles and design goals. Okay. And so a GPU has a pretty different anatomy. And I don't know if you all have ever looked at what a GPU sort of layout diagram looks like. I'll actually show you the chip figures in a moment here. But the core idea, and this is important conceptual concepts behind a GPU, is that a GPU executes many, many sm streaming multi processors. And a streaming multi processor you can kind of think of as an atomic unit when you're programming in something like Triton, they're going to operate at the level of an sm. And within each sm, it contains many sp's streaming processors. And a streaming processor is going to execute a whole bunch of threads in parallel. So one way to think about it is sm has a bunch of control logic. It can decide what to execute. It can do, for example, branching. Sps are going to to take the same instruction and apply it to many different pieces of data, right? And so you can do tons and tons of parallel computation under this model. An sm is sort of each granular unit of control. Sp can do a lot of computation individually. And if you look at an a 100, which is the previous generation GPU, at this point, you've got 128 sms. Know that's a lot more than the most cores for cpu's. And each of these sms is going to have a very large number of sp's and specialized sort of matrix multiply units inside them. And so that's kind of the compute model. speaker 2: Was there a question sorry to get a slide before the so in this, you can use the same as . speaker 1: so the question was, is this gp the same as that to you? Yes. Like this is a kind of cartoon version of this. You can kind of think of each row as being a sm. It's got its own control units. Each Green block might be sort of one of these Green blocks here, like a fp 32 sort of processing unit inside of it. And each sm can sort of operate various pieces that it owns, like the tensor course Res to do computation. Cool. Okay. speaker 3: And there's going to be two important things. speaker 1: You think of GPU as computers, they compute. But actually, computation is only one of the two important things we have to keep track of, right? Memory is arguably more important at this point, and it will continue to be more important in terms of the performance profiles of how we run our programs on the GPU. And so to understand memory, you kind of have to understand the physical layout of the GPU and the chip because in some sense, you know when you're operating at such fast speeds, the physical proximity of the memory starts to matter quite a bit. And so I will show you kind of the physical proximity of how things are laid out and how that relates to how you should think about memory access and performance. So the closer a piece of memory is to each sm, the faster it's going to be. So there's going to be certain very, very, very fast kinds of memory like l one in shared memory, and that's going to live inside of the sm, right? And that's gonna to be really fast. Things like registers, things like things you're reading and writing very frequently, you're gonna to want to put into the l one and shared memory, l two cache, as you can kind of see, there's these Green areas, which are sms, and then there's these blue areas. This is on the GPU chip. These are l two memory that's sort of right next to the sms. So they're not inside the sms, but they're physically still quite close and these are still pretty fast. They're still a factor of tenslower, but they're still reasonably fast. And then outside of the chip itself, this is sort of I think this is like a 3090 card or something like this or maybe a pcia 100. Oh, this a pcia 100. You know you've got your GPU here and you've got actually dram sort of living next to the chip. So it has to actually go physically outside of the chip and connect. Then you can kind of see on this chip diagram here, these yellow connectors at the edges. These are hbm connectors. These are connecting to the dram chips that are outside of the actual GPU. And you can kind of see the speed that it takes to access these, right? The on sm memory is much, much faster, like 20 clock cycles to access something from there, whereas it's gonna to take something like 200 to 300 clock cycles to access something from the l two cache or global memory, right? And this factor of ten is going to hurt you real bad, right? So if you have a piece of computation that requires you to access global memory, right, it might mean that you actually run out of work to do on your sm. You multiplied all the matrices, you run out now you just have to idle. So utilization won't be good. And this will be a really key theme, thinking about memories in some sense, the key to thinking about how GPU's work. speaker 3: And in assignment two. speaker 1: you're going to actually be writing high performance code for a GPU. So you have to actually think about the execution model of how a GPU actually executes things. And this is somewhat complicated, but not insanely so. There's sort of three granularities of things that you need to think about. There's blocks, there's warps and there's threads. And that's the order in which kind of the granularity narrows down, right? Blocks are kind of these big groups of threads, and each block is going to be assigned to an sm. So think about this, as each sm is kind of a worker, it's its own autonomous unit, and a block is going to be assigned to an sm to process. So this is each granular unit. Now then within these blocks are a whole bunch of threads. Each thread is a sort of a piece of task that needs to be done. And when these threats execute, they're going to execute in groups. And this is a thing called a warp, right? So you take a block, which is a collection of threads, and they're gonna to take threads from that block, and they're gonna to execute in groups of 32 consecutively numbered threads each time. And that's sort of called warps. And so you can kind of see at this diagram here what's happening. You've got a bunch of blocks. Each block is assigned to a different sm, and within each block there's going to be many different warps. Each warp is going to consist of a whole bunch of threads. And all of these threads are going to execute the same instruction on different data, right? And so this is kind of the execution model right now. It seems probably mysterious what these blocks and warp and threads are. They will have important implications for a performance in how we design things like kuda kernels later. So hopefully you can kind of remember this. I'll refresh your memory kind of as we go. Hopefully that's clear. So that was the kind of logical execution model of a GPU. And if you understand that, you kind of understand how GPU's execute things. There's also a logical sort of memory model of a GPU. So now I'm not showing you the physical hardware. This is just kind of how you think about the programming of a GPU. And so there's registers. So these are really fast, know, storing single numbers type storage. You've got local memory, you've got shared memory and you've got global memory, right? And that increases in sort of the memory higher. You get slower and lower and slower. And your code can sort of write to global memory, can also write to constant memory, which is not something that's used too often. And so each thread can access its own register and shared memory, but information that goes across blocks need to be written to global memory. This is actually quite important, right? So now it means that you know whenever you write a thread that executes something, ideally it's operating on sort of the same small amount of data. So you load that small amount of data into shared memory. All the threads are very happy accessing that shared memory. It terminates, it's done right. That would be a great execution model. Instead, if you have a thread that needs to access data all over the place, you know that's gonna to have to access global memory. That's very, very slow. This thing will come back. You know as we talk about different ways of operating on a GPU, hopefully that's clear. That's kind of the very high level four slide overview of a GPU. If you have kind of questions about how any know that works, feel free to ask me as I go on. Okay. So here's a side thread. Last year, I didn't cover this because I think resources on tpus was a little thin, but the nice tpu book or Internet website that I mentioned at the start of the lecture came out, and that has actually a lot of nice details. And I talked to a few Google people about the tpu, and at the high level, it's very, very similar to a GPU. And so I wanna just talk for a moment about tpu's. You may never you know operate on a tpu, but I think it's important to understand these alternative accelerators operate in many ways very similarly. So here's a diagram of what a ppu looks like, kind of a so there's something called a tensor core. And mentally, you can think about a tensor core as being similar to a sm or streaming multi processor. Each of these are kind of its own atomic units that can operate on data. There's a scalar unit, which is basically a control unit, and it can also do cpu like arbitrary things. You've got a vector unit that can operate on vector. So if you've got a vector and you want to operate entrywise on it, that's a good place to do it. And then it's got a very big specialized you know part of the chip dedicated to just doing matrix multiplies called the mxu. And then it's got very fast memory for vector memory and smmem. Both of these are very fast on chip or like on tensor core memory. And then there's high bandwidth memory that lives outside of the chip, right? So hopefully you see the similarities to an sm, right? There's slow memory outside, very fast memory inside, and there's specialized hardware to do matrix multiplication. Core structure is very much the same. The difference is, I'll talk about this in the parallelism lecture next week. How the accelerators are networked together is a little bit different. And then also mention, I didn't notice, I didn't talk about warps, I didn't talk about any of that other stuff. Tensor cores are in some ways very simple because they're optimized to just do matrix multiplies, right? Like the tensor core, unlike the GPU, doesn't attempt to do anything but that. And so that's in some ways very, very simple, much simpler in architecture, but conceptually doing the same thing, yes, is . speaker 2: senin some ways optized. speaker 1: Yeah so the question was, you know is it called tensor? Because it can operate on arbitrary tensors? So it can operate on arbitrary tensors. I can do the indexing. The operations that mxu performs is a matrix multiply. And so it would always be like a batch matrix multiply operating on a tensor. So it's kind of both a yes and a no answer, if that makes sense. So they operate on tensors, but the operations they always perform are matrix multiplies, not more complicated tensor operations that you can do cool. The reason why the GPU has been so successful is that know it scales up really easily. If you want more processing power, just add more sms, right? You don't have to worry about driving the clock faster and getting more heat dissipation problems. Programming wise, kuda is intimidating, but it's actually you not as horrendous program because of its programming model, like the way it works is within each sm, right? You have a thread and it executes the same instruction on a bunch of different pieces of data. That's conceptually sort of easy to reason about. You can think through what that means. And especially it's nice if you're operating over a matrix and you're doing sort of very simple operations is exactly this kind of sim key model. Finally, each of these threads are very lightweight and they can be kind of stopped and started at any time. And so if you need to wait for another thread where if you need to sort of like evict something and start another process, all these threads are very lightweight. So this just kind of means that there's not much state associated with the threads and they can kind of be stopped and started, which allows GPU's to get high utilization within sort of each sm. So GPU's know, obviously, graphics processing units, and for much of its life, know in the early days, it was not used to do scientific computing, but know people because it was programmable. Researchers figured out how to use early nvidia GPU's to do fast matrix multiplies. This is one of the early papers on doing fast matrix multiplies with graphics hardware. And it shows you how you can hack kind of things like the texture buffer and so on to get it to do matrix multiplies, right? And so you know, even without specific support for Matt moles, researchers figured out how to do it. But I think now, you know, especially in this day and age, nvidia and others have realized matrix multiplies are special. Like if you're doing deep learning right, most of your workload is matrix multiplies. And so matrix multiplies are in some sense blessed operation ations. So this is a chart showing the number of terafhops per second by different generations of nvidia GPU's. And the orange line is your map mole flops. Like with your performance, you can get if you're doing map moles. The blue line is your non map moflops, right? And you see kind of this big, big gap at V1 hundred s when they started putting in sort of tensor cores that were specialized hardware to do matrix multiplies. And you see this gigantic gap in the matrix multiply performance relative to the non map mole performance. And so if you're going to design any sort of a neural architecture I was saying this in the architecture part as well, you have to have most of your workload b matrix multiplies because that's the thing that know orders of magnitude faster than any other operation that you're gonna to be able to do on a GPU. So if you make a non map mobased in neural network. speaker 3: they're going to be in a big. speaker 1: big trouble. And then kind of the last thing that I want you to kind of understand as just general facts, you know matmoals is fast is one thing. But the other thing that's important to remember is kind of the relative scaling of the different components of the GPU. So this is a very nice chart that shows how quickly different components of the GPU or different components of the, let's call it like llm training stack are scaling. So the blue line is the connectivity from the GPU to the host, right? Like the server that it's attached to, right? So you can use pcie, you can use nvlink, you can use all these fancy interconnects. They are growing, but they're growing somewhat slowly. So this chart is like normalized scaling bandwidth relative to when the first generation of interconnects, the Green line, this is the global memory speed. So you go from gddr to hbm qe, and that's much, much faster, right? This is log scale to 100x faster, but this is still kind of slow scaling, right? And the gray line here, this is compute scaling. This is the number of floating point operations. If you're considering the matmoflops, this is how fast the compute has been scaling. And this is astoundingly fast, like one to 100000 times faster. And so kind of in the early days of the scaling, maybe your problems were flops based, right? Like you just didn't have enough flops to do your matrix multiplications. But now all the way to the right with the H -100s, these are astoundingly fast GPU's. Your bottlenecks are probably going to end up being memory because memory is not growing as fast. And as we go into the future, this is not really going to change. Dram is very hard to scale. You're going to keep getting this bigger and bigger gap. So if you're ever designing hardware efficient algorithms, you're going to have to think more and more about memory. And so we're going to keep a lookout on that. I'm going to keep emphasizing this. It's one of the important themes in GPU's. So you know I've been kind of throwing lots of GPU facts at you, especially if you haven't seen this recently and maybe me kind of new. So just to recap, right, GPU's are these massively parallel processing systems. They have same instructions applied across many different threads, and they have these things called sms, which are kind of like cores that you know there's many, many of them in the GPU's compute and matrix multiplies have scaled really fast and they have scaled faster than memory. And that is an important part of the characteristics that you think about GPU's. But there is some fast memory. It's not like everything is slow. So there's nothing we can do. There's the memory hierarchy, right? So some kinds of memory are very, very fast. Other kinds of memories are slow. And so if we exploit this hierarchy, maybe we can get things that are really, really fast, right? So that's kind of things to remember about the GPU. And if you remember these facts, you you're going na be able to think pretty cleanly about the performance components that I'm going to talk about next. Any questions before I move on to the next part? Okay, cool. So now you all are GPU experts. And what we would like to do is we would like to make machine learning workloads go very fast on a pu. And so I'm going to start with this chart. And one of our goals will be to understand what this chart exactly is. I think itbe a good puzzle to get us motivated. And so here what we are doing is we are multiplying square matrices together, right? So the x axis is the size of my square matrix multiplies. And know the y axis here. This is the number of operations per second that I'm doing. So you can kind of think of this as hardware utilization on the y axis, right? And so as I get bigger and bigger matrices, I'm gonna na get better and better hardware utilization because I have more work to do. So that overwhelms the overhead of sort of launching jobs and things like this. But there's all these weird things that are happening, right? You see one, two, three different, four different lines, right? And each of these lines are kind of wavy in a way that's kind of looks very unpredictable, right? And so we would like to kind of understand what exactly is going on with these lines. And by the end of this section, my promise is that you will kind of understand exactly each one of these phenomenon. You'll be able to say, Yeah, that plot looks totally normal. That is a natural thing for a GPU to do. Okay. So the very first part, right, is if you look at that plot, you will notice that it looks a little bit like this, right? And if you've taken a systems hardware, of course, know you should remember this as kind of the roofline model. The roofline model basically says, if we're looking at, know, throughput or utilization, you know what we're gonna to find is you know there's two regimes. There's going to be a regime that is sort of memory limited, right, that is on the left side of this curve, on the Green over here. And then there's a part that is throughput limited on the right side. In some sense, you can kind of think of it as on the right side, we are fully utilizing our compute units. All the matrix multiplying units are multiplying all the time. And on the diagonal here, we just have some sort of memory bottleneck. And so our ability to do computation is limited by kind of the amount of sort of intensity that we have, the amount of flops for byte that we have. So we want to avoid being in this left side region where we're memory bound, and we would like to be on this right side where we're getting in some sense, full utilization of all of our compute units. So that's in some sense the goal. And hopefully, this roof ine model looks something like this, right? Like we've got sort of this diagonal part and then we've got this flat part all the way at the top here. So that's one part of the mystery. And so this turns out to be kind of complex, right? The simple way to say this is let's make sure that we're not accessing memory unnecessarily, right? We have as few memory accesses to slow global memory as possible. But it turns out that in order to do that, we need a large array of tricks. There's a lot of different things that you could do that would mess you up, that would make you very slow. And the first one's not a memory bottleneck, I'll just mention it. It doesn't come up too often. We'll get it out of the way and then we'll talk about the remaining five items that in some sense are really core to thinking about GPU performance. Okay. So the first thing that I want to talk about is conditionals. So as I said before, GPU's their execution model, something called simm key, right? Single instruction, multi thread. And so every thread in a warp is going to execute the same instruction, and it's going to do so on different data. And so what happens if I write a piece of code that looks like this? I have if statement. And if you know the thread index is less than four, do something. If the thread index is greater than or equal to four, then do something else. I have this very simple conditional model. If I run this on the GPU, what's going to happen is that I am going to run the a instruction on four my threads. I will actually pause my other four threads, which are supposed to be executing the l sparand. Then these other four threads will come alive and they will execute x, and my original four threads will go to sleep, and then I will just alternate executing each of these instructions. Why is that? I can't execute a and x at the same time on these different threads, right? As I said again, every thread has to execute the same instrucstruction. So conditional statements within a single warp can be really, really damaging because they will force you to pause any of the threads that are not doing exactly the main sort of control flow execution. Okay, so that was the only non memory thing that I wanted to mention, and it should be kind of obvious that you should probably not be putting conditionals into sort of your massively parallel compute unit. But once we've gotten that out of the way, sort of the other tricks that we need to consider are all kind of memory based. The first thing I want na sort of mention is lower precision. And this is a big trick. This is an important trick. You should do it all the time. There's kind of going back to the plot of bill dley. There's a slight of hand here. This looks really good because the numbers are going up and up and up. But if you look at what's driving GPU progress over all these years, you actually kind of see that it's number representations. You go from fp 32 to fp 16 to intake to so on, you get many orders of mancute gains from just having lower and lower precision in your GPU operations. And let me sort of clarify why that's so important, right? If you have fewer bits in all the things that you're computing and your weights and so on, you have much fewer bits to move. So even if you're accessing these bits from global memory, they become much, much less of a concern. So let's just give a simple example and let's just think about kind of arithmetic intensity of a simple element wise operation, right? So I'm going to do it in rso. That's x equals max zero and x. And I'm going to do that on a vector of size n, let's say naively, I'm going to do this on float 32. So how many memory access to this do I have? I have to read my X. I have to write the result of if x less than zero. And that's all in float 32. So that's kind of eight bytes. And how many operations do I do? Well, I have to do x less than zero. So that's one comparison operation. And I do one flop. So I do eight bytes per single floating point operation. If I do this in float 16 now, well, I haven't changed the flops intensity here, but I have the memory access. And so now I have four bikes per flop, right? In some sense, I've like gotten double the memory bandwidth for free, assuming that I can get away with flop 16. And this is a key part of how a lot of things are designed. Part of the assignment is going to be you're going to try and play with various mixed precision or low precision training and other kinds of things. And a key part here is that not all the parts of your network and your training algorithm should be put into low precision. So let me give you an example of matrix multiplies. So if matrix multiplies that are mixed precision, what you would do is you would have your inputs be 16 bit its so these are low precision and then you're going to do your multiplication in full 32 bit, right? And that's useful because the intermediate compucations, as you're like accumulating partial sums, you would like that to be in high precision. And so you're accumulating this with the fp 32 accumulator and then your tensor core will return an fp 32 result, which you can downcast, if you would like, back into 16 bit. And so we have our inputs in 16 bit, but things like the accumulation we might want to do in 32, right? So there's lots of different things. There's operations that can use 16 bit its storage. There's operations that might need more precision. So you want to keep it in like either fp 32 or fp 16. You might want to have operations that need more range, like x functions. If you don't have sort of the dynamic range, they might blow up or zero out. And so you might want to put those in df 16. There's a lot of sort of careful engineering that has to happen in order to make sure that these models are actually stable when they're being trained with lower precision. But if you can do it, that's really great because you've basically doubled the throughput of your bottleneck going from 32 to 16 bit. If your memory is your bottleneck. Okay, the other one, and I think this is kind of what a lot of people think of when they say, like I'm going to write a couda kernel or something. Operator fusion is kind of both very intuitive and both a fun, natural one to think about. So one memory, sorry, one mental model of how a GPU works and how memory works is this kind of fun diagram of a factory from Horus heat, right? So imagine you have a factory, and your factory is your compute part, right? And so you know, it takes in little box widgets and then outputs little triangle widgets. And if you grow your compute, but your bell conveyor that takes memory to compute is finite bandwidth, you're not going to be able to use your second factory, right? You're still capped by the speed at which you can transfer things from memory to compute. And so you've got this bottleneck. Now, of course, you already knew that, right? I've been sort of hammering in the memory bottleneck thing, but I think one insidious way in which you can incur a ton of overhead without really realizing it is kind of this left hand side computation pattern. So imagine the left side of this, this plot is where the memory is. The right side is your compute unit. And so to do computation, I start with a square, and I move my squares from my memory to my compute. I do some operation. I turn them into triangles. Right now, I shimy triangles back to memory. And then, you know, okay, I realize I need a triangles again. I shithem back into the compute unit. Now the triangles become circles, and so on and so forth. I send my compute sort of back and forth and back and forth back to memory. And you might call this kind of a very naive approach. And if you were just doing operations naively on the GPU and just shipping the results straight back to global memory, this is what youend up with. And if you count the number of times a piece of data went back and forth, this is pretty terrible. You've incurred tons of memory overhead. Now you should be able to realize ze, that if you look at the right side, well, this compute, well, there's no dependencies. So I should be able to go square to triangle, the circle to rectangle, and ship the rectangle back. I can just keep everything in the compute unit the whole time, right? And that's the right hand side diagram. And this is the mental model of a fused kernel, right? You have a bunch of operations that are going to happen on a piece of data in sequence instead of writing it back into storage. What I'm gonna na do is I'm gonna na do all the computation as much as I can in one place, and then only when I have to ship it back to memory, right? So that's this idea of kernel fusion. Okay, there's some very simple examples of how if you write some naive code, you might get sort of a naive set of launches. So here's an example. I wrote a little, let's say, neural network module. Let's say I write a neural network module that takes in sx and it produces sine squared x and cosine squared x simple code. Now if I run this, you know, the computation graph in pi torch is gonna to look something like this, and it's going to launch a whole bunch of cuda kernels. It's going to launch, take in the x and itlaunch a cuda kernel to compute sine x, itlaunch one to compute cosine x, then sine squroot of x and cosine squroot of x and sine squared x plus cosine squroot of x. So there's a bunch of back and forth that has to happen in order to do this computation. It's exactly the left hand side figure that I showed you before. But if you were a little smarter, right, and you either wrote your own kuda kernel or you use something like torch compile, well, you can easily realize that those five operations don't really depend on very much. Like they use only a little bit of memory. And so you can fuse them into a single operation that does everything on GPU, on a single thread without sending things back to global memory, right? So really easy fusion operations like this can be done automatically by compilers. I just mentioned torch compille. If you aren't already doing this, know you should consider strongly thinking about using torch compile everywhere. We'll show you in the assignment torch compile as well. It's pretty nice. Okay, so I've gone through precision and fusion. If anyone has questions, let me know before I move on to recomputation and other kinds of tricks that we can do on the GPU. Okay, good. So another thing that we can do is called recomputation. And recomputation is this idea of sort of spending more compute to avoid having to do memory access, right? So remember that your original back propagation lecture, this one's actually from cs 221. What do we do? Well, we take our inputs at the very bottom. These are the yellow ones. And then we propagate activations upwards. Those are also the yellow values on the tree. And then we compute the jacobbians backwards. Those are the greeting values on the edges. And then to compute my gradients, I'm going to propagate. You multiply, sort the Jacobian and the activations. I'm going to propagate the gradients backward, right? Well, if you think about it, those yellow values after the forward PaaS have to be stored, right? And then they're stored. And then they have to be taken from global memory, where I stored them and put them into the compute unit mechanically. That's how it has to happen. But that might actually be a ton of sort of memory inputs and outputs happening. Instead, you might actually be able to avoid this. So let me give you an example of how recomputation can speed things up. Here's another sort of silly function that I might, right? I'm just gonna to stack three sigmoids on top of each other, right? You can look at the left. That's the forward graph. That should be exactly your mental model of three sigmoids on top of each other. Now you know the computation and graph for this. I'm going to compute the sigmoids, and I'm going to store s one and s two, which are the activations of the sigmoids. And I have my outputs, and then that's my sort of forward PaaS. Now, the backward PaaS in this is kind of terrible. When I do my backward graph, I need to go and take s one and s two, and I need to take the gradients coming sort of backwards into this out box, and then push that into this backwards computation, and I'll get the gradient of x. So I need to have three memory reads, one memory right, in order to compute the backwards PaaS. And then for the forward PaaS, I need to do one memory read of x, and I need to do three memory writfor, s one, s two and out, right? So hopefully that's clear. This is a decent amount of memory reads and write. I have to do eight of them. And I have very low arithmetic intensity because I have no matrix multiplies at all. So the idea of recomputation is to say, I don't want to store those activations at all, right? Like I'm not gonna to put them into memory. I'm just going to recompute them on the fly in my backward PaaS. So now in my new forward PaaS, I don't store s one and s two. I take x as input, I compute my sigmoins and I get my output right. So now that's one memory refor x one memory right for l right now in my backward PaaS, right? I don't have activations anymore. So what I'm gonna to do is I'm gonna to get both d out, which is you know the backward signal coming in from above, and then x, which is my input, right? So I'm going to take two of those, which is two memory reads, and then sort of on the fly in my sm, in my local memory, I'm going to compute each of these sigmoids, and I'm going to put them into the backward graph. I'm going to recompute s one, s two and out on the fly inside sort of my local memory. And because I do that, there's no global memory reads happening here. And then I have one memory, right, which is dx, right? So now if you compare the two, I have five eighth of the memory access for the exact same computation, right? The price that we paid is that I'm gonna to have to recompute these three sigmoids, but if you are running sort of idle anyway because you were memory capped, this is a great tradeoff, right? Like you would be very happy with this because now you've traded compute, which you have too much of, for memory bandwidth, which you had too little of. So this is one great way of trading one thing you need for another thing that you have. And of course, this is different. It's the same trick as sort of gradient checkpointing and recomputing activations for memory savings. But this is being done for a different reason. This is for sort of execution speed, not just because you're running out of memory, so it's the same technique, but for different goals. Okay. And then this one, I think is actually kind of a really interesting one and not one that I knew until I started sort of really looking into how the hardware model of the GPU and dram works. So the slow memory, the global memory called dram and a GPU that's actually very, very slow. And in order to make it faster, there are certain optimizations that are being done at the hardware level. And one of the optimizations that's done at a hardware level for dram is that when you go and read a piece of memory, you don't actually get just that value back. You actually get a whole chunk of the memory back. And this is called burst mode. So let's say I went on and tried to read the very first value of this big memory block, right? Instead of just the memory giving me back zero, it would actually give me back zero, one, two, three. It would give me back four values at once. So be like, here you go. You know, I'm sure you'll need the 1232 in the future. And so each address space is cut up into what's called birth sections, and then you're given the entire birth section rather than just what you looked for. And this might seem very mystifying. Like why would the memory give you three extra bytes for free when you're just asking for one? There's sort of like a very interesting hardware reason, which is that when you're addressing into the memory, know, in order to send the signal out from the memory that those bytes have to be moved to an amplifier. That's the slow step. And once you've done that, you can get many, many bytes for free. And so that's why sort of this birth section thing exists. It's kind of masking this more expensive step of actually moving where the data is stored to this amplifier. But kind of regardless, this kind of means that we might be able to significantly accelerate sort of our memory access if the pattern of memory access is good, right? So if I want to read this entire block over here, if I access it in random order, right, then I'm gonna to have to basically query a number of times equal roughly to the length of my query, right? But if I sort of go and I check the very first value, then I'm gonna to get all this entire birth section at once. And then if I go and check number four, I'll get this birth section, the second birth section at once. And so I can you know basically get four times the throughput if I'm really clever about my memory accesses and only access just the bits I need from each birth section. So this is called memory cololescing. So if all the threads in a warp fall within the same burst, then basically the sort of smart hardware and programming model will basically group those queries. Instead of querying zero, one, two, three, it will group them and say, just give me zero. And then I will be able to read out all the zero, one, two, three at once from this kind of burst mode dram, right? So remember that a warp is 32 sort of numbered threads. And so memory accesses from a warp happen together. And so when these warps are reading in to these kind of burst sections, there's optimizations that can be done so that you're getting all four bytes at once, rather than getting one of them at a time individually. And so that will forex the throughput that you have on your memory, right? So these are kind of very simple things, but they're actually very important. Like imagine I'm going to do matrix multiplications, right? This is a core thing that you're going to have to do a ton if you were to sort of implement, let's say, a neural network really from scratching cuda. In this case, imagine I'm going to read my matrices in one of two ways. I can read it by traversing the rows, right? So each thread is gonna to traverse the row, or I can sort of read it in sort of column order, so each thread is going to go down a column, right? Turns out that this left one where you're sort of going across different rows, so each thread is accessing a different, sorry, each thread is going through columns. This left model is going to be quite slow because the memory reads are not going to be coalesced. Whereas if you're going to this right side where each of the threads are going down so they're incrementing in rows, then these memory reads will be coalesced. And so you can think about it for a moment why this is. When I first looked at this diagram, I was like, isn't it reversed? It's actually not this. This is the correct one. And the way to think about this is, let's say, on this right hand side diagram over here, I'm going to have a thread that's trying, a series of threads that's trying to access left to right. So each thread is going to try to load the very first element. And then in the next time step, I'm going to load the element from this column, the second column, and then the third column and the fourth column and so on. So if that happens, what happens at time step one? At time step one, my first thread loads this point, and then the second thread loads this point, and then this point in that point. So those can't be coalesced at all. They're reading different burst sections. And so that means that I have to read this entire chunk of memory in order to perform any sort of an operation. Instead, if I was sort of going in the column direction, all the threads will be reading within the single burst section. And then so only one memory read operation needs to be performed and you get all of the memory at once. This is a very low level optimization, but this is very important. If your memory traversal order is all wrong, you will actually get much slower memory accesses than you really want. Okay. speaker 3: So then that brings us to kind of . speaker 1: the very last and kind of big one. And this is the idea of piling. And piling is this idea that you would like to group together memory accesses in order to minimize the amount of global memory access that we have to do. And so to explain this one, I'm going to try to go through this example of a matrix multiply. And hopefully I'll be able to sort of explain to you why sort of a naive algorithm for doing matrix multiply is going to be very problematic. And then afterwards, I'm going to give you a tiled version of the same idea, and hopefully you'll be able to see why that's going to reduce the number of global memory reads that you have to do. So let's start with this very simple matrix multiply algorithm. So you, I've got a matrix, you know, I got this m matrix on the left side. I've got my n matrix on the top. And in order to compute you the matrix matrix product, right, I'm gonna to have to traverse over the rows of m and the columns of m and then take the inner product and sort that into this p matrix, right? The corresponding rows. And I've written out here each of the threadthread zero zero zero one, one, zero 11 corresponding to where they're sort of storing their outputs and sort of the access order in which they access each of the individual elements. Now notice here that you know what's going to happen is that the memory access here is not coalesced like the row matrices here. These are going to be accessed in a non coalesced order. And I have repeated memory accesses, right? So I've got m zero zero being accessed in the first thread, m zero zero accessed here, zero n one being accessed in two different threads. So these values are being kind of read over and over from global memory into many different threads. So this is going to be potentially very slow. So there's a question of can we avoid having too many global memory reads and writes what I would ideally like to do, right? So let me explain kind of the ideal outcome first, and then I'll explain the algorithm. The ideal outcome is that I would like to spend one sort of chunk of time loading pieces from global memory to shared memory, where things are fast. I want na do a ton of computation and shared memory and and then I want to kind of be done with that piece of data. That's the ideal outcome. I've minimized my global memory accesses. So now how can I do this in this matrix multiply world? So now what I'm going to do is I'm gonna to take my matrices, both the n matrix and the n matrix, and I'm gonna to cut them up right into tiles. So here I've cut this up into two x two tiles. So I've got a two x two m tile in a two x two n tile, right? So I've got basically smaller sumatrices within each of the matrix. And now imagine that my shared memory is big enough to be able to fit these sumatrices within each of these sms. So now this gives a very, very simple algorithm with which we can do computation. So what I'm going to do is I'm going to first load, let's say, this m zero zero tile on the top left over here, and I'm going to also load my m zero zero tile into shared memory here. So now I have these partial sums that I can compute. I can take the row product of m zero zero, m zero one with n zero zero, M1 zero, and I can increment that into p zero zero. I can do the same with all the different sumatrices that I can fill out over here right now. Then once I'm completely done sort of processing these two tiles, then I can load a new tile over here, and then I can repeat that computation with my m tile and my M2 point zero tile loaded into shared memory. And then I can sort of increment my partial sums in p. So now I've really sort of consolidated and reduced the amount of global memory access I have to do, right? I load as much memory as I can at once into shared memory. I do all my sort of sumatrix computations on that tile that I can, and then I move on to the next one, right? And of course, the other nice thing is that because I'm loading an entire tile, I can traverse these sub matrices in whatever order I want, like column or row measure. And so I can coalesce all of the memory accesses whenever I'm loading a tile from global to shared memory, right? So there's kind of wins all around here when we tile our accesses. So we can do a little bit of piling math. So we've got, let's say, a matrix A, A matrix b and a matrix c. So let's say the full matrix es. These are square matrices are of size n and let's say I have a tile of size t right? Oh, yes, question. speaker 2: Load ousort zero so free we're loating and zero zero again. speaker 1: So in that case, I just wrote it for completeness. But m zero zero, let's say it's just know stored in shared memory. Let's just keep it cached. I won't load it again. That's definitely just there for completeness. Not that you would actually like discard and reload the matrix again. That would be kind of insane. Cool. Okay. And so we can kind of do very simple tiling math to think about, you know, what's happening. So let's say I'm going to do an n by n matrix multiply, right? So if I do a non piled matrix multiply, if I'm just going over rows and columns, then every input, every time I process it, has to come from global memory. So each input is read sort of n times from global memory, right? So each of these is read sort of n times. If I do a tiled matrix multiply, well, you, the global reads are operating over tile. So I'm reading each input n over t times from global memory, and I'm reading t times within each tile. Of course, I'm doing matrix. Matrix multiplies. So I can't reduce the total number of reads. I have to read all the matrix elements, but I can shift the reads into basically fast shared memory. So I do t times memory reads into shared memory and n over t times from global memory. And that's great because if we have a big shared memory that can store big tiles, that's a factor of t reduction in the total amount of data that has to come from global memory. So tiling can be really, really powerful of an idea when you're operating over matrices and you can move things into shared memory. Piling is quite complex. This is the source of many, many sort of confusing things about GPU and matrix multiplied performance. One thing that can happen, right? Once we start piling things, you start asking things about discretization, right? So imagine I have a tile size of 128. That seems like a nice good round tile size. But then you know, when I have a full matrix of 256 size, that's great. That's the two by two tile things load nicely. Now let's say I have a 257 size tile on the column side. Now this is a bad time because I need to have six tiles in order to cover this matrix. And the two tiles on the right are very, very sparse. There's just not much stuff in there, right? And the problem with this is that each tile is going to be assigned to sm, right? So each of these tiles is going to be a block, and each thread is going to be operating within each tile. So those two tiles on the right, they're not gonna to be doing very much at all, right? Those sms are going to be basically be sitting idle. And if you were kind of compute cap, you would have wanted to more evenly distribute the load between sms, right? So you have to basically optimize your tile sizes to try to avoid these kinds of scenarios. But in reality, right, there's a lot of complex things that go into setting the tile size. Remember, you have to coesce your memory accesses. You have to think carefully about that. You have to not exceed your shared memory size, right? So the palace can't be too big and you have to divide the matrix dimension hopefully evenly or as close to evenly as possible. So you don't end up with this situation of sort of an underutilized sm at the very end here. Yes. So if . speaker 2: you have say, smaller would gpudo something like your vewhere, they can that's that the title led team of war man. And so like when that happened, Yeah. So you're asking about . speaker 1: whether or not you can like overlap memory reads and computation. And Yeah, that's naturally done in GPU's. Like they're always like trying to use the available bandwidth. Like as long as shared memory is available, they can go and put things into it. The issue is that whenever you're effectively utilizing your sms, you're basically maxed out on your shared memory. That's like the bottlenecked resource. And so there is no . speaker 3: place to prefetch . speaker 1: in some sense. Cool. And the other thing that is very, very you know, we're getting into the weeds here complex is the interaction between tiling and sort of burst sections. So imagine I have a matrix layout that's kind of like this where you, I have my nice burst sections and each burst section lines up nicely with a tile. So to read this tile, all I have to do is to you know, get four different burst sections and I've gotten this entire tile. Now imagine what happens if I add sort of one element extra and the way the matrix is laid out, you know, sort of the sort of my burst sections flow over. So now what's happening is when I load my pile, I'm gonna na load this first part, and that's really great. I get the entire first row as a burst section. Now in the second row, this actually belongs to two different burst sections. And so I have to do two reads in order to get this second row and so on and so forth. So I've essentially doubled the number of memory accesses because I've added a single extra element at the very end there that's kind of bumped up the alignment of my birth St section and my align layout. And so basically, if tiiles or your matrix sizes aren't multiples of your burst section, you can easily end up with situations like this where the rows don't line up with the burst section and you've doubled the amount of memory access that you have to do. And the way to get around this is you have to do padding to be able to kind of get nice rounds matrix sizes so that your bursections line up with the size of your tiles. So this this is getting very into the weeds here. But if you really want na squeeze out all the performance from your matrix multiplies, these are the kinds of things you have to think about, right? And you will get bitten by this if you're not thinking about it. speaker 3: And of course. speaker 1: I guess like things like torch compile and all the kuda optimizations for matrix multiplies, they're doing exactly the kinds of stuff that I just talked about, right? That's the way you get better performance. And so know all this matrix complexity ends up in situations like this where I'm reading Andre's this tweet here, but the most dramatic optimization to nanogpt is to increase the vocab size from 50257 to 5304, which is the nearest multiple 64, which gives you much, much higher occupancy. Careful with your powers of two. So that's the 25% speed up from adding how many it's like 50, 57 or 47 dimensions to your vocap. Like that's kind of like how does that happen, right? And so that kind of brings us back to the mystery. Like I was dragging you through all of the GPU details in the hopes that you'll have a full understanding of all the performance characteristics. But in some sense, the payoff is I now get to explain to you how this chart comes to be. And at the end, you won't find matrix multiply performance to be so mysterious or scary at the end here, right? So the very first part is very, very simple, like we understand compute intensity, right? This is exactly the roofline that I pointed out at the very beginning, right? So up until here, which is about 1536, right, there's just not enough matrix multiply work to do, right? Just loading the matrix and doing very basic io, right? That you have to do is becoming a bottleneck below this point. So throughput is gonna to fall through to the ground past this point, you just don't have enough memory bandwidth to support your compute units. Now on the right side here, in theory, if I draw the upper envelope, this is the kind of maximum achievable performance. So it's possible up here to saturate all of my compute units and get really great performance. But if you kind of mess up your matrix sizing, you can end up in these kind of really weird places. And within each one of these, you can kind of end up in a weird troough. And so we're going to kind of think a little bit about, you know why do you have all these different places you can end up? So the very first thing, this first line here, this is a tiling alignment issue. So if you look at kind of the multiples here, so I've now colored each of these lines based on kind of the divisibility of the matrix size. And this is the size by which it's divisible. So if it's divisible by 32, then you're in good shape. You're in these purple dots up here. If you're divisible by 16, you're actually still up here. There's two colors, and then if you're Green, your k equals eight, you're up here. If your orange or k equals two, and if your k equals one, you're all the way down here. If you're not divisible by any number, don't pick prime dimensions. You're not going to get a very good throughput on your matrix multiplies. And a big part of this is going to be once you get to kind of k equals two and k equals one, you are basically forcing the situation where you can no longer read tiles in the sort of nicely aligned way with your burst readeds. And that's going to lead to some serious issues. So that's kind of a problem. But then so that's one part of the mystery. But I think another part of the mystery remains like, so within this orange line, you know I think if you zoom into here, you see this giant drop right from from this point all the way down to this point, where you're just kind of wondering what happened here? How could I lose so much performance, increasing my dimension by two? And so let's just look at these numbers. And it's just, I think this is a fun puzzle. So I'm just going to walk you through the puzzle. This is going to happen when you transition from 1792 to 1790. I guess three or four size, let's say four here just so that it's a factor of two still. Well, why does that happen? Okay, well, let's say that we're using a tile size of 256 by 128. That's a pretty natural size. As a fun fact, you know, the matrix multiply units in these GPU's, they're naturally operating on matrices of roughly size 128. So 256 by 128 is a very nice tile size, right? So that means how many piles are there? Well, there's seven times 14 tiles because we're dividing the dimension of the matrix by the size of our tiles. That's a total of 98 different tiiles. And if we increase this by one, well, we're going to have to round up each one of our coordinates. And so we're going to have a lot more tiles, 120 of them. So we've increased the number of tiles by quite a bit. Well, you know what's going to happen is not only did we significantly increase the tiles and some of them have lower utilization, which is bad, but actually even worse, an a 100 has 100 and asms, right? And if you go all the way back to the kind of the GPU execution model, sms can execute in parallel and they're kind of the execution units. And so when you have 98 sms, they all go and run, right? You can dispatch them all. All the sms are running. You've got great utilization. Once you go to 120 tiles, now you've got more tiles than sms. So 108 of those will execute. And then you will go back and you'll say, all right, I've got some more sms at very, very low utilization. You're going to execute the remaining twelve and wait for those to complete, right? And that's going to be really bad. So if you look at your utilization, you'll a good utilization for a while, you'll drop off a Cliff and then you'll sort of finish up your job. So this is something called wave quantization. And so ideally, your tile sizes are either much bigger than the number of sms or know they're not like this, where you're just like barely over the sm and you've caused this quantization sort of error. Addicool, all right. I know this is low level details, but in many ways, I've been saying through many classes that language models and deep learning is attention to detail. And these kinds of attention to details, the things that allow people to scale up lms to really, really large sizes and get great performance. So it's worth knowing, even if you're not a person that's going to do systems engineering. So what were the tricks, right? Key ideas here. First one is you got to reduce the amount of memory accesses, right? So there's lots of ways to do it. You can do coalescing, right, so that you're you can sort of reuse reads that you're getting for free. You can do fusion so that you can fuse multiple operations together and avoid unnecessary reads, and you can move memory to shared memory. So even if you're going to do reads, they're going to be from much faster memory, and that's going to be sort of piling tricks that you can do. And then finally, you can kind of trade memory for other resources that you do have. So you can trade it for compute, which is going to be recomputation, or you can trade it for just numerical precision or stability, which is going to be quantization. So there's lots of bags of tricks that you have in order to get sort of performance out, right? So so there's lots of things you can do. You just have to be really mindful of kind of the role that memory plays in the performance of a GPU. That's kind of the key thing to get the most out. Cool. Any questions on that before I sort of move to the final part with flash attention? Okay, good. All right. So now I'm going to put it all together, right? I'm going to try to make it so that all the tricks that I taught you, aren't these like random disconnected facts about GPU's? They're kind of part of the standard performance optimization toolkit. And flash attention, and flash attention too, will hopefully teach you how that all comes together to build one of the foundations, I guess, of modern high performance transformers. So flash attention, you we know that it dramatically accelerates attention, and most of you probably know that that's done through some kuda kernel magic, but maybe you don't know all the details, right? So you what the paper says is okay. So there's one part that's happening, which is you you do attention on a unopmized know pi torch transformer implementation. If you fuse the kernel and you do some things, you can get significant significant speed ups. And from the paper, you know they say we apply two established techniques, tiling and recomputation, to overcome the technical challenge of computing exact attention in subquadratic hbm accesses. So it's not subquadratic computation because you can't do that. You have to compute attention in general, but they're going to get subquadratic accesses to the high bandwidth or global memory, right? And so that's really the key. If your memory is the bottleneck, you you want to make that not quadratic so that at least you can pay for quadratic cost with your compute rather than with your memory. So just for a really quick recap, you know at this point you've implemented attention many, many times in many classes, right? So it's going to be three different matrix multiplies. You've got a kq and v with a softmax in between. So the matrix multiplies are pretty simple. That can be done with tiling. I've showed you examples like that. What's different about attention? Well, there's a soft tmax thing that's gonna to be the real tricky bit. And then once we can deal with a softmax, all of the sort of matrix multiply things I was talking about will just come into play. So the matrix multiply, as I said before, is exactly what I taught you. So if you look at the figure one from the flash attention paper, this is really just a simple tiled matrix multiply, right? You see the k matrix, the q matrix, you see it cut up into small blocks. You small blocks of it are being copied to sram. They're being multiplied, and then they're being you know accumulated descent to the hbm where you do soft maxes and then you multiply with A V, right? So this is all just really simple in terms of the kqv matrix multiply. But now we have to think about the softmax. Like what's going on with the softmax? So the key thing here is the softmax. Sorry, I'm going to roll back one step. So the issue with the softmax, what's the problem with the softmax? It's a global operation, right? The softmax in an attention operates row by row. You have to sum the entire row to compute sort of the sum normalizing term of the softmax. And that's very problematic if I have tiles, right? Ideally, I want na do everything within the tiles. I don't ever want to have to write back to the big matrix. And so I need a soft max that can be computed online within each tile. I want na do as much computation within each tile as possible. So the key thing here is to use what's called the online sofmax. And so what is that? If you have a stream of values, right? Normally the batch version of the softmax, you take all of your x one through x of nand, you would exponentiate them, sum them, and you would divide them, right? That's what you would do in your normal soft max. And then you would maybe compute the maximum value and you subtract that in order to be able to make this numerically stable. So this is the standard numerically stable soft max on the left side. So the online softmax, I've taken this from mickelloff and gimmelstein in 2018. Well, you can sort of realize that you can pull out, be sort of like a telescope being some kind of an argument, basically the current running sort of normalizer term and the current sort of poterm of e to the xi minus max of xk. So what you're going to do is you're going to maintain your current max that you've seen over x one through x of J, which is my current iteration. And then I'm also going to maintain sort of this correction term. If my max updated, this is going to basically correct my max and then I'm going to add my sort of new term over here. So this d of J is going to track online the top term of this equation term to over here. And then at the end, I can also then compute the normalizer and then sort of get the normalized y of I that I want, this d of v as itself sort of the normalization term that I need. So the key thing here is that this can be done online. I don't need the x one through x of n up front. All I need is sort of the stream of x one through x ed. And that's really key because I can now compute the softmax tile by tile, right within each tile. I can run this algorithm, and that will let me compute kind of the partial softmax for that tile. And then I can sort of write back, if I need to, all the components that I sort of I'm keeping track of. And that's all that I kind of need in order to do this computation. So I never have to materialize the full n squared matrix in order to compute the soft max. And so that's basically it. But once you have that, you know you've put it all together and you can get the forward PaaS of flash attention. And if you go and look at the flash attention to paper, which is gonna to be a thing that we're gonna to ask you to implement, so you're gonna to be following through kind of these steps here. You're gonna to see exactly this idea. So first you're gonna to have your kq matrix multiply, and this is gonna to be tiled. So these are little tiled chunks and they're gonna to be multiplied. And how am I going to compute the softmax? Well, I'm going to maintain sort of a running value of these sort of exponentiated sums, and then I'm going to keep incrementally updating it and correcting for the maximum terms. And by doing that, I can compute all of the necessary quantities, kind of tile by tile, sort of going from one tile to another and then just multiply once again with tiles with v in the end. And that will give me sort of my full soft max output right . speaker 2: until we compute the like two k multiplication across all, except we do have to double back on. speaker 1: So the question one is, you can't compute this until you are done with all the tiles. And so you have to double back on all the tiles. speaker 2: Denominator suuntil, every color. speaker 1: That's right. So you will have to, before you can output your sofmax, you will have to go through all the tiles. This is correct. But by, let's say, I do all the tiles once, right? Like I do all n square ared tiles. At that point, I have all the components that I need in order to directly output the soft max. At that point, I don't have to do recomputation because I have the normalizer terms already, right? By going through each of these kind of tiles, at the end of going through all these tiles, I've built up you know, l three or l of n, which is the sum of all of the exponentiated terms. So I already have that in my sort of in my shared memory for this last tile. And then that allows me to exponentiate and divide and then return all of the components. Okay. So the backward PaaS I'm not going to cover, you can do recomputation tile by tile, which will allow you to avoid storing the softmax, right? Remember, I always want to avoid storing anything that's of size n squared. And so here I've been sort of clever with the tiles so that I don't have to store any of the n squared components when I'm computing, for example, the softmax. But in the backwards PaaS, if I store the activations, that's already something that's n squared size, right? So I don't want na store my n squared activations. I'm gonna have to recompute it on the fly tile by tile when I do the backwards PaaS. So that's a really key other trick that they do in order to make the backwards PaaS possible. But otherwise it's fairly standard to do the same thing as computing the gradients, just tile by tile and doing that computation. So okay, that brings us to the end here. Hopefully you've kind of seen how all of the pieces I talked about about tiling and coalescing and recomputation come together to give you flash attention and all these really cool things that make your transformers go much faster. So to know recap for the whole lecture, right? Hardware is kind of the thing that has really powered all of the language models that we have today. And so if you really want to leverage your hardware, you have to understand the low level details. I think all of the systems advances really engage with a lot of the concept ts that I taught today and the current GPU sort of scaling. You know that plot is really the one you should remember, really, really incentivizes and encourages you to think about memory movement, right? The memory movement is the bottleneck in all of this. And so you don't want to just think about, Oh, how do I reduce the number of flops? That's important too, really. You really have to think about, okay, how do I make my memory movements more efficient? And then finally, if you you have to do a certain amount of computation well to optimize things, the way to do it is to optimize your data movement to be able to avoid as much movement from the high bandwidth memory or the global memory as possible. You want to reduce that and have everything in the very, very fast shared memory. And that leads to good performance on things like flash attention. Thanks, everyone.
最新摘要 (详细摘要)
概览/核心摘要 (Executive Summary)
本讲座深入探讨了图形处理器(GPU)的硬件架构、执行模型、内存层次结构及其对大规模语言模型性能的关键作用。核心观点强调,随着GPU计算能力的超指数级增长远超内存带宽的提升,内存访问已成为现代GPU应用(尤其是语言模型训练与推理)的主要性能瓶颈。讲座详细解析了GPU与CPU在设计目标上的差异,前者优化吞吐量,后者优化延迟。重点介绍了GPU内部的流式多处理器(SM)、流处理器(SP)以及至关重要的多级内存(片上L1/共享内存、L2缓存、片外全局HBM内存),强调了利用内存层次结构的重要性。
为提升GPU利用率和性能,讲座阐述了多种优化策略,包括:使用低精度计算(如FP16/INT8)以减少数据传输量;通过算子融合(Kernel Fusion)减少中间结果的全局内存读写;利用重计算以计算换内存带宽;通过内存合并(Memory Coalescing)优化DRAM突发模式下的读取效率;以及采用分块(Tiling)技术将数据加载到高速共享内存中进行计算,从而大幅减少对慢速全局内存的访问。讲座通过分析矩阵乘法性能图中不规则波动的原因(如分块对齐、波次量化),揭示了这些优化细节的实际影响。最后,以FlashAttention为例,展示了如何综合运用分块和重计算等技术,在不存储完整N²注意力矩阵的情况下,实现对Transformer中注意力机制的高效计算,显著减少了对高带宽内存(HBM)的访问次数。
GPU的重要性与硬件基础
硬件发展与并行计算的需求
- 计算规模与模型性能:讲座指出,更多的计算资源(compute)对训练大型语言模型至关重要。引用了预训练和推理的缩放法则(scaling laws),表明计算量、数据量和模型规模的增加均能提升性能。
- 硬件驱动性能:深度学习的性能提升不仅源于算法,更依赖于“更快的硬件、更好的利用率和改进的并行化”。
- 从Dennard缩放到并行缩放:
- 早期半导体通过Dennard缩放(晶体管更小、更快、功耗更低)提升CPU单线程性能。
- 1980s-2000s后,单线程性能提升趋缓(引用Hennessy & Patterson图表)。
- 性能提升转向并行缩放,即同时处理大量工作负载。
- GPU算力增长:引用Bill Dally的图表,展示了从K20到H100 GPU,整数运算能力呈“超指数级增长”。理解如何利用这种增长是优化语言模型的关键。
CPU 与 GPU 的设计哲学差异
- CPU (Central Processing Unit):
- 设计目标:优化延迟 (latency),即尽快完成单个任务。
- 架构特点:拥有大型控制单元和分支预测逻辑,用于快速执行具有大量分支和条件控制逻辑的单线程程序。
- 核心数量相对较少。
- GPU (Graphics Processing Unit):
- 设计目标:优化吞吐量 (throughput),即在单位时间内完成尽可能多的任务总量,即使单个任务延迟可能较高。
- 架构特点:拥有海量的计算单元(ALUs),控制逻辑相对较小,用于并行处理大量线程。
- 线程可以快速休眠和唤醒。
GPU 核心架构
- 流式多处理器 (Streaming Multiprocessors, SMs):
- GPU的核心执行单元,可类比为CPU的核心。
- Triton等编程模型的操作单位。
- A100 GPU拥有128个SMs。
- 流处理器 (Streaming Processors, SPs):
- 每个SM内部包含多个SP。
- SP负责执行实际的并行计算,对不同数据执行相同指令。
- SM包含控制逻辑,决定执行内容;SP进行大量并行计算。
- A100的SM内含大量SP和专用矩阵乘法单元。
GPU 内存层次结构与重要性
- 内存的重要性:讲座强调,“内存 arguably 更重要”,其性能对GPU程序影响巨大。
- 物理邻近性:内存离SM越近,访问速度越快。
- L1缓存和共享内存 (Shared Memory):位于SM内部,速度极快(约20个时钟周期)。用于存储寄存器内容和频繁读写的数据。
- L2缓存 (L2 Cache):位于GPU芯片上,紧邻SMs,速度较快(约200-300个时钟周期),但比L1慢一个数量级。
- 全局内存 (Global Memory / DRAM / HBM):位于GPU芯片外部(通过HBM连接器连接),速度最慢。
- 性能瓶颈:访问全局内存的延迟(约10倍于L2)可能导致SM空闲,降低利用率。
GPU 执行模型
- 三个粒度:
- Blocks (块):线程的大分组,每个块被分配到一个SM上执行。SM是独立的“工人”。
- Warps (线程束):块内线程的执行单位。通常是32个连续编号的线程组成一个Warp。
- Threads (线程):最小的任务单元。
- SIMT (Single Instruction, Multiple Threads):一个Warp中的所有线程同时执行相同的指令,但处理不同的数据。
GPU 逻辑内存模型
- 寄存器 (Registers):最快,存储单个数字,线程私有。
- 本地内存 (Local Memory):原文未详细展开。
- 共享内存 (Shared Memory):块内线程共享,速度快。
- 全局内存 (Global Memory):所有块均可访问,速度慢。
- 常量内存 (Constant Memory):不常使用。
- 关键点:“跨块的信息传递必须通过全局内存”。理想情况是数据加载到共享内存,块内线程高效处理。
TPU 简介
- 与GPU的相似性:Tensor Processing Units (TPUs) 在高层次概念上与GPU非常相似。
- Tensor Core (TPU):可类比为GPU的SM,是原子操作单元。
- 包含:标量单元(控制)、向量单元、MXU (Matrix Multiply Unit)(专用矩阵乘法硬件)、高速片上内存(向量内存、SRAM [原文为vmem和sram,应指片上高速内存])。
- 外部有高带宽内存(HBM)。
- 主要区别:
- 加速器间的网络连接方式不同(将在后续课程讨论)。
- 架构更简单,“TPU不像GPU那样试图做任何事情,它只为矩阵乘法优化”。
- Speaker 2提问Tensor Core是否能操作任意张量,Speaker 1回答其可操作张量,但MXU执行的是矩阵乘法(如批量矩阵乘法)。
GPU 性能优化关键
GPU 的成功因素
- 易于扩展(增加SM数量)。
- CUDA编程模型(SIMT)对矩阵等操作相对直观。
- 轻量级线程,易于调度,实现高SM利用率。
矩阵乘法的特殊地位
- 早期研究者已利用GPU进行快速矩阵乘法。
- Nvidia等厂商已将矩阵乘法视为“受祝福的操作 (blessed operations)”。
- 性能差异:图表显示,从V100开始,由于Tensor Core的引入,矩阵乘法浮点运算性能(MatMul FLOPs)远超非矩阵乘法浮点运算性能(Non-MatMul FLOPs)。
- 设计启示:“如果你要设计任何类型的神经架构……你必须让大部分工作负载是矩阵乘法”。
GPU 组件的相对扩展速度
- 关键趋势:计算能力的扩展速度远快于内存带宽的扩展速度。
- GPU到主机的连接(PCIe, NVLink):增长缓慢。
- 全局内存速度(GDDR到HBM):增长较快(约100倍),但仍属“慢速扩展”。
- 计算能力(MatMul FLOPs):增长惊人(1到100,000倍)。
- 瓶颈转移:早期可能是FLOPs瓶颈,现在H100等GPU的瓶颈“可能是内存,因为内存增长速度跟不上”。未来设计硬件高效算法需更关注内存。
GPU 性能回顾
- GPU是大规模并行处理系统(SIMT)。
- 计算(尤其是矩阵乘法)扩展速度远超内存。
- 利用内存层次结构(快慢内存)是关键。
理解GPU性能波动:以矩阵乘法为例
- 目标:理解矩阵乘法性能图(Y轴:每秒操作数/硬件利用率,X轴:方阵大小)中出现的波浪形、看似不可预测的模式。
- 屋顶线模型 (Roofline Model):
- 性能受限于两个区域:内存限制区(左侧,算术强度低)和计算限制区(右侧,算术强度高,达到峰值吞吐量)。
- 目标是处于计算限制区,充分利用计算单元。
GPU性能优化技巧
-
避免条件分支 (Conditionals):
- SIMT模型下,Warp内所有线程执行相同指令。
if-else会导致不满足条件的分支线程暂停, фактически串行化执行,损害性能。
-
使用更低精度 (Lower Precision):
- GPU性能提升的一个重要驱动因素是数值表示的进步(FP32 -> FP16 -> INT8)。
- 优势:减少数据位数意味着减少内存传输量。
- 示例:ReLU操作,FP32需8字节/FLOP,FP16只需4字节/FLOP,有效提升一倍内存带宽。
- 混合精度:矩阵乘法中,输入可以是16位,累加过程使用32位以保证精度,输出再转回16位。
- 不同操作对精度和范围要求不同(如
exp函数可能需要BF16以获取更大动态范围)。
-
算子融合 (Operator Fusion / Kernel Fusion):
- 问题:若每个操作都将结果写回全局内存再读出,会产生大量不必要的内存传输(“数据在内存和计算单元间来回穿梭”)。
- 解决方案:将一系列连续操作合并成一个单一的CUDA Kernel,中间结果保留在计算单元(如寄存器或共享内存)中,最后才写回全局内存。
- 示例:计算
sin^2(x) + cos^2(x),朴素实现会启动多个Kernel;融合后只需一个Kernel。 torch.compile等编译器可自动进行此类融合。
-
重计算 (Recomputation):
- 核心思想:用额外的计算换取减少内存访问。
- 背景:标准反向传播中,前向传播的激活值需要存储,然后在反向传播时从全局内存读取。
- 示例:
sigmoid(sigmoid(sigmoid(x)))- 标准方法:存储中间激活S1, S2。总共8次内存读写。
- 重计算方法:前向传播不存储S1, S2。反向传播时,根据输入X和上一层梯度
d_out,即时重新计算S1, S2。总共5次内存读写。
- 收益:若系统受内存带宽限制而计算单元有空闲,此方法能有效提升速度。与梯度检查点(gradient checkpointing)技术类似,但此处目标是提升执行速度,而非单纯节省内存。
-
内存合并 (Memory Coalescing):
- DRAM突发模式 (Burst Mode):DRAM读取数据时,并非只返回请求的字节,而是返回一个“突发区块 (burst section)”(如4个值)。硬件层面原因:寻址到放大器是慢步骤,一旦完成,获取邻近字节成本很低。
- 优化原理:若Warp内所有线程访问的内存地址落在同一个突发区块内,硬件可以将这些请求合并,一次性读取整个区块,从而提高内存吞吐量。
- 对矩阵操作的影响:矩阵元素在内存中的布局(行主序/列主序)以及线程访问模式会影响是否能实现内存合并。错误的遍历顺序会导致内存访问效率低下。
- 讲师提到,如果线程按列方向(threads going down, incrementing in rows)访问,内存读取会被合并。这暗示了行主序存储下,线程束内的线程访问同一行的连续元素时效率最高。
-
分块 (Tiling / Blocking):
- 核心思想:将大矩阵划分为小块(tiles),将这些小块加载到高速的共享内存中进行计算,以最大限度减少对慢速全局内存的访问。
- 矩阵乘法示例:
- 朴素算法:对M的每一行和N的每一列计算内积,导致元素(如M[0,0])从全局内存被反复读取。
- 分块算法:
- 将M和N矩阵划分为多个tile。
- 将M的一个tile和N的一个tile加载到共享内存。
- 在共享内存中完成这两个tile之间的所有子矩阵运算,累加到结果P的对应tile。
- 加载下一批tiles。
- 收益:全局内存读取次数从N次减少到N/T次(T为tile大小),其余T次读取在快速共享内存中完成。
- 分块的复杂性:
- 离散化问题:若矩阵维度不能被tile大小整除,会导致部分tile稀疏,SM利用率低下。例如,矩阵大小256,tile大小128,完美划分;矩阵大小257,tile大小128,则最后一个tile非常小。
- tile大小选择:需考虑内存合并、共享内存容量、能否均匀划分矩阵维度。
- 与突发区块的交互:若tile或矩阵大小不是突发区块大小的倍数,可能因对齐问题导致内存访问次数加倍。填充 (padding)是常用解决办法。
- 引用Andrej Karpathy的推文:“nanogpt最显著的优化是将词汇表大小从50257增加到50304(最接近的64的倍数),获得了25%的速度提升”,说明了对齐的重要性。
解释矩阵乘法性能图中的波动
- 屋顶线部分:小矩阵时,受内存I/O限制,利用率低。
- 不同曲线簇:由分块对齐 (tiling alignment)造成。矩阵维度能否被特定数值(如32, 16, 8, 2, 1)整除,影响了与DRAM突发读取的对齐,进而影响性能。素数维度性能最差。
- 曲线内的突然下降:由波次量化 (Wave Quantization)造成。
- 示例:A100有108个SM。若tile大小为256x128。
- 矩阵大小1792x1792 -> 7x14 = 98个tiles。98 < 108,所有SMs可在一波次内高效运行。
- 矩阵大小略增(如1794x1794,导致约120个tiles)。120 > 108,第一波次108个tiles运行,剩余12个tiles在第二波次运行,导致第二波次SM利用率极低,整体性能下降。
- 示例:A100有108个SM。若tile大小为256x128。
优化技巧总结
- 减少内存访问:内存合并、算子融合、分块(数据移至共享内存)。
- 资源权衡:用计算换内存(重计算)、用数值精度换内存/计算(量化)。
FlashAttention案例分析
FlashAttention核心思想
- 目标:显著加速注意力机制,主要通过CUDA Kernel优化。
- 论文核心:“应用分块 (tiling)和重计算 (recomputation)两种成熟技术,克服了在亚二次HBM(高带宽内存)访问次数下计算精确注意力的技术挑战。”
- 注意是亚二次内存访问,而非亚二次计算量。
- 注意力机制回顾:
softmax(QK^T/sqrt(d_k))V。包含矩阵乘法和Softmax。
解决Softmax的挑战
- 标准Softmax的问题:Softmax是全局操作,需要对整行求和以计算归一化项,这与分块思想冲突(理想中所有计算在tile内完成)。
- 解决方案:在线Softmax (Online Softmax):
- 标准Softmax需要一次性获得所有输入
x_1到x_n。 - 在线Softmax可以流式处理输入,逐个接收
x_i,并维护当前的最大值和指数和的运行统计量。 - 关键优势:允许逐个tile计算Softmax,无需预先获得或存储整个N²注意力矩阵。
- 标准Softmax需要一次性获得所有输入
FlashAttention 前向传播
- 步骤 (FlashAttention 2论文图示):
- 分块KQ乘法:Q和K矩阵被划分为小块进行乘法。
- 在线Softmax:逐块计算Softmax,维护运行的指数和及最大值校正项。
- 与V分块相乘:将Softmax结果与V矩阵(同样分块)相乘。
- Speaker 2提问是否需要所有tiles完成后才能计算Softmax分母,Speaker 1确认需要遍历所有tiles,但通过在线更新,遍历完所有tiles后,就拥有了计算最终Softmax输出所需的所有累积项,无需重新读取原始N²数据。
FlashAttention 反向传播
- 关键技巧:逐块重计算 (recomputation tile by tile)。
- 原因:避免存储N²大小的Softmax矩阵或其激活值。标准反向传播若存储激活,会产生N²的内存占用。
- 通过在前向传播时不存储中间的N²矩阵,在反向传播时根据输入和梯度重计算必要的注意力分数,从而保持内存高效。
讲座总结
- 硬件是驱动语言模型发展的基石,理解底层细节至关重要。
- 当前GPU扩展趋势下,内存移动 (memory movement) 是核心瓶颈。
- 优化关键在于优化数据移动:尽可能减少对全局高带宽内存的访问,充分利用高速共享内存。FlashAttention是这一思想的成功实践。