Stanford CS336 Language Modeling from Scratch | Spring 2025 | 06 Kernels, Triton
该讲座聚焦于为语言模型编写高性能GPU代码。内容首先回顾GPU基础架构,包括流式多处理器(SM)、线程、内存层级(DRAM、缓存、寄存器文件)、线程块和线程束(warps),并强调了算术强度的重要性。讲座重点阐述了基准测试和性能分析在识别与解决代码瓶颈方面的核心作用,主张在优化前进行充分的分析。演讲者计划演示使用CUDA(C++)、Triton以及PyTorch的即时编译器(JIT)编写内核,并对比它们的性能,同时深入分析底层的PTX代码,最终可能实现一个快速的softmax函数。此外,讲座提及了课程作业,特别是与GPU内核和并行计算相关的第二项作业,并将使用一个简单的多层感知机(MLP)模型作为示例进行演示。
标签
媒体详情
- 上传日期
- 2025-05-13 17:44
- 处理状态
- 已完成
- 转录状态
- 已完成
- Latest LLM Model
- gemini-2.5-pro-exp-03-25
转录
speaker 1: Today we're going to be going into details on writing high performance code for GPU's. So part of assignment two is going to be you're going to have to you do a bunch of profiling, you will have to write your own Triton kernel for flash attention to. You will need to sort of make all of this stuff very high performance. And so in this lecture, we're gonna to kind of drill down a little bit and we're gonna to try to you know write some high performance code for standard components in a language model. So the plan for this lecture is we're going to just do a brief amount of review about GPU stuff just to make sure you have, once again, the basic components of the GPU's that we need to understand in order to follow the rest of the lecture. And then I'm going na show you a bunch of sort of really basic things about benchmarking and profiling, which will be helpful for both the assignment and in general if you want to write high performance pii, torture, deep learning code. And then we're going to basically write some kernels. We're going to write cuda kernels in sort of C++. We will then do the same thing in Triton. And then lastly, we're going to do the easy but very good thing of using pi torch's existing jit compiler to have it optimized for us. And then we'll compare all of those and the profile and benchmark things. And throughout, we're going to really dig in deep. We're going to go down all the way to the ptx to pretty close to the machine code to understand what you know the GPU is actually doing under the hood when we write all of this code. And then hopefully, we'll have time. And I think we will finish by writing sort of a fast trinimutation of soft max at the very end. Okay, so assignment one has come to a close. There's still a leaderboard. You can still submit and update things there. Some of you may be using late days, so please finish up assignment one and then assignment two is now out. And as I said before, there's gonna to be a bunch of systems stuff that you're gonna na need to do. There's fun parts that you can do now involving GPU kernels. And then next week we're going to talk about parallelism, and that's going to be the other half of the assignment, writing fast parallel code like data parallelism and so on. So we will get to that next week. All right, so now remember how GPU's work, right? So when we have something like an a 100 or an H -100, we're gonna to have a whole bunch of sm streaming multi processors. Within each sm is a large number of units that can do computation. We have in 32 ones or fp 32 ones. And then each sm is going to launch a large number of threads. And we have the memory hierarchy, which is that we have dram or global memory, which is big and slow. And then we've got caches that are much faster. And in fact, you see here, there's a thing called the register file. This is very, very fast memory that each thread can access. And we're gonna na be making heavy use of these registers as we write high performance code for GPU's today. So the basic structure for the execution model is going to be we're going to have a collection of thread blocks, and a block is going to be scheduled on a single sm. So this is kind of the atomic unit that we're going to be thinking about, especially when we write code and things like Triton. And then within each block, there's going to be a whole bunch of threads, and the threads are actually going to be the ones doing the computation. And so if you have a vector and you're gonna to be operating over elements of that vector, right, you're gonna to write code where each thread is going to go in and maybe operate over a few elements of that vector at once, and all the threads together will sort of process the vector completely. So why do we have these things called thread blocks? Why not just have threads and your big global context? Well, thread blocks can communicate with each other, their shared memory kind of within the sm. That's pretty fast, right? So when you need to do something like matrix multiplication, you're going to need to PaaS information from thread to thread and within a thread block that's very fast across thread blocks or across these groups, it's going to be very expensive. So any data that you need, you're gonna to want to keep within the same thread block or within the same sort of tile, and that's going to keep things very, very fast. And that's going to be as fast as sort of A L one cache, and that's a great place to be. And so you can use this to synchronize across threads, but you can't know, for example, synchronize across blocks. You can't really control what's going to happen, right? And to remember the thing that I mentioned last week, there's this thing called waves, right? Waves aren't sort of an inherent thing that you normally think about, but for performance, it is an important component. So when we actually run these things, the threads are grouped into consecutive blocks of 32 threads, and that's a wave. And that gets executed kind of all at once in an sm. And so one thing that we would like to do is to make sure all the waves have an equal amount of computation. We can't always do that, but if we can, we would like to do that right? So we want to make the number of threadblocks, ideally divide the number of sms and to make sure that each wave has an equal amount of work. So we're going to ideally have a lot of more thread blocks than sms, and we're going to try to make that happen as we write high performance code. Okay. And then the last concept, and maybe amongst the most important concepts here is arithmetic intensity. We would like to keep arithmetic intensity high. We would like to have more flops than we have bytes of memory movement. And this is because know, if you remember the scaling plot from last lecture, our compute scaling is much, much faster than memory scaling. So a lot of the time computations are going to end up being memory bound that we're not actually getting all of the work done right? So as a general rule, you know matrix multiplication is compute bound. If we kind of do it cleverly, everything else is going to be memory bound. And we're gonna to try to cleverly reduce the amount of things that are memory bound or how badly things are memory bound. Okay, so that's our very, very brief sort of review of GPU's. Hopefully everyone remembers this. You still have a fresh sort of memory of execution model. Feel free to stop me and ask questions if any of you have sort of lingering doubts or questions about how this is all going na work. Yes. What was the function of sorry, a warp. A warp is essentially a group of threads that get executed together. And the reason why warps exist is that they reduce the amount of control machinery that's needed because you're executing all these threads at the same time. You don't need a control thing for each thread. You need them for blocks of 30q, right? And so you see, for example, there's a lot more compute units, then there are sort of warp schedulers. And so you're able to do a lot more parallel work without worrying about control. And this is one of the traoffs with cpu's cpu's, a lot more sort of silicon area dedicated control and branch prediction and things like this, whereas for GPU's, much more emphasis on computation with simpler controls. So now we're going to get into sort of sort newer content now. And I think if there's one high level thing to remember, it's if you want to write high performance code, you should remember to benchmark and profile your code. And that seems very obvious. But I've seen a lot of things where students or people go in and they're like, well, I think this is the bottleneci'm gonna to spend three hours optimizing it. And it turns out it wasn't the bottleneck at all. I'm sure it was fun, but it was kind of time that was misallocated. And so if you actually use a high performance or very detailed profiler, you can kind of see exactly where your bottlenecks are and exactly what the machine is doing. And once you have that, you can go and spend your efforts in sort of the most important parts of your code execution. And so that's the high level thing I want na get across because some of the details about GPU execution and how you write a soft tmax kernel that's going na kind of change. And maybe you even want na just rely on the torch compile autojit thing. But the fact that you should profile isn't really going to change no matter what the tools are. So I want you to sort of internalize that idea that you should be always profiling if you want to be writing high performance code. And really, there's a limit to the theory. I think systems is part of this course that you can reason about pretty well. Architecture is somewhat hard to reason about, and you can really think about sort of the roof line model and so on. But you know how fast does your matrix multiply? Well, maybe that depends on the library version, your hardware, like which things are bottlenecking for what reason? There's all sorts of micro code things that you don't really fully know. And so you have to end the end have to do end end benchmarking whenever you're developing these things. Ok, so I'm going to have an example. Computation. This is the simplest thing that we can run compared to all the things that you all are doing in your assignment. One, but I'm going to run a very simple mlp. It's going to have 128 dimensions. It's going to have 16 layers. It's going to have some batch size, and it's going to have five steps. I'm going to just do forwards and backwards for five different steps here. And just to make the code clear, it's something like this. I'm going to define a mlp model and sort of I'll show you that in a moment here, and then I'll define a random Gaussian input, and then I'll run it for five steps in that last case, where I compute some forward, and then I compute backwards and then return sort of the result, which is just the mean of the output of my mlp, right? There's not even loss sses. It's so simple. It's just. You run the mlp forward and I just average pool up the end, right? And then the mlp is just kind of the simplest thing you can also imagine here. It's just a bunch of linear layers stacked on top of each other, which is this bit. And then, you know, I've got a gu in between, right? So this is just gu, linear gu, so on and so forth. Everything is nice and square. So hopefully this is a very simple mlp that you all feel pretty comfortable with. And then let's go back. Yes. Oh, sorry, I want to go back up to here. Okay, good. And so now I have this you know mlp code that I want to run. And now I'm going to do two things. I'm going to benchmark, I'm going to do some timings. So I want to know how long does this function take to run. And then I'll do profiling, which is to go inside the function and ask, you know, where am I spending all of my time? So let's start with benchmark, right? So benchmarking is just the measurement of wall clock time of performing these operations. And I'm only looking for the end to end execution time of, in this case, my mlp function. And you know there are some subtleties to this. Like you're sitting there and you're like, why am I being told how to invoke I don't know the time it function, but you do have to be a little bit careful about how you measure times. And I think if you're not paying attention, you will run into these pitfalls when you do assignment too. And so what are we doing this for? We're gonna to compare implementations later. We're gonna to compare our Triton to our handwritten C++to pi torches implementation and torch compile. And we want to know, was it worth it to write that couda kernel? And wealso like to understand, when I make my matrix multiplies bigger, how much slower does it get? So welike to do some empirical benchmarking of those. So throughout this lecture, I'm going to be using this benchmark function, and that's going to be sort of a wrapper function. I'll step through it. Benchmark is going to do the following things, right? It's going to have a function that I want, a benchmark which is run, and then I'm gonna to do some number of warm up iterations and then I'll do some number of trials, right? And you might wonder, okay, so like what's this warm up thing that we're doing here? Well, one thing that's really important is know do when you first run your pi torch code and let's say dispatches something to the GPU, it might look very fast and transparent to you. But that very first time something is executed in the background machine code is being compiled. You know that code instruction might be being sent to the GPU. There's all sorts of things that happen to sort of initialize your code. And so you always want na do some warm up iteration to make sure that you're not measuring sort of the startup speed. Instead, you want to measure kind of the steady state speed. If you're running thousands and thousands of iterations, what you're interested in is that part, not necessarily how fast can you do on the fly compilation of your couda code. So that's why we have warm up, and you should always have a bit of warm up. And then another thing that's really important, and I'll get to this once we get to the profiler, is you want na call this thing called torch kuda synchronize. Like what is that? Well, the GPU and the cpu are basically two independent compute units in your computer, right? And they can basically run kind of independently. And so their execution model is going to be this Python code that I have here. This lives on the cpu, right? And when I run something, it's going to dispatch a bunch of cuda kernels to the GPU. It says, please run these things for me. And the GPU will go off and execute those things. And the cpu will actually go on and keep running, right? It doesn't wait for those couexecutions to stop. And so that's great for writing high percode, but you should hopefully see the immediate problem if you want to do benchmarking, right? If you're benchmarking and you've got this model where the GPU runs off in the side and your cpu is doing something different, you're actually not measuring the GPU execution type. So torch kuda synchronize basically says, all right, let's make sure that the GPU and cpu are in the same state and there's sort of no cued things running and that we're kind of at the same point in terms of the code that's being executed. And now the GPU and cpu are kind of in the same state and I'm gonna to time it for real, right? And I'm gonna to time something for some number of times and I'm gonna to run the computation, which in this case is the sleep command. I'm gonna to do it three times. And since I'm trying to sleep for 50 milliseconds, that's the time that I'm gonna to kind of get at the end. So I do time that time three times. And of course, here, right, I'm also calling torch dot kuda synchronize at the end of run to make sure that the GPU and cpu states are the same. So the cpu running ahead, it's gonna to wait for the GPU execution to actually finish here and vice versa. And so now I sort of finished and then I'm gonna to average because you know each single measurement might be you know fluctuated because of things like thermal properties of the GPU. And so you want to take multiple replicates, take the mean and return that. That's our benchmarking code, right? Very simple. But remember kind of the two important pieces here, right? Always do a warm up. Make sure to call kuda synchronize if you do those very simple. If you forget to do those you'll get pretty crazy numbers like you'll get that your big matrix multiply finish instantly which is definitely not right. Okay, so now we can do some benchmarking of matrix multiplies. I'm gonna walk through some of these they're just putting numbers to things that we already know but I want na just walk through it and make sure we're on the same page here, right? So I ran this on the class H -100S. I have GPU, so I'm going to do matrix multiplies over these sizes and then I'm going to go and collect a whole bunch of matrix multiply timings for each of these dimensions, stepping through kind of this benchmark result. And so we kind of see, you know as we expect, right, super linear scaling of our run times as we increase the matrix size, of course, at the smallest size, zes, like 1024 and 2048, we actually see that the times don't grow at all because there's constant factor overhead in just doing these matrix multiplies. Like these numbers have to get shipped from the cpu to the GPU. There's overhead in launching the kernel. And so it's not the case that you it's super linear all the way to zero. But once the matrices get big enough, we see exactly the kind of scaling that we expect to see with our matrix multiplies. Okay, so hopefully straightforward. Now let's try to benchmark our mlp. So what are we going to do? We're going to make our mlp bigger. We're going to have 256 dimensions. We're going to have four layers back, size of 256. Take two steps. And so what's the time that it takes to do that? Well, it's going to take 6.2s to do that. And now I could do some basic things. I can scale the number steps from two to five, and I can benchmark all of those, and I will get two, three, four and then five steps. And unlike in the matrix multiply case, if I'm scaling the number steps so the number forward and backward passes on my mlp, what do I expect the runtime to behave like? Well, I expect sort of linear scaling, right? And that's kind of what we see. There's about 5s per mlp execution and we see it's about n times five for the runtime of kind of the end d object here, right? Okay, let me see if I can reset the thing that's being monitored here. Oh Nope, I can't. Okay, I'm gonna to zoom out a little bit. Sorry about that. Okay, now we can also scale the number of layers from two, three, four to five. And what does that give us? Well, it gives us you increasing run times, once again, linear in the number of layers, right? This time, once again, one layer takes about 5s, a little bit less than that. And so we get about four times, actually four times the number of layers and linear scaling sort of shows up. Again, unsurprising, right? So both steps and layers obviously have linear relationships with the runtime. And that is exactly kind of what we end up seeing at the end here. I'm going to skip the batch size thing because this is getting a little bit unwieldy in terms of the amount of things that are being tracked here. All right. So that's the end of this benchmarking bit. We can kind of make this nice function that does a little bit of warm up this kuda synchronize, and we can measure the runtime of anything that we want. And this is good and you should do this all the time in your code, right? You can measure how long it takes for your new fancy architecture to run. But then I think if you want na fix some problems, benchmarking is a very coarse grain tool. It tells you that your code is slow, but it doesn't tell you where the time is being spent. And so what we would like to do is instead do profiling. And so this is going to be a much more fine grained object that we're going to want to do. And so profiling is really nice because it not only helps you see where the time is being spent, which functions, but when you look at what you're calling, usually you interact with the pi torch interface, like the parts of pi torch that you call. But beneath pi torch, there's this whole universe of kuda stuff that's being called. And when you run a profiler, you can actually see all the way to the low level calls, what is actually being called. And so you can get a much nicer intuition for how the program is actually being executed on the hardware. And so we'll step through profiling, a few simple functions, and then get a little bit of intuition about what is happening. And so one of the things that is nice is that if you want basic profiling, PyTorch has a very nice kind of built in profiler that you can use. And this will allow you to not leave the Python PyTorch world and get some fairly reasonable looking outputs. And so I've profiled some functions here, and you can kind of see the output of this as well. And so I've taken the sleep example from before, and here is the sleep function. And when we profile the sleep function, the profile function looks something like this. Now I have a warm up again. I have torch couda synchronize, and then I call the profiler and I'm tracking both cpu and the GPU times. And then I run something, and then I synchronize again and I print out the average table across all the time. Okay, so I go back now. So now I'm going to profile the sleep function. And if we look at what's happening, what happens here, well, 100% of the time is being spent on something called kuda device synchronize because there's no GPU work being done. This is just kind of a noup. You know, it's kind of a silly thing to be profiling. And so now let's look at something kind of non trivial, right? So let's look at this basic operation here of adding to matrices, right? So I defined a add function that takes an a and A B and adds them together. And this is a helper function that instantiates two random Gaussian matrices and then invokes know whatever is in the operation argument. So this is adding to 28 size matrices together. So now I'm going to profile this, and I'm going to call the profiler and I'll get back something that looks like this block over here. So this is what I get back and I'm going to have to zoom back out because this is not going to be righok. Is this visible from the back? Can someone give me a thumbs up if it's visible from the back? And okay, good, good. Or thumbs down if it's not. All right. So when we call the ad function in Python, this is kind of all that we interact with this ad function, a plus b, that's all we think about. But actually underneath here, underneath the iceberg, so to speak, there's a lot more that happens. So this gets dispatched to the GPU and first there's this thing called a ten, which is the c sort of interface for pi torch. And so this wrapper gets called and it says, okay, I'm gonna to add some numbers, right? This is what's being called. That's the outer wrapper. And then that dispatches to a particular kernel called vectorized element wise kernel four comma at native kuda function ture ad right? And this is the thing that's actually doing the adding. And then there's this also other thing called cuda launch kernel that's taking some time. And this is actually the cpu is taking the command and sending it over to the GPU. That's the kernel launch. And that takes some time. And then finally, the kuda device synchronizes. We're waiting for the GPU to finish and send things back to us, and that also takes some time. The mere act of having a synchronization barrier is going to cost us some time. And so we basically have you the time total in the end here, 1.4 milliseconds on the cpu and 17μs on the cuda, right? So it is really fast than the GPU, slower on the cpu. And if we're looking at the cpu time that's being spent, which is the self cpu time, we see that kind of the C++interface or the c interface is actually the thing that's costing us a whole bunch of cpu time, and they're sort of overhead to doing anything where we're sending stuff over to the GPU. So that's the add function. And we see what's happening under the hood. Same story here. If I want to do a matrix multiply, so I'm doing know a multiplied by b. So this is a matrix multiply of a and b. You know I'm doing 20, 48 matrices once again. And then I do profiling. Now this time I see you know a ten map mole. So this is saying like this is the lower level interface to do a matrix multiplies and this is going to dispatch the cutliwhich is nvidia's sort of high performance matrix multiply cuda library. And then it's dispatching to a very particular cutless kernel, which is going to have some pile size, but names are truncated here. I'll show you a more detailed version in a minute. You know this is basically pointing towards a very particular set of tile sizes and the number of blocks and so on. And so this thing is parameterized, and that's actually doing the matrix multiply. And once again, we see the same two things at the bottom here, the kernel launch and the synchronization of kuda devices. And you can sort of see once again, the cpu time, kuda time split. And we're spending way more time in cuda because you know matrix multiplies do take more time than just adding two vectors. Okay, any questions so far? I can pause for a moment here. I think I've just been going sort of very quickly and on my own through the profiler. So if anyone has questions, I can stop for a moment. If not, I can keep going. Okay. Oh yes. Have a barrier that like sent for the cpu to wait synchronize. And so by that, shouldn't the cpu type the Yeah, I don't think it's counhappens to be vital. Cool. Oh yes, sorry. There's two questions there. Any particular reason why when we switadding the is there a reason why when we go from adding to Matt Mothe cpu time goes down? That I am not sure to be entirely honest. Yes. Is there overhead in the profiler that can distort things compared to running it in the real world? Yes, there is overhead in the profiler. Like the barriers will do that. I'll show you a more advanced profiler from nvidia and you can add things like annotations that will also slightly distort the timings, but not by much. The really large scale things that you see aren't going to be really distorted by the profiler. So if you're looking at like micro timings, yes, probably, but a lot of the things that we care about in the class, no, yes, I'm just too make sure I mean, you're reading this correctly. So is that like for the ad case cpu utilize over? That's that's right. Yeah. So this is the percentage of time you can see the actual millisecond time that a ten ad was actually executing in some capacity on the. Of what the cp is doing. Yeah, right. This is a title. The cp is active, not percentage utilization. If that's Yeah so this is not like the total amount of cpu flops or something. This is a total percentage of time that the cpu is doing something. Yes. Okay, cool. All right. Here's another example of a map moso. This is a different dimensionality, right? So this is, I'm multiplying 128 dimensional matrix here. So 128 by 128, much smaller. And you'll actually see that now it's actually directly executing sort of this different command that's executing xmgmgmm is a matrix multiply type. And this is float 32. Float 32, you can kind of see from the naming of this kernel what's actually happening here, which is that this is a tiled matrix multiply of some kind. And it's not sort of going through cloudless. It's executing this particular command directly. And so for a small matrix multiply, you see that as dispatching to a different kernel now, so you can kind of see kind of the complexity of matrix multiply when we're operating at this high level abstraction, we just think of matrix multiplies a single thing, right? We call like a at b and we're done. But underneath the hood, depending on the dimensionality that you have, depending on the hardware that you have, it will actually dispatch to very different matrix multiply sort of primitives under the hood, and that will actually manifest in very, very different sort of performance characteristics. And so one fun tip is torch compile, which I will talk about later, actually has an option to sort of micro benchmark the matrix multiply performance on your hardware. And then it will actually then pick the highest performing matrix multiply subroutines for your model, which you know in the past I found you, gives you like 10% speed ufor free. It's very cool that optimizing for these things actually gives you free gains out in the real world. Okay, so that's another map, moexample. And so the cool thing about the profiler compared to just the raw benchmarking is we can now kind of see which kuda kernels are being called. We can see that know different sizes of matrices lead to different kuda kernels. And we see know colus 80 simkey s gem is this cutless linear algebra library. And it tells us things like the tile size. So so far, these operations are very boring in a way, like matrix multiplies and adds. They're basically one to one. You have a operation on the cpu side, it translates to a GPU operation, and it just gets shipped over, right? So there's just a single operation in all of these that does anything on the GPU. So I want to look at some more complicated operations, two more of these that have sort of more compound behavior. So what I want to do now is I want do I want to look at this operation called torch c disc? And this is computing for two sets of matrices, the pairwise Euclidean distance between two sets of vectors. So this is going to be a big distance matrix computation between a's and b's that I want. So that's c Dist. And so this is obviously a much more complicated operation. If you want to compute Euclidean distances, you're gonna to need to compute dot products, you're gonna to need to compute square roots. And we're gonna to see that once we compute cdist. So now here is the profiled output of cedist. So we see that this torch Python command does map in the c interface to some sort of lower level c Dist. So this is a ten c Dist, which then maps to a ten Euclidean disst. And then this will decompose into a whole bunch of things, like a ten Mamal, a ten pal, and then some, because these are all primitives that you're going to need in order to actually to compute the Euclidean distances between all of your vectors. And for each one of these matrix multiplies in concatenation and taking the powers, you have a corresponding cuda command that is being called here. We have gmm m, which we've become familiar with. So this is a matrix multiply. It's taking 78% of our compute time on the GPU. We've got copies and sort of concatenation of arrace. This takes 6% of the execution time. And then this sort of vectorized element wise kernel, which is taking the power, takes 5% of the GPU time and 3% goes to the sum. So now we got this very nice low level breakdown of where you know my GPU is spending all of its time. And from this, you know I can get some sense of where maybe I should spend my time optimizing. You know maybe I think I can optimize my matrix, multiply. That would be great because that's 70 plus percent of the time spent in the GPU. The final example, the final two example, sorry, that I want to talk about is gliu and softmac. So these will be our running Oh, sorry, there's a question. Okay, so I will maybe answer that question in a few minutes because there's a cooler profiler that shows you a much nicer picture. And so I can just sticulate here, but I think itbe better to show it up with pictures. Okay? So I'm going to talk about now the gu and the sofmax. So the gu is going to be our running example throughout the class. So this is a nonlinearity, if you remember, it's the Gaussian error unit, Gaussian error linear unit. And that's going to be a product of a can H and a exponential, if I remember, right? And so we're gonna to have you know all sorts of operations. So we're going to add amb and then we're going to call gu sort of simulating the linear plus nonlinear structure that we might have in our mlp. And so we see, once again, basically the same sort of mapping. We see a ten ad corresponding to a plus b, and then we have the couda equivalent. And then we have actually a galu function implemented in cuda, which is all the way down here. And that takes about 33% of the compute. Okay, fairly reasonable. And then we have 1s, the softmax. I won't go through all of these in sort of gory detail since you they all start to look the same after a while. But the thing to really point out that I think is cool is that a lot of these really core primitives like softmax and gu, there's just kernels written for them, right? So it's not like the GPU is executing the basic primitives. There's sort of a fused operator that computes all of this. So there's no back and forth between cpu and GPU for all of these. So okay, I mentioned before that I was going to sort of answer this question on what the cpu was doing. And so let's think about something a little more sophisticated, right? I had the mlp example that I started with for benchmarking, and I would, let's say, like to optimize the mlp, make it run really fast. So how can we do that? Well, ideally, we would sort of profile this in a nice sort of fine grained way. So if we use the torch profiler, this is kind of what we would get. If you remember the mlp, there's a stack, linear layers. There's a Ford in a backward. And you see roughly, you know there's this backward thing that's happening. There is a matrix multiply, there's linear, and then there's accumulate grad operation for the backward. And here's the matrix multiply kernel. And then there's only ten things that can fit here. So I think this gets cut off at a certain point, but this this is nice. It does tell you that most of the time is being spent in the matte moles, but you do kind of wonder, like where does all the rest of the time go? And why does only 31% of my time stay here and where is the 60% here? It's the 8:10 mm, but there's no corresponding kernel, right? This is a little bit mysterious. And for something that's very complex module, this is not a very good visualization. And so for that, I think we have to actually get out a real of grown up profiler. And you will have to where we will ask you to look at this thing, which is nvidia's insight systems. And this is the kind of nvidia's sort of detailed way of looking at GPU behavior and performance. And so we will actually kind of see exactly what is happening as we run this mlp. So actually in the back, can you see, I don't know this tiny text over here. Thumbs up. Okay, all right. If you can see it then I'm not gonna to zoom in but it does seem small even from here. All right. So basically if we look here, we see several different things. We see kuda hw over here and then we see threads. And so this top half, this kuda part, this is what the GPU is kind of doing. And then in this threads part, we see kind of what the cpu is doing. And I guess also pull up the code. I think, yes, the code here, when I profiled it, I've added a few annotations. This one I need to zoom in for sure. Excellent. All right. So I've annotated the code with this set of things that says, let's see nvpx, which basically annotates my code with markers. So when the profiler comes in here, it will know that this piece of code belongs to a block called define model. And for example, this part that says step, range push and range pop, this range here from line 77 to line 55 should be annotated with something that says step underscore step. So I've added all these annotations in my code before calling my profiler. And so let's go back here. So now if we go to this line that says nvtx, we can kind of see define model, which is the thing that I wrapped my model construction call, and then I see step zero, step one, step two, step three, step four, step five. So each step is now nicely annotated in this profiler. And we can kind of see all of the things that the model is doing as it goes along. And I'll start on this side. One thing we see is that this piece of code, it doesn't do very much work. It takes only 14s. So actually, most of the time for the profiler is spent on overhead. So the part up until roughly here is things like just loading the libraries. And that takes a long time. It takes apparently 7.5s to initialize everything. And then at least on the GPU, at 7.5s or so into the program, it starts actually building the model. And you see here on the memory footprint, you know this is the place where now memory is being sort of allocated. And on the GPU memory, the memory usage starts to grow right now. The model is now constructed at this point. And then step zero is where sort of the action starts to happen. And so you were asking earlier what's happening between the cpu and sort of GPU. And so how the execution model of this works is here is sort of step zero on the cpu, and I'm starting right here and here's the forward PaaS, and this is layer zero. So let's just kind of think through what's happening. As I said before, when you first encounter or when you first call a piece of code and pi torch, it doesn't just directly execute. It will actually do things like you know on the fly, compile things. And so know this thing like runtime triggered module loading is sort of overhead work that's being done in order to just initialize the layer and the computation and move sort of various bits of code into the GPU. So this takes a long time. And then after this layer zero is done, now if I look at sort of any slice here, let's sort of zoom in to selection. We'll see that each of these layers is really, really, really quick. And what happens here is when I highlight this layer one over here on the cpu side, notice that that's not where layer one is on the GPU side, right? So as I said before, the cpu and GPU are kind of two different execution devices. So I start at layer zero. I'm done with layer zero. I start layer one now. The cpu is actually just sending all of the sort of cuda commands, the couda kernels. It's launching all the couda kernels already to the GPU at this point, right? So when the cpu is saying I'm doing layer one, what it's actually doing is it's queuing commands into the GPU. It says, now run this thing next, run this thing next, run this thing next, right? And so the cpu is running way ahead of the GPU. And by the time layer one starts executing on the GPU, actually we're already at layer nine on the cpu the cpu is running way ahead. And there's basically a quethat the cpu maintains where it's sending a fixed number of a kernel cuda kernels to the GPU. And so once you hit that q depth, it's going to sort of stop running ahead. But until that point, it's just going to keep going and going and going as far as it can, right? And in this case, this does become, I'm going to zoom out again. Undo the zoom. There we go. In this case, this kind of gets a little extreme because if I zoom out once more, notice how you know, in these steps I'm running way ahead. Like the step zero is here. Step two is here. This was step one, which basically took no time at all. Step two is here. So the cpu is basically running one entire step forward and backward ahead of the GPU. One interesting thing that you might do is if you're writing various code for training a language model, one normal thing that you might do is let's go back to the code. I might do something like print my losses in between iterations. This seems like it should have no effect on what the GPU is doing, right? You're like, well, it's a print statement. How much could it do? If you think about it for a moment, this will have big impacts on the execution layout on the GPU because in order to print this statement, this print statement happens on the cpu, and the cpu needs to get the loss. That means it needs to wait for the GPU to compute that loss. And so let's look at what happens. So here, as I said, you know step four on the cpu happens way before the GPU equivalent. Now let's switch back. Now this is the version that I profiled where it has the print statement, right? And then now I sort of zoom into selection here. Now see how step one and step two are basically kind of synchronized now, right? Because I have to wait for the loss to get computed. And you look at this and you say, Oh, but it's still a little offset, right? Like step two. Step one isn't exactly aligned with each other. So now let's kind of zoom back in and see, okay, what happened to step one of the cpu? Well, basically the end point of step one on the cpu is also kind of where the optimizer step starts, right? So by the time that Ford is done, sorry, this kuda stream synchronizes the thing. So this kuda stream synchronize command on the cpu. This is basically saying, I'm just waiting for the GPU because I can't run ahead. I'm waiting for this loss to be computed and to be sent back to me. So this is kind of a dummy operation where saying, cpu weights, weights, weweits weits weights, weights. Well, the backward step is done. So now I can print the loss. I've printed the loss. Okay, now the cpu can start running ahead and it does run ahead and start sending step two stuff now and then. Well, once this hits here, it's sort of run out of commands. It's waiting for the loss again. Kuda synchronize. Wait, wait, wait, wait, wait. Backward step is done. Now I can print the loss. Now I run ahead again. Right? So in this case, the GPU is still essentially full utilization in both cases. But in extreme cases where, let's say you're printing tons of stuff all the time, actually you're going to introduce a cpu bottleneck because the cpu has to keep waiting for the GPU and you can't launch the kernels sort of ahead of time. So that's kind of a really cool thing that you can see with the profiler, sort of the cpu versus GPU, and they're actually different devices that communicate to each other. It's not at this single unified object. And then you wouldn't see that unless you started to look at some of these like more advanced profilers. Any question about that sort of set of things? Cool. Okay. And the other thing that I want to kind of show you is you know the profiler thing that I was playing with before. You can also generate very similar views in nsis as well, where you sort of select some range of things that you want. Na, let's do a warm up. I said we should so we should exclude the first couple steps. So we'll start a step three and we'll measure some steps sort of in this range. We could take the kernels. This is what's doing the computation. And you can see that there's actually many different kinds of matrix multiply. This is one matrix multiply kernel. This is a different matrix multiply kernel. There's a different sort of like vectorized element kernel. And all of these are taking different amounts of computation. And we can take this and we can say, Oh, show me in the events view all of the things that are happening. And I can also see sort of the stats view all of the time that it takes. Wait, let's see. We want the average time. No, we want sorry, the cuda kernel execution summary. Yeah we want the total duration of the kernels. And so we can see which kernels are taking the most time and aggregate across these views. So this is actually a very, very powerful tool that can give you both like the aggregate view of what's slow and what's fast, as well as individual kernels that are being launched and when they're launched and where the cpu commands for that came from. And I guess one final side note here is this is one of the reasons why, you know it doesn't matter that we're programming in Python, and Python 's not a very high performance language, right? Because the cpu is never the bottleneck, because the cpu can run ahead and sort of cucommands into the GPU. And so this sort of detaching or like this disconnecting aspect between the GPU and the cpu is one of the key reasons why we can use this nice high level programming language and yet still get sort of full utilization out of sort of our GPU's. Cool. Okay. Any questions before I sort of switch back to this? Because I'm going to leave ensis sort of forever for this lecture at this point. Cool, Yeah but you'll get to play with it in assignment too. And I think you'll appreciate it because it gives you like a really interesting view into what your hardwis actually doing to make these language models train. Okay, that was benchmarking and profiling. Now you have all the tools you need to be able to do sort of performance things. And now we're gonna to write some kernels in the remaining time. So remember kernel fusion, right? So this was the image that I showed you in lecture, right? There's a little factory. Every time I need to do an operation, I need to ship it from the warehouse to the factory in back. And so if I naively do a bunch of operations in sequence without thinking about it, I'm paying for a lot of sort of shipping costs back and forth from the warehouse. What I should do is have one factory that does all the operations at once. So I do not pay for this cost multiple times. That's very important. So now we're gonna to do gu, and we're gonna to write a kernel for gu, and I'm gonna to write that kernel in several different ways, and we're gonna to look at the performance impact of doing that. And so we have the pi torch implementation of gu, and that looks just like this torch and n functional galu. And I invoke approximate equals ten H because I want this to exactly match the naive thing that I'm gonna to do next. So this is not going to be actually multiplying by the cdf of the Gaussian. It's gonna to be some approximation to that that's easier to compute. Okay, so that's the pi torch galu. Now I'm going to do the dumb thing. You're going to look at this code and say, this is going to be low performance. I'm going to go in and in PyTorch, I'm going to write galu as zero five times x times one plus ten H square root pi over two times x plus 0.044715 times xcuright magic formula. But this is a good approximation to the galu. You can look it up or convince yourself this is but if you do this, you see that there's a lot of operations that happen. There's like a tan H, there's an x cube, there's multiplication by a constant in addition and multiplication by 0.5 and x. If this involves multiple different cuda kernels, this is probably going to be slow, right? That should be our intuition at this point from fusion. So let's see if that's okay. So these two are the same. You can see at the top left, they compute the exact same numbers. And we can systematically check this on random Gaussians. And now let's sort of benchmark the two. Okay, so the manual time is 8.1s for a really, really big gu, and pike torch time is 1.1, right? Millisecond ds, sorry. And the fuse version is going to be significantly faster, in fact, eight times faster. Wow. You big difference from writing a simple kernel. Of course, your matte moles are probably still going to be the bottleneck, but it would be really cool if we could go from that eight milliseconds to that one millisecond, right? That would feel very satisfying. So we're gonna to try to get close to that 11 millisecond in the next few parts of the lecture. So now let's look at the what's happening under the hood. I don't need to look at ensis because all I really want to know is some very high level stuff for the manual gu, you know, kind of just like I said, it's going to do a whole bunch of operations. It's going to do a bunch of multiplications. It's vectorized, but it's a bunch of you know kuda kernels being launched here. And notice on the right, this kuda kernel gets called three times because we have a whole bunch of multiplications floating around here. We've also got no addition. We've got a ten H. And each one of these is probably kind of slow. And in the end, now we're incurring fairly large overhead doing this. Now let's do the same thing, sorry, with the pii torch gu, and this is really great. There's a single kuda kernel launch. It happens once, and it just processes the whole thing. This is what welike to see. And of course, this is very, very fast because it's just a single cuda kernel, right? So this is really nice. And we would like to somehow get to the kuda kernel. And so the first thing you might think of, depending on how much you know about writing GPU efficient code is, all right, the PyTorch people must have written this in the lowest level language possible. So we're gonna to do the same thing. We're gonna to go to not the lowest level possible, but we're gonna to go to the C++api and we're gonna to write the cuda kernel in C++. So let's open it up and write our own cuda kernel. So how is that going to work? Okay, so we have gone in and sort of created A C plus plus version of the whole thing. So kuda, you know when we say kuda is actually the C++api for interfacing with and programming GPU's and just like sort of the logical model of a gpuu that we describe, know we're gonna to write some sort of function f, and then when we sort of invoke this cuda kernel, it's going to automatically call f on all the elements of a vector or a matrix, and then we will get to parallel compute everything that we want as nomenclature. We're going to have a grid, which is a collection of threadblocks. So think of this as I have a task, I'm going to cut it up into pieces, and there's going to be a number of blocks. This is in the 2D grid, for example. There's going to be sort of a row coordinate, and then there's going to be a column coordinate. And this will be very useful if you're working with matrices. And then there will be the size of each of these blocks, like how big are these in terms of the number of threadblocks? So this is the dimension of the blocks. And then there's a collection of threads within these blocks, and this is the coordinate that, for example, one thread block lives in. And then each thread is within each block. So there's sort of hierarchical structure here. There's a grid, and then there's a thread inside a grid, right? And then we're going to basically each function is going to take in three things. It's going to take the block index, like which thread block do I belong to? Which what's kind of the block dimensions, and then what is the index that I am like my thread index? And with these, I can kind of know which coordinate that I am in in the matrix or the vector, and then I can sort of decide what logic that I want. One sort of last thing before we go through the actual C++code is you know whenever you're trying to debug cuda, you want to launch with kuda. Launch blocking equals one. This will allow you to actually debug your kuda kernel. It will give you sort of error messages back at a cost in terms of the runtime. If you don't do that, you are going to have a bad time if you're writing kuda code and needing to debug. So okay, here is my gu code, and let's go through it kind of piece by piece, and then I'll talk about what all the pieces are doing. This will probably take the longest out of the things that we're gonna to walk through other than the machine code. And once you understand this, you should be able to understand all the other pieces. So we'll go through this a little slowly. So there's two parts of this code. So the first part, this gu kernel piece up here, this is the actual kernel. This does the computation, right? This goes going to get sent to the GPU. It's going to do the computation and then it will return the results. This piece, the gu function here, this is a wrapper, right? This lives on the cpu. It's going to orchestrate the launch of the kernel, which is actually going to go out and live in the GPU, right? So maybe we can start with kind of this sort of wrapper piece, this gu function first, right? So we're always gonna to check two things. Basically in the Triton or the kuda code, we're always gonna to check, Oh sorry, there's a question back there. Can't okay, sorry, that's my back. Okay, that is an easy fix, but I needed to know that you can't see. Okay, good. All right, is this good? Okay, excellent. Okay. So we're going to start with the galery function and there's two things that we're always gonna to need to do. The first one is to make sure that x lives in like the GPU device, like a cuda tensor of some kind, right? If it's not, well, well, that's gonna to be a problem. We're not gonna to be able to do anything on the GPU. The second thing, which is maybe less obvious, is that we want to check to make sure x is contiguous. What that means is it lives in a contiguous block of memory, because when we index into x, we're going to do a whole bunch of indexing arithmetic, and we're gonna to assume that x lives in a block of memory. And if it doesn't, it's just going to be basically impossible to do this with any level of generality. And so when we compute the galu, right, we take an input x and we're going to output A Y, and so we need to allocate an output. So torch tensor y equals torch empty like x. This is just saying, well, give me sort of an output tensor space or a pointer to an output tensor that is just like the dimension of x and notice that I'm not calling zeros. This will save on extra operations. I don't need to zero out these wise because I'm gonna to write into them anyway, right? So this is a minor, but you might as well do it optimization. And then basically in all the code that we write, we're going to need to figure out the grid, right? So what's the total number of elements that I have? What's the size of each block, the number of threads that I have in each block? And then how many blocks total do I have? And when I need to figure out the number of blocks, I'm going to call c div, which is going to be essentially take the ratio of nub elements to block size and then take the ceiling, right, because they need to round up to make sure that very last set of elements that sort of isn't divisible by box size still gets computed, right? So I take the ceiling rather than the floor, and then this is all very simple bookkeeping stuff. And then I say, all right, launch the kernel. You know, the galu kernel gets launched. And this sort of angle brackets is saying, this is kind of with the given number of blocks in the size of each block. And this is going to be passed into sort of the kernel command. And then I'm going to PaaS in the pointers to x's and y's. I'm not actually going to PaaS the values of x's and y's and the total number of elements. And I need this to compute sort of essentially the boundary conditions of my kernel. So now let's go to the actual kernel itself. So I have global void gu kernel, and I get in pointers for in and out. And I have number of elements, items, and this keyword, global, the website sorry, the rendering here has mangled it a little bit, but you should think of this as underscore underscore global. And this is a key yword that distinguishes it as a couda kernel function. And so what am I doing? Well, you this thread is actually supposed to operate on a single element. I but I don't get I as input. Like the code doesn't actually tell me you're in a vector in coordinate. I so I need to compute where I am and how am I going to do that? It's gonna to be I take my block index, I only have one dimension, so it's block index x. So just the first coordinate and then multiply it by the size of each block. The block of dim got x and this tells me basically the starting point within my current block. And then now I add in thread idx, so I know where the start of my current block is, and I add in the offset to where I am within the block. And that gives me my global coordinate. I so some bookkeeping computation just to get the coordinates here. And then this is important too. You see this pattern basically in all the kuda code that people write. There's no kind of out of bounds checking naturally. And so what you do is I have my coordinate and I'm going to check to make sure that you I am supposed to be processing something that's in bounds. And some of the threads at the very end of your block, they're going to be processing stuff that's out of bounds and memory and you do not want to touch those. And so you basically condition it on I less than nuelements and you do nothing if you're outside of that. Sorry. Yes. Sorry, this is just the extension that you sort of write the couda code in. It's to distinguish it from just your standard secode. Okay? So this is just a final namthing. Is this cu? There's nothing particularly special about it. Okay. And this so now you know, within here, we're going to just do our computation, right? It's just going to be, I'm going to write out, I have my input in, I'm going to index into the ielement, and I compute my gu just like I did before, and I assign it to out of I, and then I'm done, right? That's all that's all that I need to do. And since this is all pointer stuff, I don't really need to worry too much about what is kind of actually happening here. So that's basically it. I can then take my sort of kuda gu code that I have, and then I can load this sort of C++code in line, and then I can just have it compile into a module all within Python. It's all very nice and convenient. You don't really have to go out onto the command line and do things. And so now we have kuda gu defined. So this is nice. And basically it's a compilation of this, and I can call it from within Python. And we'll use the c bindings to call this guy. Okay, we're done calling kuda gu. I have my, you know, I can check that the manual gu and the kuda gu are the same. And now let's benchmark the two. So I have the time that it takes to run PyTorch. And you know, just like last time, it's about 1.1 milliseconds. And manual time, remember, is 8.1 milliseconds. And so drum roll, what is our couta time? Well, we've gotten it down to 1.8, right? Not quite as good as pi torches implementation, but you know we're getting pretty close to pi torch time, right? We've gone from eight milliseconds to 1.8 milliseconds, which is which not bad because that c code wasn't that hard to write. And so now we also do some profiling and we can kind of see what is happening here now. And it's called the gu kernel, right? This is the code that got shipped off to the GPU and then it's calling empty. Like this is the initialization and then empty strided, right? And then kuda launch kernel and kuda device synchronize. And that's basically all that's happening. And notice how you know, once again, this is a single kuda kernel eats up 100% of the GPU time, kind of like we what we want it, right? So there's some further optimization we can do, but this has really already solved the problem of kernel fusion. We fused all the operators together, so pretty good. These kinds of element wise operations are easy to write in kuda. Like if you have a new kind of, I don't know, nonlinearity, you could easily write a couda kernel for it yourself if you really want it to. But more interesting operations are gonna to require reading multiple values, like doing reductions. Those are gonna to get a little more complicated. Flash attention will be a little bit more complicated, but not too much so when you have to do it in the assignment. Okay. Any questions on the simple C++couda kernel? Yes, the beginning. Yeah. So so the question was what happens if it's not configuous? At least in the code that we wrote, it will just throw an error because it's an asserp. You could potentially write code to handle it, but there's almost no reason for memory to be fragmented because it will allocate contiguously and you won't dellocate like the middle of a memory unless you're doing something like really tricky. And so you you should really, unless you're doing something pretty advanced, expect to have continutiguous memory like the transthere's some offers first, but it's never not continuous. So like when you're coding at a higher level, care, which Yeah. So the question was like if you're transposing, then you're no longer going to be contiguous. You're going to have like a you know jump between all the elements in the index if you're sort of road traversing something that's sort of column stored. Yeah. So something transpose or like views or like essentially shuffling dimensions is like the one exception to this. But that's handable in like the outer like sort of the wrapper part, right? You can basically PaaS it something that is contiguously indexed. And for a lot of the matrices, you won't really care, right? Yes. Right. So what would happen if you chose a different block size, the sort of GPU related sort of concerns would kick in, sort of like do you have enough blocks to saturate your sms? And do you have enough work within each block? And those are kind of the two things that could matter here. But I think my guess is that for block sizes that are relatively large, like ten, 24, it probably won't matter past the certain point because we're not doing anything advanced. It's all entry wise operations for this like very, very simple example. Yeah is the reason that our. To like do a small operation and fsome jehere set it back and then get mainto those small operation. So the question was like why was our nakuda kernel sort of manual thing so slow? It's not that it's sending things back from GPU to cpu per se. Like x is going to live in the GPU. We allocate it in GPU like we'll do like as device like kuda, but it's gonna to basically not be in the sm the whole time, right? So once we do like x squared, right, that's a you know a cuda kernel. And so that multiplication operation will read the sort of vector from the global memory into the sms, do the computation, itwrite it back. And so this is all in the sort of draram to sm communication cost rather than the cpu to GPU communication cost. Of course, if you write like as device cpu, then you'll get the cpu transfer cost in addition to the draram transfer cost. Okay. So now you've seen that in like, okay, so that was not too painful, but it would be really nice if we had nicer sort of Python abstractions for rikuda kernels. And this is what Triton is. And Triton is quite nice. It has this very nice middle ground where you don't have to manage literally everything about the GPU. So Triton is sort of a domain specific language developed by OpenAI in 2021, and it makes GPU programming much more accessible. So you write everything kind of in Python, and you don't really think about the threads anymore. You think about thread blocks. And Triton manages a lot of stuff that is annoying but can be automatically optimized so it can manage a coalescing of memory. So remember that from dram, you get four sort of adjacent values at once with something called burst mode. So you really want to make sure that you know your memory retrievals are sort of grouped into adjacent sort of four element or more sort of calls at once. So it will handle those automatically. It will group those. It will do shared memory management. When you need to sort of manage which sort of memory that you're writing to within the sm with multiple threads from within each sm, you know you might need to stop or start threads all managed automatically. But scheduling across sms or what different sms do, that's manual. So the kind of the programming model is that you're going to think kind of at the sm centric level. And the compiler will handle a lot more of the lower level details. And trianis quite nice because it can outperform by quite a bit a lot of PyTorch implementation. So it's kind of like going all the way to writing kuda, but you're still in the very familiar Python land. And I think a very underappreciated advantage is sort of as it's written here, it's all in Python. You can step through it. You can kind of debug it fairly nicely. And so let's step through a Triton kernel. Like once again, we're gonna to write gu and we're going to do it in Triton. So this, I've put the code to be as similar structure as possible to our other code. So this is sort of the cpu side code, so to speak. This is the wrapper Triton gyuu code. It takes in x, which is a torch chansor. And I've got my two asserts at the top, and I'm going to allocate an output tensor y using empty like once again. And it has the same exact sort of coordinate computation sort of components. And even the kernel launch looks very similar. I've got this numb blocks annotation, and then my block size is at the end here, not in part of this brackets, but basically I'm passing the same information to my kernel. And now try and gallukernel is this code over here? And this is going to do the same thing as what we were doing before. But now it's nicely written in Python. And the mental model here is the inputs are going to be x pointer. Y pointer is the output vector sort of the starting coordinate. And the block size is how big each of my blocks are. And none elements is going to be sort of the very end of my array. So now I need to get this set of lines five, five, seven to 561. This is doing the computation of my index right. I did, I equals know some formula before. This is doing the same calculation over here. I'm calculating, where is the start of my current block? Well, that's my block ID times the size of the block that's gets to me. Let's say I live in block one. Itget me this point right here at the middle. And then afterwards I need to know where do I live within my block? Well, that's gonna to be kind of the offset. But now notice one difference. I don't get in on offset because I'm not programming threads, right? I'm programming blocks. And so what does that mean? Well, my offsets are actually a vector, not a single value, because this is basically going to be I'm going to do vectorized operation, where the vectorized operation is going to be handled by different threads. So here my offsets are the start of the block plus a vector, this range of block size sort of offsets. So I'm my offsets are all of these coordinates within block one at once. Of course, if I'm at the very end, I might go off the edge. And so I need a mask to handle anything that lives off the boundary of my vector. Now I'm going to load in a sort of single vectorized operation, everything at once. So x pointer plus offsets, these are sort of the values that I'm responsible for masked up. And it's loaded into x, which is my sort of internal values, my internal sort of temporary vector that I need. And with this temporary vector, I'm going to do exactly the old galu computation. There's no ten H, so I compute that manually. But this formula, you can convince yourself, is the same as what we have here. And then why is going to be the formula computed up here? Now, once I'm done, I need to write it back into my output sort of buffer or my output vector. And so I compute sort of my targets. So this is y pointer plus offsets. I take my values, my temporary value y, and then I store it. So this is very, very, very similar to what came before. But this one is the vectorized version. I get to operate on an entire block at once. And so instead of kind of thinking at the perspective of a thread, I'm thinking from the perspective of a block, but not too different, right? This is all fairly similar stuff. So now I've written my Triton gu, and all right, I will do this fairly quickly. All right. So one last thing, I will only point out a few things here, because I don't want to get like so in the weeds that you all get up and leave. But the one last cool thing that we can do is Triton, of course, compiles into low level, sort of almost machine code for the GPU. And we can look at this very low level called called ptx code after the Triton compiler sort of goes over it. And it's actually kind of cool. You can kind of see how the GPU actually works at the thread level. So this is the Triton gu kernel. It was generated by the compiler. And at first it's going to do some of the really basic stuff. So what's it doing here? It's saying, well, I'm going to need to store some values, right? I'm going to need to store intermediate computations. B means actually sort of untyped, sort of basically like bytes. So I need bytes that are sort of 32 bsize. I need floats for doing computations called f, and I need another set of registers that are 64 bits, and that's another set of registers. And so I have all these sort of registers that I need for temporary computations. And then starting here, I'm going to start computing basically my coordinates. So sorry, this part is loading the various arguments to the function. So things like the x pointer and the y pointer get loaded here. Starting here, I start computing the coordinate offsets of my Triton sort of kernel. And then once I get down here, this ld global, this is the code that's used to load the values from x pointer back into my temporary registers. So it's basically saying load R2, R R four, R five, using the memory position in rd D1, and notice how it's loading four things at once because it's cleverly handling, coalescing, right? We know we can get four values for free. We should operate on all four of these values at once because we get them. And then you do the same thing again. You do the same thing again here, and then you start to get basically the floating point operations, mole f 32, which basically goes through and does the ten H computations. I'm not going to explain all of the different pieces, but know here it's it's multiplying by a constant. It does x to the cube, like multiplying the same numbers multiple times. And then it's going to compute here two to the x, but we want e to the x, and so it multiplies by log two to get the exponentiated base. You can really see all of the different like literal step by step operations that the GPU does in order to get you the final result. And so I'll skip all over to the end. This is all floating point computations that it needs to do. And then at the very end, it stores the values that it has, R 38 through R 41 into our d four, which is the memory position of our output, right? So this is kind of like what's actually happening at the low level. And we see that each thread is operating on four values at a time. It's temporary storage as the registers, which is the really, really high speed storage that it has very locally. So we can see this is going to, just looking at it, be probably pretty fast code, right? So that was the ptx. And we can go through and see what it's doing for all sorts of things. But now let's go back and actually benchmark things. So we got manual gu 8.1s, pi torch time, 1.1s. Couta time, 1.84s. Triton time, 1.848s. So we didn't get any faster, but it was much easier to write Triton code, right? We wrote an m Python. We thought about blocks. We could do vectorized addiditions. If you're doing more sophisticated stuff, you know basically Triton will handle a lot of the memory stuff for you. And so it's actually pretty good. And then profiling, once again, we see single kernel launch that consumes all of the GPU time, right? So that's great. And that gets Triton kernels the last thing, at least in this sort of whoops. Here. Okay, that I want to talk about is torch compile. Of course, writing kuda kernels is cool and it makes you feel really good, but maybe we don't need to do that, right? Like the things that we were doing here were very simple. We were just taking these like you know x cubed and like exponentiation operations, and we were just shoving them all into a single cuda kernel. And so maybe we can just do that without you know doing much. And so you know we've had the several different ways that we've showed you. Now the last one I want to talk about is this thing called torch compile, which will take a non optimized pi torch code and it will write more optimized code. And so here it's going to attempt to automatically do optimizations like kernel fusion. And this compiled galue is going to be equivalent in the actual outputs that it generates. But now let's look at the runtimes, right? So we've got some runtime variation, but basically the same kind of numbers, 8.1s manual, 1.1s PyTorch, 1.8s cuda, and then 1.47s on torch compile, right? So the punch line here is modern jit compilers are pretty good. It can do optimizations like operation fusion without you having to do very much at all. And if you look under the hood, you can kind of see that there's basically, once again, one thing that happens. This is a sort of fused ad multiplied ten H Triton code. So it's generating Triton under the hood that basically is doing similar kinds of things as a Triton code, but it's actually slightly more optimized than what we did. And so it's getting slightly better performance than even our code. So torch compile is quite nice. Yes. First compable to better like because you do a price implemenment your price version, you see much seems like it gets through slash information, right? Yeah. So the question was like when do you know that? I guess maybe the better way to phrase that question is when do you know you can do better than torch compile, right? Is sort of the relevant question. And I think for simple stuff like simple operator fusion or the other thing that it's very good at is optimizing matrix multiplies. So torch compile, as I said before, can do things like if it knows the shape of the matrices, can figure out which kernels to dispatch. It is very good at those things. I doubt that you can get much better than that. But there are things like if you've seen flash attention one, 23, those are pretty nontrivial optimizations. Like these days, torch compile and like jx's xla compiler can do those, but that's because we know in hindsight that those are the right optimizations to do. I think some of those things are a little bit nontrivial to figure out. Like flash attention three has additional sort of hardware level optimizations that leverage you know the H -100 hardware. That's not obvious to do with a jit compiler. And so there are some things that I think are quite hard with torch compile that I think you could do better. But in general, I think the point here is you shouldn't go home and say, I'm cuda, I'm gonna to write kuda kernels for every single part of my language model. Know that's probably not a good use of your time, but if you're writing a new architecture with some complicated piece and you're not getting utilization but you think you can, that's maybe the time to really bust out the Triton. Okay? So we're basically at time, but we can quickly go through one last example of Triton. Maybe this will be useful for you in assignment two of doing softmac. So one difference is until now, we were doing just basic element wise operations. And that's really easy because you just operate on each element and there's sort of no sort of complexity to those kinds of things. So now let's do softmax, which is it has a reduction operation where you have to add across all of the elements. So how do we do that? Well, what we want to do is we want to normalize across each row of the matrix. And what we would like to do is we like to make this fast. So a naive version of this is going to be pretty slow. And now we're going to write the Triton kernel. So if I want it to be lazy, the easiest way to do this is actually, you can think for a moment about what the easiest way to do this. Now let's say you want to write a soft max, you're going to normalize each row of a matrix. And imagine these matrices are pretty small. So you're just writing a kernel for small matrices, right? So if you're doing this, what's the right kind of block design? Well, maybe what we should do is our grid should actually just be rows. So each sm is gonna to handle a single row. That's kind of the optimal thing to do because if we can fit a whole row into an sm, then we just sum across that row in the sm and then we divide, right? That's great. And so that's going to be the simple design for our very naive sofnax kernel here. So all we're going to do is that we're gonna to make the block size basically, sorry, we're gonna to make each block a row. And so the block size should be number of columns plus a little bit of buffer to sort of be able to fit all the columns. So this is Triton next power of two of n, and that's a nice way of padding out your columns. And then I'm going to make each block of rows and the number of Blois, exactly the number of rows. And then I have my Triton softmax kernel, which is written in kind of the way that you expect. So now we have a matrix rather than a vector. So we have x pointers, we have y pointers, we need the strides of the matrices. And then we can basically figure out what row index I'm in. I can get the column offsets. This is gonna to be the same kind of code as before. In fact, getting the row offsets simpler because each row is a block. And then now I'm going to do basically the same kind of stuff. I'm going to load in each row into my sort of s's, sort of local memory. And then I'm going to do computation exactly in a way that looks like a softmax. I have my row, I subtract my max, I takes the exponent, I sum it, and then I divide, which is going to give me my softmax normalized row, and I write it back to global memory, right? No complexity at all. Whenever your computations fit nicely in an sm, writing Triton code looks very similar to writing just normal Python code, just with a little bit of load and store and keeping track of where the blocks are. Right? So life is pretty simple. Let's go back. Wait, where we to the Triton? Here we go. And then we can kind of see how fast all of our different pieces of code are. So I'll zoom out again. Just make sure. Ok, so manual time takes 3.7s. Our compile time is 1.3s for torch compile. The pe torch time is 1.5s, and the Triton time is 1.9s. It's still a little bit slow. Torch compile can actually do better than sort of the native pi torch implementation, especially when it knows about the shapes and sizes of certain operations. So finally, we can look in the profiler. The manual softmax is kind of a disaster here. You see all sorts of crazy operations happening all over the place. Let let me clear this. If we go back up here, okay. Yep, we see all sorts of operations happening. You know we have x, we have max, we have sum because we've implemented things naively and we've got memory reads and writeverywhere the compiled softmax. It's just going to be sort of one fused softmax operation that goes quite fast. And then we've got pytorx softmax, which is also one cuda kernel call. And same thing with our Triton soft tmax. We have our nice Triton sofmax kernel that is a single fused kernel for everything. Okay. I won't go through the ptx code for this. I think where we're kind of at time, and I don't want to drag you through that low level again, but hopefully this has given you a flavor of lower level GPU programming for the purpose of making language models go fast. And hopefully, you'll have fun doing assignment, too. Thanks.
最新摘要 (详细摘要)
概览/核心摘要 (Executive Summary)
本讲座(Stanford CS336, Spring 2025, 06 Kernels, Triton)深入探讨了为GPU编写高性能代码的技术,特别是针对语言模型中的标准组件。核心内容围绕GPU架构回顾、基准测试(benchmarking)与性能剖析(profiling)的重要性及方法展开。讲座强调,在进行任何优化前,必须通过性能剖析确定瓶颈,避免盲目优化。具体实践中,演示了如何使用PyTorch内置工具及NVIDIA Nsight Systems进行细致的性能分析,揭示了CPU与GPU的异步执行机制及其对性能的影响,例如torch.cuda.synchronize()的正确使用和print语句可能带来的隐性同步开销。
讲座通过GELU(高斯误差线性单元)和Softmax作为案例,对比了多种核函数(kernel)实现方式的性能:包括朴素PyTorch实现、手动CUDA C++编写、使用Triton语言编写以及利用torch.compile进行JIT编译优化。结果表明,核函数融合(kernel fusion)是提升性能的关键,能显著减少内存读写开销。手写CUDA C++和Triton核均能实现接近甚至超越原生PyTorch(未优化)的性能,其中Triton提供了更友好的Python编程接口。torch.compile则展现了强大的自动优化能力,往往能生成高效的Triton代码,达到与手动优化相媲美甚至更优的性能。讲座还深入到PTX(并行线程执行)汇编层面,分析Triton编译后的底层指令,以理解GPU的实际执行细节和优化点,如内存合并(memory coalescing)。最终结论是,虽然现代JIT编译器非常强大,但在特定复杂场景或追求极致性能时,理解并手动编写/优化GPU核仍然具有价值。
GPU 架构回顾
Speaker 1首先简要回顾了GPU的工作原理,为后续高性能代码编写奠定基础。
- 核心组件:
- SM (Streaming Multiprocessors): GPU包含多个SM,每个SM内有大量计算单元(如INT32, FP32)。
- 线程 (Threads): 每个SM能启动大量线程执行计算。
- 内存层级 (Memory Hierarchy):
- DRAM (Global Memory): 容量大,速度慢。
- 缓存 (Caches): 速度远快于DRAM。
- 寄存器文件 (Register File): 速度极快,每个线程可访问,在高性能GPU编程中会被大量使用。
- 执行模型 (Execution Model):
- 线程块 (Thread Blocks): 一组线程,调度到单个SM上执行。是Triton等编程模型中思考的基本原子单元。
- 通信: 线程块内的线程可以通过共享内存 (Shared Memory) 高效通信,速度接近L1缓存。跨线程块通信则非常昂贵。
- 同步: 可以在线程块内同步线程,但不能跨块同步。
- Warp: 线程被组织成32个线程一组的Warp,在SM上同时执行。这减少了控制逻辑的开销。
- 性能考量: 理想情况下,希望所有Warp有均等计算量,线程块数量能被SM数量整除(或远多于SM数量)。
- 线程块 (Thread Blocks): 一组线程,调度到单个SM上执行。是Triton等编程模型中思考的基本原子单元。
- 算术强度 (Arithmetic Intensity):
- 定义: 计算操作次数(FLOPs)与内存访问字节数的比率。
- 目标: 保持高算术强度,因为计算能力的提升速度远超内存带宽的提升速度。
- 现实: 许多计算是内存受限 (memory bound) 的。矩阵乘法若实现巧妙可以是计算受限 (compute bound) 的,其他多数运算是内存受限。
基准测试 (Benchmarking) 与性能剖析 (Profiling)
Speaker 1强调,编写高性能代码的核心在于首先进行基准测试和性能剖析,以准确定位瓶颈。
- 核心观点: > "if you want to write high performance code, you should remember to benchmark and profile your code."
- 基准测试 (Benchmarking):
- 定义: 测量操作的端到端执行时间 (wall clock time)。
- 目的: 比较不同实现的性能,理解代码随输入规模变化的扩展性。
- 关键实践:
- 预热 (Warm-up): 运行若干次迭代以排除初始化、JIT编译等首次运行的开销,测量稳态性能。
- 同步CPU与GPU (
torch.cuda.synchronize()): 由于CPU和GPU异步执行,CPU提交任务后不会等待GPU完成。计时前和计时结束后都需要调用torch.cuda.synchronize()确保测量的是GPU实际执行时间。- Speaker 1解释道: > "the GPU and the cpu are basically two independent compute units... their execution model is going to be this Python code that I have here. This lives on the cpu, right? And when I run something, it's going to dispatch a bunch of cuda kernels to the GPU... And the cpu will actually go on and keep running, right? It doesn't wait for those cuda executions to stop."
- 多次测量取平均: 消除单次运行的波动(如GPU温度影响)。
- 示例:
- 矩阵乘法: 随矩阵增大,运行时间呈超线性增长;小矩阵时,启动开销占主导。
- MLP: 运行时间与层数、步数呈线性关系。
- 性能剖析 (Profiling):
- 定义: 更细粒度地分析函数内部时间花费在何处。
- 优势:
- 识别具体瓶颈函数。
- 揭示PyTorch接口下的底层CUDA调用,理解硬件执行细节。
- PyTorch内置Profiler:
- 可以追踪CPU和GPU时间。
- 示例分析:
add操作: 显示aten::add(PyTorch C++接口)、实际CUDA核 (vectorized_elementwise_kernel)、核启动 (cudaLaunchKernel) 和同步 (cudaDeviceSynchronize) 的耗时。- 矩阵乘法: 显示
aten::matmul,底层可能调用NVIDIA的cutlass库中的特定核函数。不同尺寸的矩阵可能调度到不同的核。 torch.cdist(欧氏距离): 分解为多个底层操作(如aten::matmul,aten::pow,sum)及其对应的CUDA核。- GELU, Softmax: 通常有预编译的融合核 (fused kernel)。
- NVIDIA Nsight Systems (进阶Profiler):
- 提供GPU硬件活动 (
cuda hw) 和CPU线程 (threads) 的详细时间线视图。 - 代码注解 (
nvtx.range_push,nvtx.range_pop): 帮助将代码段映射到Profiler的输出中。 - 揭示现象:
- 初始化开销: 加载库等操作可能耗时较长。
- CPU-GPU异步执行: CPU通常会领先GPU执行,提前将CUDA核任务推入队列。
- Speaker 1指出: > "the cpu is running way ahead of the GPU."
print语句的影响: 在迭代中打印损失等操作,会强制CPU等待GPU计算结果,导致同步,可能形成CPU瓶颈,破坏流水线。- Speaker 1解释道: > "this kuda stream synchronize command on the cpu. This is basically saying, I'm just waiting for the GPU because I can't run ahead. I'm waiting for this loss to be computed and to be sent back to me."
- Python性能: Python本身性能不高,但由于CPU可以将任务快速提交给GPU并继续执行,因此CPU通常不是瓶颈。
- 提供GPU硬件活动 (
编写高性能核函数 (Kernels)
Speaker 1通过GELU和Softmax函数的不同实现,展示了如何优化GPU运算。
- 核函数融合 (Kernel Fusion):
- 核心思想: 将多个连续操作合并到单个GPU核中执行,以减少数据在全局内存和SM之间的往返次数,从而降低内存访问开销。
- 类比: > "There's a little factory. Every time I need to do an operation, I need to ship it from the warehouse to the factory in back... What I should do is have one factory that does all the operations at once."
-
GELU 实现对比:
- PyTorch原生实现 (
torch.nn.functional.gelu): 内部已融合,速度快。- 性能: 约 1.1 ms (针对特定大输入)。
- 朴素PyTorch实现 (手动展开公式): 多个PyTorch操作(乘法、加法、
tanh等)会触发多次独立的CUDA核调用,性能差。- 性能: 约 8.1 ms (慢约8倍)。
- Profiler显示: 多次
vectorized_elementwise_kernel等核调用。
- CUDA C++ 实现:
__global__ void gelu_kernel(...): 定义GPU核函数。- 线程索引计算:
int i = blockIdx.x * blockDim.x + threadIdx.x; - 边界检查:
if (i < n_elements)。 - CPU端封装函数: 检查输入(如
.is_cuda(),.is_contiguous()),分配输出内存 (torch.empty_like),计算网格和块大小,启动核。 - 调试: 设置环境变量
CUDA_LAUNCH_BLOCKING=1。 - 性能: 约 1.8 ms。显著优于朴素实现,接近PyTorch原生。
- Triton 实现:
- Triton是OpenAI开发的领域特定语言,可在Python中编写GPU核,易用性高。
- 编程模型: 面向线程块,Triton编译器负责内存合并、共享内存管理等底层细节。
@triton.jit装饰器定义核函数。- 使用
tl.program_id(axis=0)获取块ID,tl.arange创建块内偏移向量。 - 通过
tl.load和tl.store进行带掩码的内存访问。 - PTX (Parallel Thread eXecution) 代码分析: Triton编译后生成PTX代码(GPU的汇编级指令)。
- 显示寄存器分配 (
.reg .b32 %r<id>;)。 ld.global(从全局内存加载),通常一次加载多个元素(如4个),实现内存合并 (memory coalescing)。st.global(存储到全局内存)。- 每个线程实际操作多个数据元素,利用寄存器进行高速本地存储。
- 显示寄存器分配 (
- 性能: 约 1.848 ms。与CUDA C++版本性能相当,但编写更便捷。
torch.compile(JIT编译):- PyTorch的JIT编译器,能自动进行核函数融合等优化。
- 性能: 约 1.47 ms。优于手动CUDA C++和Triton实现,接近甚至可能超过PyTorch原生融合核。
- 底层:
torch.compile通常会将操作融合并生成Triton代码。 - 何时手动编写核: > "if you're writing a new architecture with some complicated piece and you're not getting utilization but you think you can, that's maybe the time to really bust out the Triton." 对于FlashAttention这类复杂优化或需利用特定硬件特性的场景。
- PyTorch原生实现 (
-
Softmax 实现对比 (涉及Reduction操作):
- 挑战: Softmax包含行内求最大值和求和等归约 (reduction) 操作。
- 朴素Triton Softmax设计:
- 假设矩阵行较短,可以使每个SM处理一行。
- 网格大小 (num_blocks) 等于行数。
- 块大小 (block_size) 至少为列数(通常取2的次幂)。
- 核内操作: 加载整行数据到SM的本地内存,计算max,减去max,求指数,求和,归一化,写回。
- 性能比较 (针对特定输入):
- 手动PyTorch (naive): 3.7 ms
torch.compile: 1.3 ms- PyTorch原生: 1.5 ms
- Triton (naive): 1.9 ms
- Profiler显示: 手动实现的Softmax涉及大量独立操作,性能不佳。
torch.compile、PyTorch原生和Triton版本均能实现单核融合。
核心结论
- 性能剖析至关重要: 在优化前必须使用profiler(如PyTorch内置工具或NVIDIA Nsight Systems)来识别真正的性能瓶颈。
- 理解CPU-GPU异步性:
torch.cuda.synchronize()对于准确的基准测试是必需的;注意可能导致隐式同步的操作(如print)。 - 核函数融合是关键: 减少GPU核的调用次数和内存I/O是提升性能的核心策略。
- Triton简化GPU编程: 提供了Pythonic的方式编写高效GPU核,自动处理许多底层细节。
torch.compile非常强大: 现代JIT编译器在许多情况下能自动实现高效的核函数融合,性能可与手动优化媲美甚至更优。- 手动优化仍有价值: 对于复杂算法(如FlashAttention)或需要利用特定硬件特性的场景,手动编写和优化CUDA/Triton核函数仍然是必要的。