Stanford CS336 Language Modeling from Scratch | Spring 2025 | 02 Pytorch, Resource Accounting
该讲座介绍了使用PyTorch从头构建语言模型的过程,并重点关注了模型训练中的资源效率问题,特别是内存和计算资源的使用。讲座通过示例计算(如训练大型模型的耗时、特定硬件可训练的最大模型参数量)强调了进行资源估算(“餐巾纸数学”)的重要性,以便有效控制成本。讲座内容不涉及Transformer架构的具体细节,而是聚焦于PyTorch的基本构件和资源核算方法,旨在培养学员的效率意识和实践能力。
在内存核算方面,讲座详细讨论了张量(Tensor)作为存储参数、梯度、优化器状态等数据的基本单元,及其不同浮点数表示对内存占用的影响。具体对比了float32(单精度,默认,4字节)、float16(半精度,2字节,动态范围受限,可能导致训练不稳定)、bfloat16(脑浮点数,2字节,具有类似float32的动态范围但精度较低,适合深度学习计算)以及fp8(8位浮点数,更小,适用于H100等新硬件以追求极致优化)等数据类型。讲座建议在计算中使用bfloat16以平衡效率和稳定性,而参数和优化器状态的存储仍推荐使用float32以保证训练稳定性。
标签
媒体详情
- 上传日期
- 2025-05-13 16:30
- 处理状态
- 已完成
- 转录状态
- 已完成
- Latest LLM Model
- gemini-2.5-pro-exp-03-25
转录
speaker 1: Okay. So last lecture, I gave an overview of language models and what it means to build on from scratch and why we want to do that. Also talked about tokenization, which is going to be the first half of the first assignment. Today's lecture will be going through actually building a model. We'll discuss the primitives in pe torch that are needed. We're going to start with tensors, build models, optimizers and training loop. And we're going to place close attention to efficiency, in particular, how we're using resources, both memory and compute. Okay. So to motivate things a bit, here's some questions. Okay, these questions are going to be answerable by napkin mats or get your napkins out. So how long would it take to train a 70 billion parameter dense transformer model on 15 trillion tokens on 1024H -100s? Okay. So I'm just going to sketch out the sort of give you a flavor of the type of things that we want to do. Okay, so here's how you go about reasoning it. You count the nototal number of flops needed to no train. So that's six times the number of parameters, times the number of tokens. Okay. And where does that come from? That will be what we'll talk about in this lecture. You can look at the promise, the number of flops per second that H -100 gives you, the mfu, which is something we'll see later. Let's just set it to 0.5. And you can look at the number of flops per day that your hardware is going to give you at this particular mfu, so 1024 of them for you know one day. And then you just divide the total number of flops you need to train them all by the number of flops that you're supposed to get. And that gives you about 144. Okay? So this is very simple calculations. At the end of the day, we're going to go through a bit more where these numbers come from and in particular where the six times number of parameters, times number of tokens comes from. Okay, so here's the question. What is the largest model you can train on H eight, H -100s using atom W if you're not being too clever? Okay. So H -100 has 80 gb of hbm memory. The number of bytes per parameter that you need for the parameters, the gradients optimizer state is 16, and we'll talk more about where that comes from. And the number of parameters is basically a total amount of memory divided by number of bytes you need per parameter. And that gives you about 40 billion parameters. Okay? And this is very rough because it doesn't take you into activations, which depends on batch size and sequence length, which I'm not really going to talk about, but will be important for assignment one. Okay. So this is rough back envelope calculation. And this is something that you're probably not used to doing. You just implement a model, you train it, and what happens, happens. But remember that efficiency is the name of the game. And to be efficient, you have to know exactly how many flops you're actually expending, because when these numbers get large, these directly translto dollars and you want that to be as small as possible. Okay, so we'll talk more about the details of how these numbers arise. You know we will not actually go over the transformer ers, so taatu is going to talk over the conceptual overview of that next time. And there's many ways you can learn about a transformer if you haven't already looked at it. There's assignment one. If you do assignment one, you'll definitely know what a transformer is. And the handout actually does a pretty good job of walking through all the different pieces. There's a mathematical description. If you like pictures, there's pictures. There's a lot of stuff you can look on online. So but instead, I'm going to work with simpler models and really talk about the primitives and the resource accounting piece. Okay, so remember last time I said, what kinds of knowledge can you learn? So mechanics in this lecture, it's going to be just pi torch and understanding how pi torch works at a fairly primitive level. So that's will be pretty straightforward. Mindset is about resource accounting and it's not hard, it's just you just have to do it and intuitions. Unfortunately, this is just going to be broad strokes for now. Actually, there's not really much intuition that I'm gonna to talk about in terms of how anything we're doing translates to good models. This is more about the mechanics and mindset. Okay, so let's start with memory accounting and then I'll talk about compute accounting and then we'll build up bottom up. Okay, so the best place to start is a tensor. So tensors are a building block for storing everything in deep learning parameters, gradients, optimizer stay, data activation tions. There's sort of these atoms. You can read lots of documentation about them. You're probably very familiar with how to create tensors. There's creating tensors different ways. You can also create a tensor and not initialize it and use some special initialization for the parameters to, if you want. Okay, so those are tensors. So let's talk about memory and how much memory tensors take up. So every tensor that we'll probably be interested in is sort of as a floating point number. And so there's many ways to represent floating point. So the most default way is float 32. And flotortwo has 32 bits. They're allocated one for sine, H for exponent, and 23 for the fractions. So exponent gives you dynamic range and fraction gives you basically specified different values. So flow 32 is also known as fp 32, or single precision is sort of the gold standard in computing. Some people also refer to flow 30p as full precision. That's a little bit confusing because full is really depending on who you're talking to. If you're talking to a scientific computing person, theywill kind of laugh at you when you say float 32 is really full because theyuse flosixty four or even more. But if you're talking about machine learning person, float 32 is the maps you'll ever probably need to go because deep learning is kind of sloppy like that. Okay, so let's look at the memory. So the memory is it's very simple. It's determined by the number of values you have in your tensor and the data type of each value. Okay, so if you create a torch sensor of a four x eight matrix, the default itgive you a type of flothirty two. The size is four x eight, and the number of elements is 32. Each element size is four bytes, 32 bits is four bytes. And the memory usage is simply the number of elements times the number of size of each element. And that will give you 128 bytes. Okay? So this should be pretty easy. And just to give some intuition, if you get the one matrix in the fefour layer of GPT -3, is this number by this number? And that gives you 2.3 gb. Okay? So that's one matrix. Know these matrices can be pretty big. Okay? So float 32 is a default, but of course these matrices get big. So naturally you want to make them smaller. So you use last memory. And also it turns out if you make them smaller, you also make it go faster too. Okay, so another type of representation is called flow at 16. And as the name suggests, it's 16 bits where both exponent and fraction are shrunk down from eight to five and 23 to ten. Okay? So this is known as half precision and it cuts down half the memory. And that's all great except for the dynamic range for these, float 16 isn't great. So for example, if you try to make a number like ten E1e minus eight in flow 16, it basically rounds down to zero and you get under flow. So the flow 16 is not great for representing very small numbers or very big numbers as a matter of fact. So if you use flow 16 for training for small models, it's probably going to be okay. But for large models, when you're having lots of matrices and you can get instability or underflow or overflow and bad things happen, okay. So one thing that has happened, which is nice, is there's been another representation of b float 16s, which stands for brain float. This was developed in 2018 to address the issue that for deep learning, we actually care about dynamic range more than we care about this fraction. So basically bf 16 allocates more to the exponent and less to the fraction. Okay, so it uses the same memory as floating point 16, but it has a dynamic range of float 32. Okay? So that sounds really good. And it actually the catches that this resolution which is determined by the facaction is worse, but this doesn't matter as much for deep learning. So now if you try to create a tensor with one e minus eight in bf 16, then you get something that's not zero. Okay, so you can dive into the details. I'm not going to go into this, but you can stare out to the actual full specs of all the different floating point operations. Okay? So bf 16 is basically what you will typically use to do computations because it's sort of good enough for febefore able computation. It turns out that for storing optimizer states and parameters, you still need flow 32 for otherwise your training will go haywire. So if you're bald, so now we have something called feight or eight bit. And as the name suggests, this was development 2022 by nvidia. So now they have essentially, if you look at fp and bf 16, it's like this, and fp, wow, you really don't have that many bits to store stuff, right? So it's very crude. There's two sort of variants depending on if you want to have more resolution or more dynamic range. And I'm not going to say too much about this, but feeight is supported by H -100. It's not really available on a previous generation, but at a high level training with flow 32, which is I think is what you would do if you're not trying to optimize too much and it's sort of safe. It requires more memory. You can go down to fp eight or bf 16 and but you can get some instability. Basically, I don't think you would probably want to use a float 16 at this point for deep learning. And you can become more sophisticated by looking at particular places in your pipeline, either four paths or backward paths or optimizers or gradient accumulation, and really figure out what the minimum precision you need at this particular places. And that's called gets into kind of mixed precision training. So for example, some people like to use flow 32 for the know the attention to make sure that doesn't kind of know get messed up of what four simple fefour passes with mamals. Bf 16 is fine. Okay, pause a bit for questions. So we talked about tensors and we looked at depending on what representation, how much storage they take. Yeah. Can you just clarify about the mixed position, like when you would use 32 in the before? Yeah. So the question is when would you use flothirty two or bf 16? I don't have time to get into the exact details. And it sort of varies depending on the model size and everything. But generally for the parameters and optimizer states you use flothirty two, you can think about bf 16 as something that's more transitory. Like you basically take your parameters, you cast it to bf 16, and you kind of run ahead with that model. But then the thing that you're going to accumulate over time, you want to have higher precision. Okay, so now let's talk about compute. So that was memory. So compute obviously depends on what the hardware is. By default, tensors are stored in cpu. So for example, if you just in pi torch, say x equals torch at zero is 3232, then you'll put it on your cpu itbe in to cpu memory. You of course that's no good because if you're not using your GPU then you're going to be orders of magnitude too slow. So you need to explicitly say in PyTorch that you need to move it to the GPU. And this is it's actually just to make it very clear in pictures, there's a cpu, it has ram and that has to be moved over to the GPU. There's a data transfer which is caught, which take some work, take some time. Okay? So whenever you have a tensor in PyTorch, you should always keep in your mind, where is this residing? Because just looking at the variable or just looking at the code, you can't always tell. And if you want to be careful about computation and data movement, you have to really know where it is. You can probably do things like assert where it is in various places of code just to document or be sure. Okay, so let's look at what hardware we have. So we have in this case, we have one GPU. This was run on the H -100 clusters that you guys have access to. And this GPU is H -100 80 gb of high bandwith memory, and it gives you the cache size and so on. So if you have remember that x is on cpu, you can move it just by specifying two, which is a kind of a general pi torch function. You can also create a tension directly in a GPU, so you don't have to move it at all. And if everything goes well, I'm looking at the memory allocated before and after. The difference should be exactly 23 by 32 by 32 matrices of four by yte floats. Okay, so it's a 192. Okay, so this is a sanity to check that the code is doing what is advertised. Okay, so now you have your tensors on the GPU, what do you do? So there's many operations that you'll be needing for assignment, one and in general to do any deep learning application. And most tensors you just create by performing operations on other tensors. And each operations has some memory and compute footprint. So let's make sure we understand that. So first of all, what is actually a tensor in PyTorch? Right? Tensors are A A mathematical object in pi torch. They're actually pointers into some allocated memory. Okay? So if you have, let's say, a matrix, four x four matrix, what it actually looks like is a long array. And what the tensor has is metadata that specifies how to get to address into that array. And the metadata is going to be two numbers a stride for each, or actually number of one number per dimension of the tensor. In this case, because there's two dimensions, it's stride zero and stride one. Strides zero specifies if you were in dimension zero to get to the next row to increment that index. How many do you have to skip? And so going down the rows, you skip four. So stride zero is four. And to go to the next column, you skip one. So stride one is one. Okay, so with that, to find an element, let's say one, two, one, comma two, it simply just multiply the indexes by the stride and you get to your index, which is six here. So that would be here or or here. Okay? So that's basically what's going underneath the hood for tensors. Okay? So this is relevant because you can have multiple tensors that use the same storage. And this is useful because you don't want to copy the tensor all over the place. So imagine you have a two x three matrix here. Many operations don't actually create a new tensor, they just create a different view and doesn't make a copy. So you have to make sure that know mutations, if you start mutating one tensor, it's going to cause the other one to mutate, okay? So for example, if you just get row zero, okay. So remember, y is this tensor and sorry, x is one, two, three, four, five, six, and y is x zero, which is just the first row, okay? And you can sort of double check there's this function in rothat says if you look at the underlying storage, whether these two tensors have the same storage or not. Okay. So this definitely doesn't copy the tensor. It just creates a view. You can get column one. This also doesn't copy the tensor. So don't need to do that. You can call a view function, which can take any time. So and look at it in terms of different dimensions. Two x three should distribute maybe the other way around as a three x two tensor. So that also doesn't change. Do any copying you can transpose that also doesn't copy. And then like I said, if you start mutating x, then y actually gets mutated as well because x and y are just pointers into the same underlying storage. Okay. So things are one thing that you have to be careful of is that some views are contiguous, which means that if you run through the tensor, it's like just slide going through this array in your storage, but some are not. So in particular, if you transpose it now, what does it mean when you're transposing it? You're sort of going down now. So you're kind of if you imagine going through the tensor, you're kind of skipping around. And if you have a non contiguous tensor, then if you try to further view it in a different way, then this is not going to work. Okay? So in some cases, if you have a non contiguous tensor, you can make it contiguous first and then you can apply whatever viewing operation you want to it. And then in this case, x and y do not have the same storage because contiguous in this case, makes a copy. Okay, so this is just ways of dicing. And tensor views are free, so feel free to use them deffind different variables to make it sort of easier to read your code because they're not allocating any memory. But remember that contiguous or reshape, which is basically a contiguous view, can create a copy. And so just be careful what you're doing. Okay, questions before moving on. All right. So hopefully a lot of this will be reviewed for those of you have kind done a lot of pie torch before, but it's helpful to just do it systematically, make sure we're on the same page. So here are some operations that do create new tensors and in particular, element wise operations. I'll create new tensors, obviously, because you need a somewhere else to store the new value. There's a know triangular U is also an element operation that comes in handy when you want to create a causal attention mask, which you'll need for your assignment. But nothing is interesting that interesting here. Okay, so let's talk about matmalls. So the bread and butter of deep learning is matrix multiplications. And I'm sure all of you have done a matrix multiplication. But just in case this is what it looks like, you take a 16 by 32 times a 32 by two matrix, you get a 16 by two matrix. But in general, when we do our machine learning application, all operations are you want to do in a batch. And in the case of language models, this usually means for every example in a batch and for every sequence in a batch, you want to do something. Okay? So generally what you're gonna to have instead of just a matrix is you're gonna to have a tensor where the dimensions are typically batch sequence. And then whatever thing you're trying to do, in this case, it's a matrix for every token in your data set. And so you know, pie torch is nice enough to make this work well for you. So when you take this for your dimensional tensor and this matrix, what actually ends up happening is that for every batch, every example in every token, you're multiplying these two matrices. And then the result is that you get your resulting matrix for each of the first two elements. So this is there's nothing fancy going on, but this is just a pattern that I think is helpful to think about. Okay, so I'm going to take a little bit of a digression and talk about inops. And so the motivation for inops is the following. So normally you impi torch, you define some tensors and then you see stuff like this where you take x and multiply by y, transpose minus two, minus one, and you kind of look at this and you say, okay, what is minus two? Well, I think that's this sequence. And then minus one, is this hidden because you're indexing backwards and it's really easy to mess this up because if you look at your code and you see minus one, minus two, you're kind of if you're you're good, you write a bunch of comments, but then the comments are can get out of date with a code and then you have a bad time debugging. So the solution is to use inops s here. So this is inspired by Einstein summation notation. And the idea is that we're just going to name all the dimensions instead of you know relying on indices essentially. Okay, so there's a library called Jax typing, which is helpful for as a way to specify the dimensions in the types. So normally in pytorg would just define, write your code and then you would comment, Oh, here's what the dimensions would be. So if you use Jax typing, then you have this notation where as a string, you just write down what the dimensions are. So this is a slightly kind of more natural way of documenting. Now notice that there's no enforcement here because piytorch types are sort of a little bit of a lime and piytorch. You can use a checker, right? Yeah, you can raise a check, but not by default. Okay, so let's look at the no einsome m. So einsome m is basically matrix multiplication on seraods with good bookkeeping. So here's our example here we have x, which is, let's just think about this. As you have a batch dimension, you have a sequence dimension, and you have four hidends and y is the same size. You originally had to do this thing, and now what you do instead is you basically write down the dimensions, names of the dimensions of the two tensors. So batch sequence, one hidden, batch sequence, two hidden, and you just write what you dimensions should appear in the output. So I write batch here because I just want to basically Carry that over. And then I write c one and c two and notice that I don't write hidden. And any dimension that is not named in output is just summed over and any dimension that is named is sort of just iterated over. Okay? So once you get used to this, this is actually very, very helpful. And maybe it looks, if you've seen this for the first time, it might have seen a bit strange and long. But trust me, once you get used to it, itbe better than doing minus two, minus one. If you're a little bit slicker, you can use dot dot dot to represent broadcasting over any number of dimensions. So in this case, instead of writing batch, I can just write dot, dot dot. And this would handle the case where, instead of maybe batch, I have batch one, batch two or some other arbitrary long sequence. Yeah question, does forge compile this? Like is it guarantee to compile to the position? I guess so the question is a guarantee to compile to as something efficient. This I think the short answer is yes. I don't know if you have any nuances. Figure out the best way to reduce, the best order of energy to reduce and then use that. If you use torch compile, only do that one time and then reuse the same implementation, open over again, be better than anything that designed my hand. Okay, so let's look at reduce. So reduce operates on one tensor and it basically aggregates some dimension or dimensions of the tensor. So if you have this tensor before you would write mean to sum over the final dimension. And now you basically say, actually, okay, so this replaces this with some so reduce. And again, you say hidden and hidden is disappeared, which means that you are aggregover that dimension. Okay. So you can check that this indeed kind of works over here. Okay, so maybe one final example of this is sometimes in a tensor, one dimension actually represents multiple dimensions and you want to unpack that and operate over one of them and pack it back. So in this case, let's say you have batch sequence and then this eight dimensional vector is actually a flattened representation of number of heads times some hidden dimension. Okay? So and then you have a vector that needs to operate on that hidden dimension. So you can do this very elegantly using inops s by calling it rearrange. And this basically you can think about it. We saw view before. It's kind of like kind of a fancier version which basically looks at the same data but know differently. So here it basically says this dimension is actually heads in hidden one. I'm gonna to explode that into two dimensions. And you have to specify the number of heads here because there's multiple ways to split a number into you two. Let's see, this might be a little bit long. Maybe it's not worth looking at right now. And given that x, you can perform your transformation using esum. So this is something hidden one, which corresponds to x and then hidden one, hidden two, which corresponds to W, and that gives you something hidden two. Okay? And then you can rearrange back. So this is just the inverse of breaking up. So you have your two dimensions and you group it into one. So that's just a flattening operation that's with everything all the other dimensions kind of left alone. Okay. So there is a tutorial for this that I would recommend you go through and it gives you a bit more so on. You don't have to use this because you're building it from scratch. So you can kind of do anything you want. But in assignment one, we do give you guidance and it's something probably to invest in. Okay. So now let's talk about computation. No cost of tensor operations. So we introduce a bunch of operations. How much do they cost? So a floating point operation is any operation, floating point like addition or multiplication. These are them. And these are kind of the main ones that are going to, I think, matter in terms of flop count. One thing that is sort of a pet peeve of mine is that when you say flops, it's actually unclear what you mean. So you could mean flops with a lowercase s, which stands for number of floating operations. This is measures amount of computation that you've done. Or you could mean flops also written dwith, an uppercase s, which means floating points per second, which is used to measure the speed of hardware. So we're not going to, in this class, use uppercase s because I find that very confusing and just write slash s to denote that as floating point per second. Okay. Okay. So just to give you some intuition about flops, GPT -3 took about 3:23 flops. GPT -4 was two e 25 flops. Speculation and there was a us exactly order that any foundation model with over 1:26 flops had to be reported in government which now has been revoked. But the eu has still has something that hasn't the eu AI act which is one e 25 which hasn't been revoked. So know some intuitions. A 100 has a peak performance of 312 terraflop per second and H -100 has a peak performance of 1979 terraflop per second with sparsity and approximately 50% without. If you look at the mvidia has these specification sheets, so you can see that the flops actually depends on what you're trying to do. So if you're using p 32, it's actually really, really bad. Like the if you run p 32 on H -100, you're not getting its orders of magnitude worse than if you're doing fp 16 or and if you're willing to go down to fp eight than it can be even faster. And you when I first read that, I didn't realize, but there's an asterisk here. And this means with sparsity. So usually you're in a lot of the major matrices we have in this class are dense. So you don't actually get this. You get something like you know half, exactly half, okay? Okay, so so now you can do a buhand envelope calculations. Eight H -100s for two weeks is just eight times the number of flops per second times the number of seconds in a week. Actually this is this might be one week. Okay? So that's one week and that's 4.7 times e to the 21, which is you know some number. And you can kind of contextualize the flop counts with other model counts. Yeah sparmean. So that means so what does sparsley mean? That means if your matrices are sparse, it's a specific place. Strucptured sparsity, it's like two out of four elements in each. Like root of four elements is zero. That's the only case when you get that. That's me. No one uses it. It's a marketing department of. Okay, so let's go through a simple example. So remember, we're not going to touch the transformer, but I think even a linear model gives us a lot of the building blocks and intuitions. So suppose we have endpoints. Each point is d dimensional, and the linear model is just going to match map each d dimensional vector to A K dimensional vector. Okay, so let's set some number of points. Is b dimension is D K as the number of outputs. And let's create our data matrix x, our weight matrix W, and the linear model is just a map model. So nothing know too interesting going on. And you know the question is, how many flops was that? And the way you would look at this is you say, well, when you do the matrix multipcation, you have basically for every ij, k, triple, I have to multiply two numbers together. And I also have to add that number to the total. Okay? So the answer is two times basically the product of all the dimensions involved. So the left dimension, the middle dimension and the right dimension. Okay. So this is something that you should just kind of remember. If you're doing a matrix multiplication, the number of flops is two times the product of the three dimensions. Okay. So the flops of other operations are usually kind of linear in the size of the matrix or tensor. And in general, no other operation you encounter. Deep learning is expensive as matrix multiplication for large enough matrices. So this is why I think a lot of the napkin math is very simple because we're only looking at the matrix multiplications that are going are performed by the model. Now of course, there are regimes where if your matrices are small enough, then the cost of other things starts to dominate. But generally that's not a good regime you want to be in because the hardware is designed for big much versus multiplication. So sort of by it's a little bit circular, but by kind of we end up in this regime where we only consider models where the mammals are the dominant, no cost. Okay. Any questions about this number? Two times the product of the three dimensions, this is just a useful thing, always be the same, because the chip might have loto my head. Yeah. So the question is, is essentially, does this depend on the matrix multiplication algorithm in general? I guess we'll look at this the next week or the week after when we look at kernels. I mean, actually, there's a lot of optimization that goes underneath under the hood when it comes to matrix multiplications, and there's a lot of specialization depending on the shape. So I would say this is just a kind of a crude estimate that is basically like the right order of. Okay, so Yeah additions and mofications. Yeah additions and multipatients are considered miracle life. So one way I find helpful to interpret this. So at the end of the day, this is just a matrix multiplication, but I'm going to try to give a little bit of meaning to this, which is why I've set up this as kind of a little toy machine learning problem. So b is really stands for the number of data points and dk is the number of parameters. So for this particular model, the number of flops that's required for forward PaaS is two times the number of tokens or number of data points times the number of parameters. So this turns out to actually generalize to transformers. There's an asterisk there because there's the sequence length and other stuff, but this is roughly right for if your sequence length if than isn't too large. So okay, so now this is just the number of floating point operations, right? So how does this actually translate to a wall clock time, which is presumably the thing you actually care about? How long do you have to wait for your run? So let's time this. So I have this function that is just going to do it five times and you're going to perform the matrix multiplay operation. We'll talk a little bit later about this two weeks from now. Why the other code is here. But for now we get an actual time. So that matrix took 0.16s and the actual flops per second, which is how many flops did it do per second, is 5.4e 13. Okay, so now you can compare this with the marketing materials and for the a 100 and H -100 and know as we look at this fix sheet, the flops depends on the data type. And we see that the promise flops per second, which for H -100, I guess this is for flothirty two, is 67 tera flops as we looked. And so that is the number of promise flops per second we had. And now if you look at the, there's a helpful notion called model flops utilization or mfu, which is the actual number of flops divided by the promise flops. Okay, so you take the actual number of flops, remember, which is what you actually witnessed, the number of floating point operations that are useful for your model, divided by the actual time it took, divided by this promise flolot for a second, which is from the glossy brochure, you can get a mfew of 0.8. Okay? So usually you see people talking about their mfus and something greater than 0.5 is you usually consider to be good. And if you're like 5% mfu that's considered really bad, usually can't get you close to that close to 90 or 100 because this is sort of ignoring all sort of communication and overhead. It's just like the literal computation of the flocks, okay? And usually mfu is much higher if the matrix multiplication is dominate. Okay. So that's and if you're any questions about this, Yeah, you're using the per sec not so this promise flop per sec is not considering this as much. One note is this is actually there's also something called hardware to flop utilization. And the motivation here is that we're trying to look at the it's everyone's called model because we're looking at the number of effective useful operations that the model is performing. And so it's a way of kind of standardizing. It's not the actual number of flops that are done because you could have optimization your code that cache a few things or do recomputation of some things. And in some sense, you're still computing the same model. So what matters is that this is y of trying to look at the model complexity. And you shouldn't be penalized just because you were clever in your mfu. If you were clever and you didn't actually do the flups, but you said you did, okay. So you can also do the same with bf 16. And here we see that for bf, the time is actually much better, right? So 0.03 instead of 0.16. So the actual phoops per second is higher. Even the accounting for sparsity the promise flops is still quite high. So the mfu actually actually lower. For bf 16, this is maybe surprisingly low, but sometimes the promise flops is a bit of optimistic. So always about benchmark your code and don't just kind of assume that you're going to get certain levels of performance. Okay. So just to summarize, matrix multipliations dominate the compute. And the general rule of the thumb is that it's two times the product of the dimensions flops. The flops per second, floating points per second depends on the hardware and also the data type. So the fancier of the hardware you have, the higher it is. The smaller the data type, the usually the faster it is. And mfu is a useful notion to look at how well you're essentially squeezing your hardware. Yeah, I bring it up here to get like the maximum utilization, you want to use these like tensor cores on the machine. And so this piyper is by default, you use these tensor cores and like all Yeah so the question is, what about those tensor cores? So if you go to this spec sheet, you'll see that these are all on the tensor core. So the tensor core is basically a specialized hardware to do matmoso. If you are you know, if you're so by default, it should use it. And if you especially if you're using PyTorch, compile itwill, generate the code that will use the hardware properly. Okay. So let's talk a little about gradients. So and the reason is that we've only looked at matrix multiplication, or in other words, basically feet forward passes and the number of flops. But there's also a computation that comes from computing ingradients, and we want to track down how much that is. Okay. So just to consider a simple example, a simple linear model where you take the prediction of a linear model and you look at the mse e with respect to five. So not a very interesting loss, but I think it's illustrative for looking at the gradients. Okay, so remember, in the forward PaaS, you have your x, you have your W, which you want to compute the gradient with respect to. You make a prediction by taking a linear product and then you have your loss. And in the backof PaaS, you just call lost stop backwards. And in this case, the gradient, which is this variable attached to the tensor is turns out to be what you want. Okay, so everyone has done yoke gradients in PyTorch before. So let's look at how many flops are required for computing gradients. Okay? So let's look at a slightly more complicated model. So now it's a two layer linear model where you have x, which is bead ID times W one which is d by d. So that's the first layer. And then you take your hidden activations, H1, and you PaaS it through another linear layer, W two, and to get A K dimensional vector, and you do some compute, some laws. Okay. So this is a two layer linear network. And just as a kind of review, if you look at the number of forward flops, what you had to do was you have to multilook at W one. You have to multiply x by W one and add it to your H1, and you have to take H1 and W two and you have to add it to your H2. Okay? So the nototal number of flops, again, is two times the product of all the dimensions in your map mal, plus two times the product dimensions in your map mal for the second matrix. In other words, two times the total number of parameters in this case. Okay, so what about the backward PaaS? So this part will be a little bit more involved. So we can recall the model x to H1 to H2 and the loss. So in the backward path, you have to compute a bunch of gradients. And the gradients that are relevant is you have to compute the gradient with respect to H, H two, W one and W two of the loss. So d loss d, each of these variables. Okay, so how long does it take to compute that? Let's just look at W two for now. Okay? So the things that touch W two, you can compute by looking at the chain rule. So W two grad, so the gradient is th of d loss. Dw two is you sum H1 times the gradient of the loss with respect to H2. Okay, so that's just a chain row for W two. And this is so all the gradients are the same size as the underlying of vectors. This turns out to be essentially looks like a matrix multiplication. And so the same calculus holds, which is that it's two times then a number of the product of all the dimensions, b times, d times k, okay, but this is only the gradient with respect to W two. We also need to compute the gradient with respect to H1, because we have to keep back propagating to W one and so on. Okay? So that is going to be the product of W two times H2. Sorry, I think this should be that grad of H grad. So that turns out to also be essentially looks like a matrix multiplication. And it's the same number of flops for computing the gradient of H1. Okay? So when you add the two, so that's just for W two. You do the same thing for W one, and that's which has d times d parameters. And when you add it all up, it's so for this, for W two, the amount of computation was four times b, times d, times k. For W one, it's also four times b, times d, times d because W one is d by d. Okay, so I know there's a lot of symbols here. I'm going to try also to give you a visual account for this. So this is from a blog post that I think may work better. We'll see. Okay, I have to wait for the animation to loop back. So basically this is one layer of the node out where has you know the hiddens and then the weights to the next layer. And so I have to the problem with this animation is have to wait okay ready set okay so first I have to multiply W and a and have to add it to this that's a forward PaaS and now I'm going to multiply this these two and then add it to to that and I'm gonna to multiply and then add it to that okay any questions and which a way to slow this down but you know the details maybe I'll let you kind of ruminate on but the high level is that there's two times the number of parameters for the forward PaaS and four times the number of parameters for the backward PaaS and we can just kind of work it out via the chain row here for the homeworks. Are we also using the you said some type works implementation is lasome isn't are we allowed to use ugrad or we are doing the like entirely by hand to the gradient. So the question is, and the homework, are you going to compute gradients by hand? And the answer is no. You're going to just use pi torch gradients. This is just to break it down so we can do the county flops. Okay. Any questions about this before I move on? Okay. Just to summarize, the forward PaaS is for this particular model is two times the number of data points times the number of parameters, and backwards is four times the number of data points times the number of parameters, which means that total it's six times the number of data times parameters. And that explains why there was at six in the beginning when I asked the motivating question. So now this is for a simple linear model, but it turns out that many models, this is basically the bulk of a computation when essentially every computation you do has touches essentially a new parameters roughly. And obviously this doesn't hold. You can find models where this doesn't hold because you can have like one parameter through parameter sharing and have a billion flups, but that's generally not what models look like. Okay, so let me move on. So far, I basically finished talking about the resource accounting. So we looked at tensors. We looked at some computational tensors. We looked at how much tensors take to store and also how many flops tensors take when you do various operations on them. Now let's start building up different models. I think this part isn't necessarily going to be that conceptually interesting or challenging, but it's more for maybe just completeness. So parameters in ptorture stored is these nn parameter objects. Let's talk a little bit about parameter initialization. So if you have, let's say a parameter that has okay, so you generate up, okay, do your sorry, your W parameter is an input dimension by hidden dimension matrix. You're still in the linear model case. So let's just turn n an input and let's feed it through the output. Okay, so rand add unit Gaussian seems incuous. What happens when you do this is that if you look at the output, you get some pretty large numbers. And this is because when you have the number grows as essentially the square root of the hidden dimension. And so when you have large models, this is going to blow up. And training can be a very unstable. So typically what you want to do is initialize in a way that's invariant to hidden, or at least when you guarantee that it's not going na blow up. And one simple way to do this is just rescale by the one of a square root number of inputs. So basically, let's redo this. W equals is a parameter where I simply divide by the square root of the input dimension. And then now when you feed it through the output now you get things that are stable around. You know, this would actually concentrate to something like normal zero one. Okay? So this is basically, you know this has been explored pretty extensively. And deep learning literature is known up to constant as savior initialization. And typically I guess it's fairly common if you want to be extra safe, you don't trust the normal because it doesn't have it has unbounded tails and you just say I'm gonna to truncate to -33. So I don't get any large values and I don't want any to mess with that. Okay, so so let's build just a simple model. It's going to have d dimensions and two layers. There's this, you know, I just made up this name cruncher. It's a custom model, which is a deep linear network, which has n numb layers, layers. And each layer is a linear model, which has essentially just a matrix multiplication. Okay, so the parameters of this model looks like I have layers for the first layer, which is A D by d matrix, the second layer, which is also d by d matrix, and then I have a head or a final layer. So if I get the number of parameters of this model, then it's going to be d squred plus d squred plus d. Okay, so nothing too surprising there. And I'm going to move it to the GPU because I want this to run one fast and I'm going to generate some random data and feed it through the data. And the forward PaaS is just going through the layers and then finally applying the head. Okay, so with that model, let's try to I'm going to use this model and do some stuff with it. But just one kind of general digression. Randomness is something that sort of can be annoying in some cases. If you're trying to reproduce a bug, for example, it shows up in many places. Initialization, dropout, data ordering, and just the best practices. We recommend you always PaaS a fix of random seed so you can reproduce your model, or at least as well as you can. And in particular, having a difference random seat for every source of randomness is nice because then you can, for example, fix initialization or fix the data ordering by very other things. Determinism is your friend when you're debugging and encode. Unfortunately, there's many places where you can use randomness and just be cognizant of which one you're using. And just if you want to be safe, just set the c two for all of them. Data loading. I guess I'll go through this quickly. It's not itbe useful for your assignment. So in language modeling, data is typically just a sequence of integers because this is, remember, output by the tokenizer and you thealize into, you can serialize them into an umpire arrays. And one I guess thing that's maybe useful is that you don't want to load all your data into memory at once, because, for example, the llama data is the 2.8 tb, but you can sort of pretend to load it by using this handy function called venmap, which gives you essentially a variable that is mapped to a file. So when you try to access the data, it actually, on demand, loads the file. And then using that, you can create a data loader that is samples data from your batch. So I'm going to skip over that. Just the interest of time. Let's talk a little bit about optimizer. So we've defined our model. So there's many optimizers just kind of maybe going through the intuitions behind some of them. So of course, there's stochastic gradient descent. You compute the gradient of your batch, you take a step in that direction, no questions asked. There's an idea called momentum, which dates back to classic optimization, where you have a runny average of your gradients and you update against the ny average instead of your instantaneous gradient. And then you have at grad, which you scale the gradients by the average over the norms of, or I guess not the norms, the square of the gradients. You also have rms prop, which is an improved version of vatigray, which uses exponential averging rather than just like a flat average. And then finally, Adam, which appeared in 2014, which is essentially combining rms prop and momentum. So that's why you're maintaining both your runaverage of your gradients, but also runaverage of your gradients squared. Okay. So since you're going to implement Adam in homework one, I'm not going to do that. Instead, I'm going to implement adegrad. So the way you implement an optimizer in PyTorch is that you override the optimizer class and you have to let's see, maybe I'll and I'll get to the implementation once we step through it. So let's define some data, compute the forward PaaS on the loss, and then you compute the gradients. And then when you call optimizer step, this is where the optimizer actually is active. So what this looks like is your parameters are grouped by, for example, you have one for the layer zero, layer one, and then the final no weights. And you can access a state, which is a dictionary, from parameters to whatever you want to store as optimizer state. The gradient of that parameter you assume is already calculated by the backward PaaS. And now you can do things like you know in in adegrad, you're storing the sum of the gradient squares. So you can get that g two variable and you can update that based on the square of the gradient. So this is an element y square ing of the gradient, and you put it back into the state. Okay. So then obviously, your optimizer is responsible for updating the parameters. And this is just the you update the learning rate times the gradient divided by this scaling. So now this state is kept over across multiple invocations of the optimizer. Okay. So and then at the end of your optimizer step, you can free up the memory just to which is, I think, going to actually be more important when you look when we talk about monparallelism. Okay, so let's talk about the memory requirements of the optimizer states. And actually basically at this point, everything so you need the number of parameters in this model is d squared times the number of layers plus d for the final head, the number of activations. So this is something we didn't do before. But now for this simple model, it's fairly easy to do. It's just b times, d times, the number of layers you have. For every layer, for every data point, for every dimension, you have to hold the activations for the gradients. This is the same as the number of parameters and the number of optimizer states in for adegrad. You'll remember if we had to store the gradient squared, so that's another copy of the parameters. So putting all together, we have the total memory is assuming fp 32, which means four bytes times the number of parameters, number of activations, the number of gradients, the number of optimizer states. And that gives us some number which is 496 here. Okay, so this is a fairly simple calculation in the assignment. One, you're going to do this for the transformer, which is a little bit more involved because you have to there's not just matrix multipcations, but there's many matrices. There's attention and there's all these other things. But the general form of the calculation is the same. You have parameters, activations, gradients and optimizer states. Okay. And so in the flops required again for this model is six times the number of tokens or the number of data points times the number of parameters. And that's basically concludes the resource accounfor this particular model. And for reference, if you're curious about working this out for transformers, you can consult some of these articles. Okay. So in the remaining time, I think maybe I'll pause for questions. And we talked about building up the tensors, and then we build a kind of a very small model. And we talked about optimization and how much memory and how much compute was required. Yeah so the question is why do you need to store the activations? So naively, you need to store the activations because when you're when you're doing the paper PaaS, the gradients of, let's say, the first layer depend on the activation. So the gradients of the eilayer depends on the activation there. Now, if you're smarter, you don't have to store the activation tions or you don't have to store all of them. You can recompute them. And that's something a technique world called activation checkpointing, which we can talk about later. Okay. So let's just do this quick. Actually, there's not much to say here, but here's your typical training loop where you define the model, define the optimizer, and you get the data fee forward, backward and take a step in a parameter space. And I guess itbe more interesting. I guess next time I should show actual one b clock, which isn't available on this version. So one note about checkpointing. So training language models takes a long time, and you certainly will crash at some point. So you don't want to lose all your progress. So you want to periodically save your model to disk. And just to be very clear, the thing you want to save is both the model and the optimizer and probably which iteration you're on to add that and then you can just load it up. One maybe final note and is I alluded to kind of mix precision in your training. Choice of the data type has the different tradeoffs. If you have higher precision, it's more accurate and stable, but it's more expensive and lower precision, vice versa. And as we mentioned before, by default, the recommendations used float 32, but try to use vf 16 or even fph H whenever possible. So you can use lower precision for the fifforward PaaS, but flow 32 for the rest. And this is an idea that goes back to the 2:17. There's exploring mixed precision training. PyTorch has some tools that automatically allow you to do mixed precision training because it can be sort of annoying to have to specify which parts of your model it needs to be, what precision. Generally, you define your model as sort of this clean, modular thing, and specifying the precision is sort of like something that needs to cut across that. And one I guess maybe one kind of general comment is that people are pushing the envelope on what precision is needed. There's some papers that show you can actually use fp eight all the way through. There's I guess one of the challenges is, of course, when you have lower precision, it gets very numerically unstable. But then you can do various tricks to control the numerics of your model during training so that you don't get into these bad regimes. So this is where I think the systems and the model architecture design kind of are synergistic because you want to design models. Now that we have a lot of model design is just governed by hardware. So even the transformer, as we mentioned last time, is governed by having GPU's. And now if we notice that nvidia chips have the property that if lower precision, even like int four, for example, as one thing, now if you can make your model training actually work on int four, which is I think quite hard, then you can get massive speed ups and your model will be more efficient. Now there's another thing which we'll talk about later, which is often you'll train your model using more sane floating point. But when it comes to inference, you can go crazy and you take your preach model and then you can quantize it and get a lot of the gains from very, very aggressive quantization. So somehow training is a lot more difficult to do with low precision, but once you have a train model, it's much easier to make it low precision. Okay. So I will wrap up there just to conclude, we have talked about the different primitives to use to train a model building up from tensors all the way to the training loop. We talked about memory accounting and flops accounting for these simple models. Hopefully, once you go through assignment one, all these concepts will be really solid because you'll be applying these ideas for the actual transformer. Okay. See you next time.
最新摘要 (详细摘要)
概览/核心摘要 (Executive Summary)
本讲座(Stanford CS336, Spring 2025)深入探讨了从零开始构建语言模型所需的PyTorch核心组件和资源核算方法。核心目标是让学习者理解模型构建的底层机制,并掌握内存与计算效率的量化分析。讲座首先通过估算大规模模型(如700亿参数模型)训练时间和单机(8卡H100)可训练最大模型规模(约400亿参数,不含激活)来强调资源核算的重要性,指出效率直接关系到成本。
讲座详细介绍了PyTorch中的张量(Tensor)及其不同浮点数表示(float32, float16, bfloat16, fp8)对内存占用的影响,强调bfloat16在动态范围和内存效率上的优势。接着,讨论了计算资源(CPU/GPU)、张量操作(视图、复制、einops库)及其计算成本(FLOPs)。核心结论是矩阵乘法(matmul)主导计算量,其FLOPs约为2 * M * K * N。讲座引入了模型FLOPs利用率(MFU)作为衡量硬件效率的指标。关于梯度计算,指出反向传播的计算量约是前向传播的两倍,因此训练一个参数的总FLOPs约为6 * num_tokens * num_params。最后,概述了模型参数初始化、自定义模型构建、数据加载、优化器(以Adagrad为例说明其状态内存需求,Adam优化器(使用float32时)通常需要额外存储两倍参数量的状态,使得参数、梯度及优化器状态共需约16字节/参数(不含激活))、训练循环和混合精度训练等实践环节。
PyTorch基础与资源核算动机
Speaker 1首先回顾了上节课关于语言模型概览及从零构建原因的内容,并提及了Tokenization。本讲座聚焦于实际模型构建,涵盖PyTorch原语、模型、优化器和训练循环,并特别关注效率,即内存和计算资源的使用。
资源评估的启发性问题
为了引出资源核算的重要性,讲座提出了两个可通过“餐巾纸计算”(napkin math)解答的问题:
-
训练时长估算:在1024张H100上训练一个700亿参数的稠密Transformer模型,使用15万亿Token,需要多长时间?
- 计算方法:
- 总FLOPs ≈
6 * num_parameters * num_tokens - H100单卡FLOPs/秒(假设MFU为0.5)
- 计算集群总FLOPs/天
- 总FLOPs / 集群每日FLOPs ≈ 144天
- 总FLOPs ≈
- “这个6倍参数量乘以Token数的公式来源,将在本次讲座中讨论。”
- 计算方法:
-
最大可训练模型规模估算:在8张H100(每张80GB HBM显存)上,使用AdamW优化器(不采用特别技巧),能训练的最大模型参数量是多少?
- 计算方法:
- 每个参数所需字节数(参数、梯度、优化器状态):16字节 (讲座后续会解释来源)
- 总可用显存 / 每参数字节数 ≈ 400亿参数
- “这是一个非常粗略的估算,因为它没有考虑激活值(activations)的显存占用,激活值大小取决于批量大小(batch size)和序列长度(sequence length)。”
- 计算方法:
Speaker 1强调:“效率是关键(efficiency is the name of the game)。要做到高效,你必须确切知道你实际消耗了多少FLOPs,因为当这些数字变得巨大时,它们直接转化为美元,而你希望这个成本尽可能小。”
学习目标
- 机制 (Mechanics):PyTorch及其底层工作原理。
- 思维模式 (Mindset):资源核算。
- 直觉 (Intuitions):目前主要关注机制和思维模式,关于模型性能的直觉暂为宏观层面。
内存核算 (Memory Accounting)
张量 (Tensors)
- 深度学习中存储一切(参数、梯度、优化器状态、数据、激活值)的基础构建块。
- 创建方式多样,包括未初始化创建。
浮点数表示与内存占用
-
float32(fp32 / 单精度):- 32位:1位符号,8位指数,23位尾数。
- 每个元素占 4字节。
- 被认为是计算领域的“黄金标准”,在机器学习中常被称为“全精度”(尽管科学计算领域可能使用
float64或更高精度)。 - 内存计算:
num_elements * element_size_in_bytes。 - 示例:一个
4x8的float32张量占用32 * 4 = 128字节。 - GPT-3中一个前馈层权重矩阵(
2048 x 8192)约占用 2.3GB。
-
float16(fp16 / 半精度):- 16位:1位符号,5位指数,10位尾数。
- 内存减半,计算通常也更快。
- 缺点:动态范围(dynamic range)有限,容易出现上溢(overflow)或下溢(underflow),例如
1e-8在float16中会变成0。不适用于大型模型训练或需要高精度表示的场景。
-
bfloat16(bf16 / Brain Float):- 由Google于2018年开发,专为深度学习设计。
- 16位:1位符号,8位指数,7位尾数。
- 特点:与
float16占用相同内存,但拥有与float32相同的动态范围,牺牲了尾数精度。 - 深度学习对动态范围的需求高于精度,因此
bf16表现良好。1e-8在bf16中不会是0。 - “bf16基本上是进行计算时通常会使用的数据类型,因为它对于前馈和反向传播计算来说足够好。”
- 注意:存储优化器状态和参数本身通常仍推荐使用
float32以免训练不稳定。
-
fp8(8位浮点数):- 由Nvidia于2022年开发。
- 位数极少,表示非常粗略。存在两种变体,侧重于不同方面(更高分辨率或更大动态范围)。
- H100 GPU支持
fp8。
-
混合精度训练 (Mixed Precision Training):
- 根据模型不同部分的需求选择最低可用精度。
- 例如:参数和优化器状态使用
float32,前向传播和反向传播中的矩阵乘法使用bf16,某些如Attention的敏感部分可能仍用float32。 - Speaker 1观点:“目前来看,在深度学习中你可能不想使用float16。”
计算核算与硬件 (Compute Accounting & Hardware)
CPU 与 GPU 数据传输
- PyTorch张量默认在CPU上创建和存储。
- 使用GPU进行计算必须显式将张量转移到GPU:
tensor.to('cuda')或创建时指定device='cuda'。 - 数据从CPU RAM传输到GPU HBM(High Bandwidth Memory)需要时间。
- “在PyTorch中处理张量时,应始终清楚它驻留在哪里。”
硬件概览
- 示例中使用的GPU:NVIDIA H100,拥有80GB HBM。
张量操作与视图 (Tensor Operations & Views)
- 张量的本质:PyTorch中的张量是指向已分配内存区域的指针,并包含元数据(如
stride)来解释如何索引该内存。 - 视图 (Views):许多操作(如索引、
tensor.view()、tensor.transpose())不会创建数据的副本,而是创建指向相同底层存储的新“视图”。- 修改视图会影响原始张量,反之亦然。
tensor.storage().data_ptr()可检查两个张量是否共享底层存储。
- 连续性 (Contiguity):
- 如果张量在内存中是按其逻辑顺序连续存储的,则为连续张量。
- 转置等操作可能导致非连续张量。
- 对非连续张量执行某些视图操作可能会失败。此时需先调用
tensor.contiguous(),这会创建一个数据的副本使其连续。 tensor.reshape()行为类似tensor.contiguous().view()。
- 建议:视图操作本身开销很小(不分配新内存),可多加使用以提高代码可读性。但需注意
contiguous()或reshape()可能触发数据复制。
创建新张量的操作
- 元素级操作(element-wise operations)通常会创建新张量。
torch.triu()(上三角矩阵) 是一个元素级操作,常用于创建因果注意力掩码 (causal attention mask)。
矩阵乘法 (Matrix Multiplications / Matmuls)
- 深度学习的核心运算。
- 支持批处理:例如,一个
(batch_size, seq_len, input_dim)的张量与一个(input_dim, output_dim)的权重矩阵相乘,PyTorch会自动处理为在每个batch和sequence position上进行独立的矩阵乘法,结果为(batch_size, seq_len, output_dim)。
einops 库
- 受爱因斯坦求和约定启发,用于以更清晰、不易出错的方式处理张量维度操作。
- 动机:避免使用难以理解的负数索引(如
transpose(-2, -1))。 - 核心思想:为维度命名。
jax typing库可用于在类型注解中以字符串形式声明维度名称(如Float[Tensor, "batch seq_len hidden_dim"]),但PyTorch本身不强制执行。
einsum(fromeinops):- 用于执行带维度名称的矩阵乘法或更广义的张量缩并。
- 示例:
einsum(x, y, 'b s1 h, b s2 h -> b s1 s2')表示将两个张量x(batch, seq1, hidden) 和y(batch, seq2, hidden) 沿hidden维度相乘并求和,保留batch,seq1,seq2维度。 ...可用于表示任意数量的前导广播维度。- Speaker 1观点:“一旦你习惯了它,它会比使用-2、-1索引好得多。” 并且
torch.compile可以有效地编译einsum操作。
reduce(fromeinops): 用于按名称聚合维度。rearrange(fromeinops): 用于重排、拆分或合并维度,类似更强大的view。- 示例:将
batch seq (heads hidden_dim)拆分为batch seq heads hidden_dim。
- 示例:将
张量操作的计算成本 (FLOPs)
- FLOP (Floating Point Operation):指单个浮点运算(如加法、乘法)。衡量计算量。
- FLOPs/sec (Floating Point Operations per second):衡量硬件速度。本课程避免使用大写S的FLOPs,而明确写为
FLOPs/sec。 - 数量级参考:
- GPT-3训练约消耗
3.23e23FLOPs。 - GPT-4训练(推测)约消耗
2e25FLOPs。 - 美国曾有行政命令要求超过
1e26FLOPs 的基础模型需向政府报告(后被撤销)。欧盟AI法案中有1e25FLOPs 的阈值。
- GPT-3训练约消耗
- 硬件性能:
- A100峰值性能(fp16/bf16):312 TFLOPs/sec。
- H100峰值性能(fp16/bf16):1979 TFLOPs/sec (带稀疏性),约一半即 ~990 TFLOPs/sec (不带稀疏性)。
fp32在H100上的性能远低于此。- 稀疏性说明:NVIDIA宣传的稀疏性加速特指“结构化稀疏性(structured sparsity),即每4个元素中有2个为0”,Speaker 1评论:“没人用它,这是市场部门的说法。”
- 矩阵乘法的FLOPs:对于
(M, K) @ (K, N) -> (M, N),FLOPs ≈2 * M * K * N(每个输出元素需要K次乘法和K-1次加法,近似为2K)。- “你应该记住,如果你在做矩阵乘法,FLOPs数量是三个维度乘积的两倍。”
- 其他操作的FLOPs通常与张量大小成线性关系,远小于大型矩阵乘法。因此,性能分析常聚焦于矩阵乘法。
壁钟时间与模型FLOPs利用率 (MFU)
- MFU (Model FLOPs Utilization) = (实际有效FLOPs / 实际耗时) / 硬件理论峰值FLOPs/sec。
- 衡量模型在特定硬件上实际利用硬件计算能力的程度。
- MFU > 0.5 通常被认为是好的。
- MFU通常在矩阵乘法占主导时较高。
- 实验数据 (H100, 32x32x32 matmul):
float32: 耗时0.16s, 实际FLOPs/sec 5.4e13, H100float32理论峰值 67 TFLOPs/sec, MFU ≈ 0.8。bfloat16: 耗时0.03s, 实际FLOPs/sec更高。但相对于bfloat16的理论峰值(约990 TFLOPs/sec),MFU可能反而较低(示例中为0.26),表明理论峰值有时过于乐观。
- 建议:始终对代码进行基准测试,不要假设能达到理论性能。
梯度与反向传播的计算成本
- 两层线性网络示例:
Y = W2 @ (W1 @ X)- 前向传播FLOPs:
2 * num_tokens * (params_W1 + params_W2)。对于简单线性模型,可简化为2 * num_tokens * total_params。 - 反向传播FLOPs:
- 计算
dL/dW2涉及类矩阵乘法,FLOPs ≈2 * B * D * K(假设W2是D*K)。 - 计算
dL/dH1(H1是W1的输出) 涉及类矩阵乘法,FLOPs ≈2 * B * D * K。 - 对W1重复此过程。
- 核心结论:反向传播计算量大约是前向传播的两倍。
- “前向传播是2倍参数量,反向传播是4倍参数量。” (这里的参数量指的应是与数据点数相乘后的总计算量基数)
- 计算
- 总FLOPs (训练一步):前向 + 反向 ≈
(2 + 4) * num_tokens * num_params=6 * num_tokens * num_params。- 这解释了讲座开头估算训练时长时使用的
6x因子。 - 此规则对许多模型(包括Transformer,当序列长度不过大时)大致成立。
- 这解释了讲座开头估算训练时长时使用的
- 前向传播FLOPs:
构建模型 (Building Models)
参数 (Parameters)
- 在PyTorch中,模型参数通常存储为
torch.nn.Parameter对象,它们是会自动记录梯度的张量。
参数初始化 (Parameter Initialization)
- 问题:朴素的高斯初始化(如
torch.randn)可能导致激活值随网络深度增加而爆炸或消失,因为方差会累积(输出标准差约与sqrt(hidden_dimension)成正比)。 - 解决方案 (Kaiming/Xavier 初始化思想):按
1 / sqrt(input_dimension)缩放初始权重。- 例如:
W = nn.Parameter(torch.randn(in_dim, out_dim) / math.sqrt(in_dim)) - 这有助于使输出激活值的方差保持在1附近。
- 例如:
- 额外技巧:截断正态分布(truncated normal),例如将值限制在
[-3, 3]标准差范围内,以避免极端值。
自定义模型示例:Cruncher (深度线性网络)
- 一个包含
num_layers个线性层(矩阵乘法)的简单模型。 - 使用
nn.Module和nn.ModuleList构建。 - 参数量计算:若每层为
D x D矩阵,共L层,加一个D维输出头,则总参数为L * D*D + D。 - 模型和数据都需移至GPU (
model.to('cuda'),data.to('cuda'))。
随机性管理 (Randomness)
- 随机性来源:参数初始化、Dropout、数据打乱顺序等。
- 最佳实践:始终为每个随机源传递固定的随机种子 (
random_seed) 以保证可复现性。torch.manual_seed(),numpy.random.seed(),random.seed()。
- “确定性是调试时的朋友。”
训练循环组件 (Training Loop Components)
数据加载 (Data Loading)
- 语言模型数据通常是Token ID序列(整数)。
- 内存映射 (Memory Mapping):对于非常大的数据集(如Llama的2.8TB数据),一次性加载到内存不可行。可使用
numpy.memmap,它允许像操作内存数组一样操作磁盘上的文件,数据按需加载。 torch.utils.data.Dataset和torch.utils.data.DataLoader用于高效批处理数据。
优化器 (Optimizers)
- 常见优化器回顾:
- SGD (Stochastic Gradient Descent):沿负梯度方向更新。
- Momentum:引入梯度的一阶动量(指数移动平均),加速收敛并抑制震荡。
- Adagrad (Adaptive Gradient Algorithm):为每个参数维护一个梯度的平方和,并用此来调整学习率(梯度大的参数学习率减小)。
- RMSProp (Root Mean Square Propagation):Adagrad的改进,使用梯度的平方的指数移动平均,避免学习率过早衰减。
- Adam (Adaptive Moment Estimation):结合了Momentum(一阶矩估计)和RMSProp(二阶矩估计)。是目前广泛使用的优化器。
- 实现自定义优化器 (以Adagrad为例):
- 继承
torch.optim.Optimizer。 - 在
step()方法中:- 遍历
self.param_groups中的参数。 - 访问参数的梯度 (
param.grad)。 - 更新优化器状态(
self.state[param]),例如Adagrad中累积梯度平方和 (g_squared_sum)。 - 根据优化算法更新参数值 (
param.data.add_())。 - (可选)在
step()结束时zero_grad()(或在训练循环中显式调用)。
- 遍历
- 继承
- 优化器状态的内存需求:
- Adagrad:为每个参数额外存储一个浮点数(梯度平方和)。
- Adam:为每个参数额外存储两个浮点数(一阶矩和二阶矩)。
- 因此,使用Adam时,除了参数本身和梯度,优化器状态大约需要两倍于参数量的内存。
- 总计(不含激活,使用
float32):参数(P) + 梯度(P) + Adam状态(2P) = 4P。若为float32(4字节/值),则每个参数约需 16字节。
完整模型资源需求总结 (以简单线性网络为例)
- 参数 (Parameters):
D*D * num_layers + D(假设fp32,每个参数4字节) - 激活 (Activations):
batch_size * D * num_layers(假设fp32,每个激活值4字节)。- “为什么需要存储激活值?朴素地看,因为在反向传播时,计算前一层的梯度依赖于后一层的激活值。如果更聪明些,可以不必存储所有激活值,可以通过重新计算来获得,这是一种称为激活检查点(activation checkpointing)的技术。”
- 梯度 (Gradients):与参数量相同。
- 优化器状态 (Optimizer States):Adagrad为1倍参数量,Adam为2倍参数量。
- 总内存 =
4 bytes * (num_params + num_activations + num_gradients + num_optimizer_states) - 总FLOPs ≈
6 * num_tokens * num_params
模型检查点 (Checkpointing)
- 语言模型训练耗时长,易中断。
- 应定期保存模型状态(
model.state_dict())和优化器状态(optimizer.state_dict())以及当前迭代次数到磁盘,以便恢复训练。
混合精度训练进阶 (Mixed Precision Training)
- 权衡:高精度(准确、稳定,但昂贵) vs. 低精度(便宜,但可能不稳定)。
- 建议:
- 默认使用
float32。 - 尽可能尝试
bf16甚至fp8。 - 一种常见策略:前向/反向传播使用低精度(如
bf16),参数更新和主权重保持float32。
- 默认使用
- PyTorch提供
torch.cuda.amp(Automatic Mixed Precision) 工具简化此过程。 - 前沿研究:探索全程使用
fp8进行训练,挑战在于数值稳定性控制。模型设计与硬件特性协同发展,例如NVIDIA芯片对低精度(如int4)的支持可能驱动新的模型架构。 - 训练 vs. 推理:训练对精度要求更高。模型训练完成后,在推理时可以采用更激进的量化策略(如
int8,int4) 以获取性能提升。
总结与展望
讲座系统梳理了从张量到完整训练循环的PyTorch原语,并重点讲解了内存和FLOPs的核算方法。通过作业一(Assignment 1)对Transformer模型进行类似的分析,将有助于巩固这些概念。