Stanford CS336 Language Modeling from Scratch | Spring 2025 |07 Parallelism 1

本次讲座探讨了在训练大规模语言模型时进行多机优化的并行计算问题。由于单个GPU在算力和内存方面均无法满足大型模型的需求,必须采用跨机器并行策略。讲座首先介绍了网络基础,强调了硬件层级结构的重要性:单台机器内的多个GPU通过NVSwitch等高速接口互连,而机器间的通信则依赖相对较慢的网络交换机(如InfiniBand)。这种异构通信特性(节点内快、节点间慢)以及特定规模(如256个GPU)内的高速全互联能力,对并行策略的选择有深远影响。随后,讲座回顾了关键的集体通信操作,如AllReduce、Broadcast、AllGather和ReduceScatter,并特别指出AllReduce操作可以分解为ReduceScatter和AllGather操作,这在带宽受限的情况下能达到最优通信效率。讲座旨在阐释如何结合不同的并行化策略高效训练超大规模模型,并将通过案例分析展示这些策略在实际大规模分布式训练中的应用。

媒体详情

上传日期
2025-05-13 17:44
处理状态
已完成
转录状态
已完成
Latest LLM Model
gemini-2.5-pro-exp-03-25

转录

下载为TXT
speaker 1: All right. So today is going to be the second of the basic systems lectures, and now we're going to move on to sort of multi machine optimization. And so the focus today is going to be all about parallelism across machines. And so the goal today is going to move from optimizing a single GPU's throughput to being able to understand the complexities and the details that are required to train really large models. When models get large, they no longer fit on a single GPU. So you've got to split up your models across different machines, but also you've got to be able to leverage all of the different servers that you have in order to train these models quickly. So we've got both compute and memory concerns that we're going to have to deal with and communication across different machines, it's going to be quite heterogeneous. We have different kinds of communication across GPU's at different levels of hierarchy. And so this is going to lead to different parallelization paradigms. People use many different paralylization strategies all together at once. And we're gonna to talk through each one of the very popular ones, and then we'll talk about how you combine them together in order to efficiently train a very large model. And then I'm going to end the lecture with sort of looking at some examples of how people are actually using these parallezation strategies to run their large scale distributed training runs. And so that's going to roughly map to the different parts of this lecture. We're just going to talk about the basics of networking first, and then we're going to talk about you, how do each of these sort of networking hardware concepts map to different parallezation strategies? And then finally, some case studies to close off with to show you how it all comes together, right? So I told you about GPU scaling last week and know it's quite impressive seeing this super exponential curve of flops per GPU going way, way up. But if we want to you know rapidly scale out both our compute and memory, a single GPU isn't enough. We're going to have to wait for another couple years for this curve to continue going upwards and upwards and upwards. So if we want to train a really powerful language model here and now today, well, we have to rely on multi machine parallelism. So if we look at you know the world's fastest supercomputers, that's what's being shown on the right here. You know the fastest supercomputers have you exoplops and exflops of compute? Those are kind of the Green lines that you see over there. That's what you're really gonna to have to rely on if you're gonna to try to train you know the biggest, baddest language models today. And so that's the compute side of why you want to think about multi machine parallelism, but we've also got a memory angle for thinking about the same thing. So these two are really the core resources and the core concerns that you're going to have to think about. So in terms of memory, right, many of the models are getting quite big. And of course, you know memory on GPU is also growing, but not quite as quickly. And a single GPU is not gonna to be able to fit these models. Maybe eventually in the distant future, we won't have to worry about a lot of these, but we've got billions and billions of parameters. They're not going to fit very nicely into a single GPU. So we have to be very respectful of the memory constraints that we have. So those are kind of the realities that we have to deal with and what are kind of the tools that we have to have to be able to handle these. Well, you know, GPU's, I'm sure you've noticed in the class cluster don't come in sort of singletons ons, right? A single machine will have multiple GPU's within the same sort of physical rack. And so here's an example. I took this, I think, from the GPT neo x paper, but this is an old example. But the same lesson applies to the H -100 machines that you have in class. So here there's eight different GPU's. They're connected to the various cpu's through fast interconnects within each GPU's. You see this nd v switch thing at the bottom. This is very, very fast connections across these eight GPU's. But if these eight GPU's want to talk to GPU's on a different machine, they're gonna to have to go through a networking switch. And you see this purple wthat says hdr infinite band. You know, that's a much slower connection compared to the nvelent connection. You can sort of see the difference in the throughput that's like about eight times slower per lane. And so this kind of hardware hierarchy that we have is going to have big implications for how we're gonna to end up paralyzing our models in practice, right? And so you can kind of keep this mental model with you as I talk through these things. You know we have very, very fast connections within a single machine. And then when we go across machines, it's gonna to get slower. And then depending on the kind of hardware we're using, there might even be another level of slowness once we go beyond, let's say, 256GPU's network together. Many of you may already know this, having taken systems or networking classes, but here's a very, very brief refresher on collective communication operations. And the reason why I'm going to bring this up is there is one particular important sort of identity or equivalence that you will kind of need to know to really understand some of the finer points of the performance characteristics of the parallelization algorithms, right? So I'll talk through these, and then I'll talk through one important sort of performance implication. So the first one, which all of you probably have heard of, is all reduced, right? So you have four machines, four ranks, in this case, each one having its own sort of piece of data. And what youlike to do is perform some sort of reduction operation, let's say, want to sum all these inputs, and then I want the output to be sort of copied over to every single machine, right? And this is gonna to have roughly the cost of like two times the total number of things that you're all reducing. You have a broadcast operation, and here I'm taking a single sort of input from rank two, and I'd like to copy it out to all of the remaining ranks, right? And this is going to have roughly on the order of one times the total number of sort of outputs in terms of the communication cost. And then we've got reduction where we got different inputs, and that's going to be summed up and then sent only to one machine. And then the two that are quite important, even though these may not be quite as common, is gonna to be the all gather and scatter. So all gather is an operation where I'm taking a single sort of subcomponent of, let's say, my parameters from rank zero, and I'm copying it over to all the ranks. Same thing with rank 123. So each of these are handling different parts of, let's say, the parameters, and they're copied over to the rest of the machines. So that's sort of know copying what I have to everyone else. And the reduced scatter, which is you know I'm taking each of the rows, let's say I'm summing them up, and then I'm sending the result only to rank zero, right? So this is a partial version of an all reduce. And hopefully this diagram makes it clear how sort of reduced scatter works. And so all gather and reduce ed scatter are quite important because in some sense, they are the primitive by which many of the paralylization algorithm are going to be built. And so this is kind of an important sort of equivalence or an identity. I will refer to it one or two times as sort of key points in this lecture. If you want to do an all reduce, let's say I've got different GPU's abcd and each of the GPU's are handling a different data point, right? And so I've got different gradients for each of these data points, and I'm going to need to sum those gradients, and then I need to PaaS all of those gradients back to the GPU's, right? This is a classic data parallel operation that I might need to do across my four GPU's. So that would be an all reduced. One important thing though, is this could be replaced with two operations, a reduced scatter. And I'll all gather where a reduced scatter is going to you susort of each of the rows, and then leave the result of the rows in, let's say, gp zero, one, two, three, respectively, right? And then I'm going to do all gather to sort of copy those back out to the remaining GPU's, right? So each GPU now is getting a full sum of a part of the parameters, and then it's going to copy it back to the remaining workers. And in the bandwidth limited regime, this is basically the best that you can do, all reduce. The best that you can do is roughly matching the bandwidth that you can get out of a reduced scatter and all gather. And you can convince yourself this by writing out how many sort of communication operations happen in both all reduced and the right hand side. The final thing that I want to sort of briefly touch on before I sort of move on to talking about the parallelization algorithms, and this is like the one place I'll talk about GPU versus GPU, most of the discussions that they can actually abstract out the underlying hardware, but there is actually sort of one important thing that I'll mention up front so that I can refer to it later as I talk through this. How do we network together different machines or different sort of accelerators in sort of GPU's? Well, you know as I showed you in the GPT neo x slide here, how in the GPU world this generally works is you've got nodes, single machines that contain, let's say, eight GPU's, and then you've got these switches that connect fairly quickly to each other. And these machines are connected all to all up to about 256GPU's. So that's an important threshold up until which you have very fast arbitrary communication between machines. And then above that, you're actually gonna to need sort of much more slow communication, these sort of leaf switches and spine switches, once you go beyond sort of roughly a single racks sworth of GPU. On the other hand, you know if you look at sort of tpu design from Google, they actually take a very different approach to networking sort of their machines. You've got a single sort of tpu chip and they all talk to their neighbors very, very quickly. And so this is a very sort of easily expandable what they call toridal mesh, but you can only talk to your neighbors. And the reason why I'm talking about this right after the all reduced slide is if you think about you know doing these kinds of collective communications like all reduce or reduce scatter. You can implement them just as efficiently on a torrido mesh then you can on a alto connection. And so if you're optimizing purely for collective communications, it makes sense to think about things like GPU networking rather than GPU networking. I'll talk a little bit about pros and cons of this later as I go through different parallezation operations. Okay. So just to put this together right now, we're going to start talking about a new unit of sort of compute, right? Instead of the GPU, the new unit is the data center. The whole data center is going to be the thing that we're going to be doing. And now we're gonna to try to come up with algorithms and sort of sharding strategies that get us two different things. The first one is linear memory scaling. So as I scale up the number of GPU's, the sort of biggest model that I can train is gonna to scale linearly with that, right? I can train bigger and bigger models if I really want to. I also want linear compute scale. As I get more and more GPU's, the useful computation that I'm doing to train the model scales linearly. And then finally, a lot of this, these algorithms are going to be implemented by just calling these very simple collective communications primitives in various ways. And so when we think about the performance characteristics of these parallel algorithms, it suffices the reason about basically counting the collective communications primitives. So that's kind of an important way to think about these. We don't go all the way down to the low level implementation of these algorithms here. Okay. Any questions? Part one, yes.
speaker 2: From the previous slide, does it mean that's better to do with these scattered once it's all gathered?
speaker 1: Right? So this slide, right. So the conclusion to this slide is that they're equivalent, right? And I think if you think about something like parallel doing gradient descent in parallel, all reducthe very natural operation to do because you'll scatter your sorry, you'll distribute your data to different machines, and then you'll have to all reduce your gradients together, right? But what I'm saying is this very natural thing to do of all reduccan actually be written as a sum of two different operations and their equivalent. So there's no performance sort of hit by going from this left representation to this right one, at least in bandwidth. And that's going to have important implications in maybe like five slides. So you can wait a little bit to see why I mentioned this. Okay.
speaker 2: Any other .
speaker 1: questions? Good. Okay. So now we're going to get started. In some sense, this is kind of the exciting algorithmic meat of the lecture. And there three kinds of parallelism strategies, parallelism things that we should really be thinking about. So the first one is data parallelism. So data parallelism at a high level is the idea of I'm gonna to roughly copy the parameters across my different GPU's. I'm not going to worry about splitting my parameters up, but I will take my batch and I will split my batch up and different GPU's or different machines will get different slices of my batch. So that's data parallelism. There's lots of subtleties in how we execute that model parallelism now is starting to say, okay, I don't want all my GPU's to have all the different parts of my model as my models get bigger, that's gonna to be a very big problem. So I need to cut up my model in very clever ways and I need my GPU to handle different parts of my model. So that's gonna to be model parallelism. And then the final piece is kind of activation parallelism. We don't really think too much about activations in our day to day lives because the pi torch handles it very transparently, right? But as the models get bigger and the sequence lengths get longer, the activation memory starts to be a really big problem. So if you want to train these really big models with big, big batch sizes, you have to somehow manage the memory footprint of your activations. And so we have to split those up too. So there's some ways to handle that, right? And when we put all these together, we will have all the tools we need in order to scale up both compute and memory gracefully as we have lots and lots of machines. So these are kind of the core conceptual objects. And now we're going to talk about implementing each of these ideas efficiently. So the starting point of data parallelism is just sort of sgd. If we're doing very naive batch stochastic gradient descent, the formula for doing this looks like this equation that I have you know right here on the slide. Right here I'm taking a back size capital b and I'm gonna to sum up all those gradients and I'm gonna to update my parameters, right? So naive data parallelism is just saying, all right, take your batch size b, split that up and send that to different machines. Each machine will compute some part of the sum and then I will exchange all of my gradients together to synchronize, know after each sort of before each gradient step, I will synchronize my gradients and then I will take a parameter update, right? So now I've been talking to you about compute and memory scaling and all these things. So let's just talk through, know what it looks like for each of these, right? So for compute scaling, data parallelism is pretty great. Each machine, each GPU is gonna to get b over m examples. Then if my batch size is big enough, each GPU is going to get a pretty decent batch size, micro batch size, and it's able to hopefully saturate its compute. So that's good. What's the communication overhead? Well, I'm going to have to transmit twice the number of my parameters. Every batch. Remember, on all reducis going to roughly be twice the amount of stuff that you're all reducing in terms of communication costs. And so this is okay if the batch size zes big, right? If my batch sizes are really big, I can mask the communication overhead of having to synchronize my gradients every now and then. Memory scaling, I'm not touching this at all. Every GPU needs to replicate the number of parameters and needs to replicate the optimizer state. It's pretty bad for memory scaling, right? So if we didn't have to worry about you memory at all, this is an okay strategy. But I think in practice, memory is a problem, right? Like I think every of you sitting here has experienced you trying to put a big model onto a GPU and piytorch telling you all you're out of memory. And this is really a problem with your training as well, because if you can fit more and more back sizes, that's going to make the data parallel more efficient. And so ideally, youlike to save on memory. So let's take a closer look at the memory usage of naive data parallel, right? And the memory situation is actually worse than it looks. It's actually quite terrible because you've done this in assignment one, but we can sort of think about how many copies of our model we need to sort of store. And it's very large. Depending on the precision by which we're doing some of our training, you're going to need to store something like 16 bytes of data per parameter. And in fact, you need to store something like five copies of your weights. And this is really quite bad because if you just want na think about your model parameters, technically you only need two bytes, right? So where did that factor of eight come from? Well, at least you need gradients. And if you're computing your gradients in bf 16, that's another two bytes. But then your optimizer state kind of shows up and that's a really big problem because you've got four bytes of sort of master weights, the things that you're kind of accumulating into sgd, like these intermediate sort of sums that you're doing. You need you know four or two bytes for Adam's first moment estimates because remember, adom keeps track of historical gradients. And then adom also needs second moment estimates, kind of like the variance of the gradients that you've gotten in the past. And like that's gonna to need another four or two bytes. And so what originally looked fine is actually now looking quite grim. And so you know the 16x, if I just sort of draw it as a picture, you know you realize that most of your memory usage, at least in terms of kind of parameter memory, is really being dominated by the optimizer states of your atom optimizer. So your memory consumed is going to be a function of how many bytes are being used for your optimizer state, and that's generally going to be even more than the core parameter ingradient memory usage. And so for a simple example of like a 7.5v model distributed over a 64 accelerator, you're using a ton of memory, right? And this memory scales linearly upwards, total memory at least scales linearly upwards with a number of GPU's. So that's that's no good at all. But if once we sort of look at this picture, we get some very simple ideas, you might wonder clearly, or maybe not clearly, you know, I need the parameters and gradients to be copied across devices. That seems you necessary to do data parallel. But do I really need all the optimizer states to be on every single machine, right? And once you ask that question, you know, you can maybe get to the second row here. And this is going na be called, this is going to be called optimizer state sharding. And if we could do that, then at least in this case, we can go from 120 gb of total memory usage down to 31.4, and then maybe we can start sharthe gradients, and then now we can get to 16.6 gb of memory usage. And then if we also shard the parameters, we can go all the way down to 1.9 gb of memory usage. And that would be a pretty good place to be because now we've sort of fully sharted out you know all sort of the optimizer state and parameter and gradient memory that we need. Yes.
speaker 2: If we're doing, I guess, the grading competition on each of them.
speaker 1: That is a very good question. And the question is, how can we shard the optimizer state? You know when we're doing data parallel, right? GPU zero has to be responsible for data point one. So clearly it needs to know about all the parameters and upit. So how can it possibly shard the optimizer state? And in a way, I think zero, which is what this is, this is the zero overhead data parallel sort of optimizer. This is a very, in some ways, clever idea, because this shows you that even when you're doing data parallel, you don't actually need to copy everything onto every machine. You can be really clever about how you do sort of communications to avoid all of this. So I will talk through exactly this. This is a great question. So what we're going to do is we're going to split up the optimizer states, as I said. So the first and second moments are now split up across all the GPU's, but everyone has the parameters and the gradients, right? So why is this important? Right? If I have the parameters and gradients, let's say I'm gpuu zero, I have the parameters and gradients for everything. That's enough information for me to compute the full gradient, right? Like the full gradient update for this example can be computed. The only thing I can't do is I can't take that gradient and take an atom step. I can't update my parameters unless I see all of the optimizer states, right? So that's kind of the key idea. And so now what's going to happen is GPU zero is going na compute the gradients for everything, but GPU zero is now only responsible for updating the parameters for the shard that they own. And that's kind of the key idea, right? We're gonna to distribute the work of updating the parameters and then we're gonna to synchronize the parameters back. So let me show you in sort of much more gory detail how this works and sort of the reason why it's called zero overhead. So step one, right? Every GPU gets a different data point. Let's say, right? I'm just going to simplify all this batch computation. I have GPU zero through, let's say four, and every GPU gets a single example, and they compute a full gradient on the example that they own. Now what I'm going to do next is I'm going to reduce, scatter the gradients. So I'm going to send the gradients that I'm going to collect in some sense, the gradients that each GPU owns. So GPU zero, let's say, is responsible for this first quarter of the parameters, right? So the parameters are the y axis here and the x axis here is GPU's. And so what we're gonna to do is we're gonna to reduce scatter to make sure that gpuu zero has all of the gradient information from all the other GPU's for the subset of parameters that it is responsible for, right? So now it gets this gradient information from GPU one and GPU two and GPU three, and that's all reduced into GPU zero. Hopefully that's clear now. Now GPU zero has all the information it needs to update its own parameters because it has the optimizer state corresponding to this first part. It has a full summed gradient for this first part and now so it's gonna to take a gradient update on this part of the parameters using gradient and state right? And so that now I have the full updated parameters for this subset in my GPU zero and all I need to do is all gather all of the parameter updated parameters back in to all the reks. Okay, so there's many .
speaker 2: questions here. I'll start here.
speaker 1: Sorry. The communication .
speaker 2: cabeing a .
speaker 1: number of members. That's per so the question was whether the number of prants communication cost was per machine or it's total here, it's going to be total because so this is going to be like one fourth of the parameter is going to be sent three times to this machine, and then you repeat that four times. That was also total Yeah two times number of parameters is total because each block is going to have to be sent to every other kind of machine. Okay, yes. So this .
speaker 2: question is not unique to what you're showing here, but let me think of it. So the adthat optimize what we showed seems to assume like largely assume independence of parameters leadrawn all these like diagrams that show the opposite. You know like we have connected nodes and all that and it seems especially prising when we have and we're trying to split these and update them separately, does that create any issue? Okay.
speaker 1: So the question was Adam W seems to assume parameters operate independently. I'm assuming because you're saying like we track like gradient sums like in the we diagonally sort of update the parameters, right? But we know that that's not fully diagonal. And so is there a problem? There have been you know better attempts at improving sort of atom W to not just be diagonal. There's things like kfac and all these other like second order style optimizers that people have come up with. They haven't dethroned Adam even though they do have their advantages. And there's some really interesting things that you can do with these kinds of improved second order preconditioning methods. Yes. What is the rose that we're reducing over? So you're asking like what is the rose of this picture? Yeah. So imagine this is like parameters here in the rows. So like GPU zero is responsible for some number of parameters. So this is a block of parameters up top. And so when we do reduce scatter, we're saying take the gradients, for example, zero for this block of parameters. Take the gradients, for example, one for this same block of parameters, and then sum them all and put them in rank zero. That's kind of what we're saying here. Cool. Okay. And kind of the key thing here is we're doing a reduced scatter in an all gather, right? And if you kind of remember what I was saying before, well, a reduced scatter in an all gather has the same cost as an all reduce, right? And so there is a little bit of a surprising magic thing that happened here, which is that, well, you know we were doing an all reduced before on all the gradients to make sure everyone's gradients were synchronized. And that cost us two times the number of parameters. But if we're kind of clever about how we're doing the updates, well, we can do a reduced scatter and all gather. And in between the two steps, we can do some computation. And that gives us the same amount of compute communication cost. But now at least for the optimizer state, we fully sharded the optimizer state across the model. So zero stage one is in some sense free in the bandwidth limited regime and gives you memory wins. Yes.
speaker 2: Recontribution for higher notes, do people modify at them to include higher moments?
speaker 1: What do you mean by you can suppress the higher order .
speaker 2: contributions for first and second moments? Her gp is divided by the yes. So it seems like you might as well show more moments of I see. So you're you're roughly saying .
speaker 1: like you could track way more optimizer state to rephrase what you're saying, you could have even more complicated optimizer state because you could divide that by the number of GPU's. While this is what we're gonna to do next is we're actually gonna to make the other components scale with ngpu's. So that's going to make things in some sense not free anymore, right? Like optimizer state will continue to be the bottleneck if we can divide everything by the number of gps. So hopefully that's a reasonable convincing answer. Okay, so we're going to build up stage by stage to zero stage three, which is more complicated. Zero stage two is still relatively simple. So now hopefully that optimizer state sharding trick made sense. I think that's very cool. So now we want na shard even more stuff. I want na shard the gradients across the machines. So roughly, we can do the same kinds of trick as stage one. But there is one additional complexity. And so what's the additional complexity? Well, you know, we can never instantiate a full gradient vector, right? If I ever do the full backwards PaaS and I try to compute a full gradient vector, I might go out of memory, right? So I want my maximum memory usage to basically be bounded by this, which is like full parameters, charded gradient charded optimizer state. And so what we're going to have to do is when we do the backwards PaaS, as we're computing the gradient vector, we can't instantiate the full gradient first and then do communication. What we have to do is as we compute the gradients backwards, as soon as we compute like a layer's worth of gradient, we're going to have to send that over to the corresponding sort of GPU that it belongs to, right? So this is kind of how it works. It's roughly the same idea, right? So now everyone has their own batch component. Everyone incrementally goes backwards on the computation graph. And let's say we're going to operate layer by layer. So layers are sharded atomically to different GPU's. So what we're going to do then is as we go backwards on the computation graph, after we compute a lay's gradients, immediately call a reduction operation to send this to the right worker. So a layer belongs to some worker. Maybe it's like GPU number two in this case. So we're just going to immediately reduce that, send that to the worker at that point. And gradients are now no longer needed. You know I don't need to store the gradients on ranks zero, 13, so I can immediately free that. And then now we continue this process. And so all the machines have their fully updated gradients. And now they have a full gradient for their share of the parameters. They have a full optimizer state for their share of the parameters. Each machine can update their parameters and all gather the parameters back together. This looks like it's maybe more communication because you're doing this kind of like reduction operation, every layer, but this is only for a small amount of parameters, right? It's sharded. And so the full communication remains the same. So zero stage two has some more overhead because we have to synchronize layer by layer and make sure that the gradients are properly sent to the right workers. But the overhead is pretty minimal, right? It's still very simple, fairly straightforward. Now, the last one of these, zero stage three, is more complicated for sure, but it allows you the greatest win of all, which is now essentially everything is divided by the number of GPU's that you have. You can get the maximum savings possible. And if you've heard of fsdp, you've probably used that in some aspect of your life in the past. Fsdp is exactly zero stage three. So now you'll kind of hopefully today know how fsdp works. So the same idea applies. We're going to shard everything, including the parameters. We're going to do the same thing as zero stage two, which is we're going to incrementally communicate and compute things so that we don't keep these big vectors of gradients flying around. And we're going to send in request parameters on demand while we're stepping through the compute graph, both for the forward and backward passes as we go through, we're going to send things around on demand. And of course, the key is to do this with as low overhead as possible. I think the thing that's really surprising about fsdp is not that this is possible, but that this is possible with relatively low overhead. You'll see kind of why it's low overhead in the next slide. I admit that this is maybe not the most friendly graphic to start with, but this is, I promise, the baby version of sdp. The next slide is a little bit more involved, but conceptually, this actually explains everything. So what we're doing is you know we're gonna to have model weights and we're going to be all gathering the model weights as we go. So for each layer, you know no single GPU is gonna to have all the parameters, right? So I can't do the normal thing of saying, Oh, GPU zero, go ahead and run the forward PaaS. That's not possible. So GPU zero is, let's say it only owns the bottom most layer. So it does that computation and then it stops and it says it requests all of the parameters from all the other workers. So it stops and it doesn't all gather, which is right here. You see, there's all gather step. It gathers all the parameters. Now it has the parameters that it needs to do a forward. So you can step forward and sort of compute the layer that it didn't have before. And then now it can free the weights. It doesn't need the weights anymore. Get rid of it. Now I can all gather the next layer, I can do another forward, free the weights and I can repeat this, right? The activvations have to be stored. So the activation memory here is growing, right? So so that's gonna to be an eventual problem. But if we ignore activations for the moment, this is great because I load a layer, I do a Ford, I free it. You know the memory overhead is very low here. Once I get kind of to the end, now I can do the same thing with a backward PaaS, right? I can call backwards. And every time I move backwards through the neural network, you know, I all gather for the parameters that I need you. I can do a reduced scatter to update after the gradients that have been computed. And now I can free the weights. So I can free both the gradients that I don't need and the parameters. And at the very end, you know I've got a fully updated model. And so we've got three different operations that we've got to worry about here. We've got an all gather, we've got another all gather, and then we've got another reduced scatter, basically to update the model after we take the gradient update step. So conceptually, this is just a single step beyond zero stage two, but you do kind of see that there is sort of more overhead. So the total communication cost is now higher, right? We were kind of before we had two times the number of parameters. Everything was kind of free in some sense. Now it's not, right? There's a total of three times the number of parameter communication cost, and there's going to be cost associated with waiting for these communication things to finish. But I think the really cool thing about fsdp is it's actually surprisingly low overhead. You might imagine that because we're doing this crazy thing of asking for and sending parameters back and forth all the time, that you things will be really slow, like we have to be communicating all the time, but you can do this this core idea of overlapping communication and computation. So you want both your sort of you want your GPU to be working while the communications is happening in the background, almost like prefetching, so that by the time you need some piece of information, it's already loaded up, it's already been communicated to you, and you're good to go. And so I'll talk through this example at the bottom here. But this is kind of the key to making fsdp actually somewhat efficient. So let's imagine we have a computation graph that looks something like this, W one, W zero plus W two, W zero times x. Some input, let's say, is y. So some very simple computation graph like this, and then you might run fsdp and you will get actually a computation and communication that looks like this block diagram at the very end here. So the cpu, you know, it's nice that we did the insight systems example last week because hopefully this diagram will now be clear, right? The cpu is going to basically dispatch a bunch of commands asking the communication a part of the GPU to basically go and fetch some parameters. It's going to dispatch things to the GPU to say, okay, all I do, some matrix multiplies and it's going to run you far ahead in some sense of the GPU, right? We've seen this when we were looking at the profiler last week. Now let's look at the sequence of both communication and computation that happens on device. Now remember that I need to sort of gather things on demand. So at the very beginning, I have to make sure that everyone has the weights for layer zero or W zero here. So I do all gather zero, and I'm gonna to wait for that to complete. And once that's completed, I can do a four step on W zero. I can sort of compute x times W zero, let's say, right at this point, you know, all gather once starts at the same time that all gather zero ends. So as I'm doing this matrix multiply, I'm basically already starting to load the next parameters that I need. Of course, my communication is slower, and so there is some gap, but I end much quicker than sort of the initial load. So now forward one can happen. And in the background, once again, I've started to load parameter number two. And this yellow slice here, I'm now freeing the parameters associated with forward one. And then now the other thing here is I'm repeating computation W and that zero is used twice. And so I don't need to communicate this again. This happens very quickly. And I can sort of do this very quickly, right? I have 42 now already loaded before sort of I needed it. And so there's no bubble here. And then I can free number two. That's the entirety of the forward PaaS. And you see that the gaps are relatively small here. And we were able to do a lot of loads before the compute needed to happen. And so by doing this very clever thing of kind of queuing the requests for weights before you actually need them, you can avoid a lot of the overhead associated with communication. And then now at this point of 42, I'm done with the forward PaaS. I can free weight number two and I start on the backward PaaS. And you see that, you know, all gathered two for the backward PaaS is already done. And so I can start on backward two, backward zero, weight zero is already stored. So that's done. And then the high overhead here happens in the backward PaaS because I need to do reduce ed scatters and then all gathers and so on and so forth. Hopefully you see this picture and you say, wow, it's kind of surprising that even though we're doing this crazy sharding, right? Like if you go back to this picture, you know, we fully sharded the planmeters gradients and optimizer states, but the total bandwidth that we need is only three times rather than two times. So that doesn't seem too bad. And sort of the actual bubbles that we see are not horrendous, right? The communication is almost being fully being utilized and the computation is installing for varilong. So we're actually making pretty efficient use of the resources that we do have, which is cool. Okay. Yes.
speaker 2: it's stuck to my understanding left like let's do the GPU at that. Where does the weves get? Yeah. So you need a buffer in which .
speaker 1: you can store these weights. And so you know this picture is is not quite right. Like you will have some overhead that you need associated with reading these weights for the current layer. And also the other big elephant in the room is I haven't talked at all about activation. That's gonna to be like a big chunk because you've got a big set of activations for a full model that I sort of live in here in some sense. Yeah, cool. Right. Okay. So this is kind of distributed data parallel. Like zero is in some ways the the way that people do distributed data parallel efficiently. And so there's different stages. And you know stage one is it's basically free, right? It's doing the same communication pattern as naive data parallel. But you get to shard your optimizer state, that's great. You might as well always do, right? Zero. Stage two is twice the number of parameters. So the total bandwidth consumption is the same, but there is additional overhead in having to do this incremental freeing of the gradients as you go backwards. Zero stage three is more involved. You do three times number of prim communication costs, but it's not so bad, right? Like we did have some overhead in the diagram that we saw before. But if you really cleverly mask your communication patterns, it's actually pretty good. And so people use data parallel even for fairly slow sort of links in your networking pattern. And this is also conceptually very simple. One of the advantages here is know, especially data parallel, doesn't care too much about the architecture, right? I didn't talk at all about how we actually implement a transformer in any of this. It's all very abstracted. And so this is one of the reasons why, for example, fsdp is so popular. It's very easy to write a rapper that paralyzes sort of arbitrary neural networks without having deep knowledge or deep introspection of what the architecture is actually doing. And so know, here's some examples. I worked out some examples because I'm always sort of running out of memory on my GPU's. And you can kind of see what's the maximum size of the model that I can fit on a eight times a 180 gig node. And so for baseline, you might end up with like, Oh, I can fit barely 6 billion parameter model. Whereas I think if I use zero stage three, you know I'm able to fit something like a 50 billion parameter model. There's big savings in my ability to fit larger and larger models by doing things like fsdp to cleverly save on memory. So okay. Oh, sorry, there's a question.
speaker 2: Yes, I guess difference once parameters. What's the difference that so .
speaker 1: model parallelism is really fundamentally about making sure that the parameters just like live in separate let me see if I can find Yeah, Yeah. Like so in some ways, it's that we have charted the parameters ter. So you could call this a kind of parallelism, but the whole point of model parallelism is to make sure that the parameters just live entirely in one machine. We're not gonna to try to shithem across in various ways. Only the activations are gonna to get shift across. And so you'll see very different discussions in the model parallelism section, like the focus there will be on communicating activations rather than communicating parameters, and that will .
speaker 2: be a big difference. Yes. Okay. So you're asking about this step.
speaker 1: like why are we doing all gather to gather weights onto all the machine that when they're only on one machine? Is that right? Yeah. So we need to basically put, we need to take the weights that live on one machine and scatter. Or is it gather or scatter? Sorry, I want to make sure I get this right. The terminology is a little bit sketchy for me, so I want to make sure I get sorry. Yeah. So what we want to do is the same as this, right? So each machine is going to have some parameter that I want to gather across all of the machines in order to make sure that each layer is sort of properly, sort of replicated across all the GPU's. Is that the right question that you're asking? Or are you saying like is there a simpler primitive that we could have invoked? Like are you saying broadcast is the right object rather than algather? I think maybe it's written that way because of some exceptions about layers not living on individual GPU's, but I'm not 100% I agree with you that like broadcast should be able to do the same thing if the parameters live on only one machine. Okay, cool. Alrighty. Okay, let me make sure. Okay, got it. Okay, right. So there is a key resource in data parallel. And this is actually an important idea that I want you to remember. With data parallel, batch size is actually a really critical resource in the sense that you can't paralleze greater your number, sorry, than your batch size, right? Because you can have, at one example on each machine, you can't go to fractional examples per machine. And so this means that there if there's limits to your batch size, you stop being able to use data parallel and there's diminishing returns to batch sizes. So in your assignment, one, you may have played with varying batch sizes, but you kind of know that as you crank up the batch size past a certain point, you start to see sort of fairly rapid diminishing returns to your optimization rates. And there's lots of papers written on this. OpenAI has a really nice one on something called critical batch sizes, where they basically argue that past a certain point, you have very rapid diminishing returns in how much each example is contributing to your ability to optimize. Like basically, the intuition is that below a certain point, you have a lot of gradient noise and reducing that is very valuable. But at a certain point, you're really fundamentally limited by the number of gradient steps you're taking rather than variance reduction. And so that basically means data parallel alone isn't going to get you to arbitrarily large parallelism. And this batch slized thing is a really important resource. You want to essentially, you have a fixed maximum batch size and you can spend it in different ways. And I'll talk about that later because other kinds of parallelism also benefit from having sort of bigger batches. And so you use your batch size in certain parts, okay. And issues are going to remain with data parallel. You know zero stations 12 don't let you scale memory zero. Stage three is nice in principle, but it can be slow. And maybe more importantly, and this relates to the earlier question, it does not reduce activation memory, right? I ideally want to like cut up my model entirely and make them live totally separately because then the activation memory would also sort of be reduced. And so now I want better ways to split up the model so I can fit these really big models in these GPU's. And so that's going to bring us to model parallelism. We want to scale up in memory know without changing the back size. And we want an alternative axis where we don't need to spend or basically have big batch sizes in order to paralleze. And so what we're gonna to do is it's gonna to split up the parameters across GPU's. And in some ways that's like zero three, but we're not going to communicate parameters anymore. We're going to PaaS activations around and that's going to be different. And sometimes activations are going to be much smaller than parameters. And thatbe very good for us. So we'll cover two different types of parallelism. I'm going to talk about pipeline parallel, which is conceptually simpler but much more horrible implementation wise, and tenser parallel, which is conceptually maybe less obvious, but honestly much nicer to implement and more commonly used. And they're going to correspond to two different ways of cutting up the model. So I think pipeline parallel is maybe the most obvious way to cut up a neural network, right? You know that a deep neural network comes in layers, right? So if I have layers, a very natural place to cut a network is to cut it up at the layer boundaries. And so each GPU is going to handle some subset of the layers and I'm going to PaaS activations around. In this case, each layer belongs to a GPU and GPU's are going to PaaS activations from one to the other, when in the backwards case, it's going to PaaS know the backwards gradients backwards from GPU three to zero, right? Okay, so that's cool. That's great. What's wrong with this picture? Well, I think you should see that most of your GPU's are idle most of the time. This is actually quite terrible utilization. And so if I do this naive kind of parallelism that I described before, right? So if I have you know each layer having a Ford and let's say have a single example that's gonna to result in a diagram that looks like this, so different rows in this picture are different, different layers and also different GPU's. And the x axis, here is time where I'm going from left to right. So what you see, well, I first compute my first layer at the very left here, and then the activations get past the second layer. GPU two wakes up and it's like, all right, it's my turn. It does its job, passes to GPU three and then GPU four, and now the backwards passes can begin, and so on and so forth. And you see kind of this gigantic, what people call bubble, this is a big overhead where you're doing absolutely nothing. And you see that the gps are active one over n over the time. So in some sense, this is the worst possible parallelism of I've added four GPU's, but I get the throughput of a single GPU, right? And so one thing you can do is you can be a little bit more clever about what you do and you can say, all right, I'm going to have a pipeline, right? I'm not just going to cut things up in layers. I'm going to have a sequence of things that need to be processed by each GPU. So now let's say I have a micro batch, right? So each machine is going to handle sort of four examples. And what I'm going to do is you know I can finish my first example, my first data point, and I can send off the activations for that to my second GPU as soon as I finish. And then I can then get started working on my second data point. And so now I've overlapped sort of communication and computation. The second GPU can start working while the first GPU continues to work. And now the size of the bubble can potentially be reduced by having bigger batch sizes. And you can hopefully see why I said before that batch size zes are a resource. If you have a finite batch size and you have pipeline parallel, you can use that same batch size to make your pipeline bubble size smaller, for example, or you could use it to do data parallel. So there's many different ways that you can take your single batch size and then split it up into different ways. Okay. So now your micro back size can control the bubble time. And in fact, the amount of the ratio of your overhead to the useful compute that you have is the number of stages minus one over the number of micro batches. So if you have big, big back sizes, pipeline parallel could potentially be efficient. But as we said before, you know batch sizes are finite. We can't just crank that up to whatever value that we want. So in general, pipelines seem really horrible. Why do we do it? Why do we incur this cost of a bubble in order to paralyze? Well, there's a couple reasons pipelines help save memory compared to data parallel. I mean, zero, three will also shard the parameters, but this also shard sort the activations, which is nice. Pipelines can also have good communication properties, right? It only depends on activations. It's also point to point. So it's possible that depending on your topology and depending on what you have, pipelines might actually be very favorable for the slower parts of your network. And so you know pipeline parallel is often going to be used on your slower network links. So internode or even sometimes across different sort of racks or across different data centers, you might do actually not data centers across different racks, you might do pipeline parallel, right? One of the examples of a thing that I was recently told by some Google folks is know they were saying, actually one of the big advantages of tpus is that we don't have to do pipeline parallel very much because you know all of our connections are much bigger, right? Like they have this big toridal mesh. They don't have this limit at 256GPU's, where they're suddenly going towards a lower network link where you might want to switch to pipeline parallel. So that's a real world kind of example of when you would start to think about pipeline parallel. And so this is an example from an nvidia paper. I'll talk about this paper in much greater detail later. They've done some really nice work showing sort of performance characteristics of different kinds of parallelism. But you kind of see with batch size eight, as you increase the pipeline parallel size, the number of devices your utilization per GPU sort of starts to really drop off. Whereas if you have a big, big batch size of 128, you can get away with pretty good utilization for reasonably sized pipeline parallel, right? So batch sizes are really key to hiding the size of the bubble. Otherwise you have issues. Of course, you can do different kinds of pipeline strategies. So instead of having the sort of like standard patterns for scheduling the bubble, you can sort of cut things up into finer pieces where you're sort of assigning different stages, assigning different sublayers to different device, and you're doing different computations at different parts. You can then sort of interleave the pipeline better. And sort of an advanced version of this that I want na spend a moment talking about, and this is very, very clever, is zero bubble pipelining, or I think in deep seks lingo, I think they call it dual pipe. But the core single trick is the same. So here, if you think about it, let's say we're doing you know the backwards PaaS to compute gradients. You can split this up into two different components. The first part is about you know back propagating the activations. So this is you know as I go down sort of the residual connections, I need to compute essentially the derivative with respect to the activations. And then as I sort of get to a parameter, I also want to compute the gradient itself, like how I'm going to update the parameters, not just how do the activation change with respect to sort of the previous layers. And so to give you a concrete example, let's look at this bottom left diagram over here. So in this diagram, you see the forward PaaS. This is a single mlp. So we've got multiplied by a weight, I do a nonlinearity, and then I'm just going to output the nonlinearity. So this is a kind of a naive single part of a mlp. Now let's look at the backwards. You know, I have sort of the derivative respect to the loss it comes in. And then I can compute know how that's going to change the x's, the inputs to my mlp. So this is in some sense, the derivative respect to the activations here. And then as I compute these, of course I can use them to compute the gradients that I need to update my weights. But the important thing is this part, this part of computing the gradients for the weights, this can be done whenever, right? There's no sort of dependence of this. And so I can rearrange the scheduling for this computation to any part of the computation graph. And so what you can do is you can sort of do your standard pipeline parallel for the parts that are serially dependent, but anytime you have to do these computations just for updating the parameters, you can sort of reschedule them wherever. And so the key idea is when you start with sort of a nice, what it's called one f, one b pipeline, this is a nice optimized, reducing the bubble size schedule. And then you can take this. And what you can do is you can separate this b, which is this computation of the backwards part, and then W, which is the computation necessary to compute the gradient of the weights. And now I can do the computation of the wekes, the W's, where I would have originally had a bubble. So the parts where I originally had these White sort of idle utilization components, I can now fill them in with these W's. And so by thinking carefully about what the serial dependencies actually are, you know, I can now have something really nice where I'm getting actually good utilization out of my GPU's. To be clear, this is horrendously complicated, right? Like if you actually want to implement pipeline parallel in this way, you're gonna to have to intervene in how your autodiff is actually calculating these things. You have to have a cue that can track where things go. I heard a funny anecdote in a conversation recently from someone in a frontier lab sort of training lms, and they said, you know actually there's two people in the group that understand how the pipeline parallel in our infra works. One person left. And so there's a single load bearing person in our training infra. You know there are stories like this pipeline parallel is structurally very, very complicated, right? It looks simple here. If you're interested, I encourage you to try and implement it. It does get pretty hairy pretty fast. And I think that's a good note on which to switch to the other kind of model parallelism, because this is much simpler and this is often very cleanly utilized by a lot of frameworks and a lot of sort of even people training really big models rely very, very heavily or primarily on this kind of model parallelism. So what other way can we split up a model, right? So if we think about it, most of what we do is matrix multiplies, right? In a big model, most of the computation as matrix multiplies, most of the parameters where matrix multiplies or matrices. And so what can we do? Well, if we can parallelize just the matmoles, that would be pretty good. And so tensor parallel is this idea that we can take a big matrix multiply and split it up into a set of sumatrices that can be multiplied, right? So if I have know this matrix multiply at the top, right, we have x and sort of x times a equals y, you know what I can do instead is I can cut up a into half, right? And then I can also cut up x into half and I can compute the sumatrices, I can sum them up and then I will get my answer at the end, right? So conceptually, pipeline parallel is cutting along the depth dimension, like the layers tensor parallel, which is what this is, is cutting up along the width dimension of your matrix multiplies. And so we're gonna to decompose into submatrices and then do partial sums. So here's an example of what it might look like in a mlp. We have each GPU handling a different sub matrix of, let's say, a big mlp matrix multiply. And then we're going to have collective communications to synchronize the activations as we kind of need them, right? So what are we going to do? So this is the mlp and sort of the top half. And the bottom half, there's two different paths. These are splitting up the matrices. So I want to do this operation. Y equals galyux times a. I'm going to split up my matrix a into a one and a. And then on the right hand side, I want to compute drop out Y B, and then I want to return the result as z. So I'm going to also cut up b. So I've cut up both of my diparameter matrices into two parts, a and b, and in the forward PaaS, what I'm going to do is I'm going to take my inputs x, and I'm just going to copy them twice, right? So each GPU is going to get the same inputs and they're going to operate on it with a one and a two. They have the same kind of, Oh, sorry, they're the same row dimension. So that's going to be fine operating on them. So X A one and X A two is going to give you some activations. Y one and y two, those are going to go into B1 and B2. And then I'm going to do an all reduce to sum them up. That's exactly the figure I showed you before, right? So you copy and then you all reduce and you got the answer. Z. In the backwards PaaS. Now it's actually the reverse as sort of the gradients come backwards and the backwards steps, this g is going to be the identity. So I'm going to copy sort of the derivatives on both sides. I'm going to do sort of the backwards operation all the way through. And once I get to f, this is on all reduced, right? Because I've got sort of two derivatives sort of coming in from both paths, and then I sum them back up. So this f and g are synchronization barriers. In the forward PaaS, I do single all reduce. On the backwards PaaS, I do a single all reduce, just at two different places in the convocation graph. So now you can hopefully see how this is a very nice way of wherever you have a matrix multiply, you can just cut up the matrix multiply and sort of paralyze them across different devices. Okay? And as you might imagine, this is actually somewhat expensive. We have a synchronization barrier that lives kind of per layer. It needs to communicate an activation, sort of like the residual activation worth of stuff twice in a forward, backward path. And so tenor parallel, this very simple idea, is going to require very high speed interconnects. And so there's a rule thumb, it's a very simple rule thumb to remember, which is that tensor parallel is applied or within a single node. So a single box of, let's say, nvidia GPU's is going to ship with eight different GPU's that live in that same box. And as I showed you at the sort of beginning of lecture today, they're very, very high speed connected. So those eight GPU's can talk to each other very quickly. And so it makes sense to use something like tensor parallel that's very bandwidth hungry on between those eight devices. So what you will typically see is that tensor parallel is applied up to eight GPU's, where the eight GPU's live in the same machine, because that gives you the least sort of drop in performance. And so this is an example from hugging faces sort of paralylization tutorial, showing you sort of the throughput decreases of different levels of tensor parallelism. You see that there are hits, right? Ten and 12% hits to throughput as you do tenslar parallelism. But up until eight, well, maybe this is manageable. This is kind of the price you pay for just being able to paralyze more nicely. But then you go to 16 devices and you get this like kind of astounding 42% drop in performance. You go to 32 and you see another sort of 65% drop in throughput, right? And so you see hopefully visually here that you really want na stop at eight for tensor parels. And that's really the sweet spot because of the kinds of hardware interconnects you can get your hands on. Okay. So how do things now compare to pipeline parallel, right? Well, compared to pipeline parallel, we don't really have to deal with this bubble thing that we had before. We don't need to consume sort of larger batch sizes in order to reduce the bubble, which is nice. And there's very relatively, I want na say very there's relatively low complexity in applying tensor parallel. All you really need to know about are where are the big matrix multiplies? Can I split them up and make them live on different devices? Right? The fords and backwards operations still remain the same, right? Compared to implementing something like zero overhead or dual pipe pipeline parallel. You're gonna to be in much, much better shape doing this. So the con is that it's much larger communication overhead. You've got you know in pipeline parallel, batch size, type, sequence length, sort of residual dimension, point to point communications per micro batch, intenser parallel, you've got you know eight times that per layer and you've got all reduced communication. It's potentially a very large amount of communication that needs to be done. So you know the rule of thumb, as I said before, is tensor parallel is used whenever you have low latency, high bandwidth interconnects. You're going to see two to like 16 depending on what kinds of machines you have of tensor parallel out in the wild. And I'll show you examples as I talk through at the very end here of the examples of tensor parallel. Okay. Any questions on pipeline or tensor parallel before we move on to the kind of third kind like sequence parallel and activation sharding? Yes.
speaker 2: Can they both be used out?
speaker 1: Yeah. So the question was, can they be used simultaneously? The answer is that Yeah, you do use them both. So I think we'll get to examples later. But I think the typical thing that you see is for large scale runs, you very often see tensor parallel. Pipeline parallel is often used on top of that. I think the only example I know of that thus pipeline but not tensor parallel would be deep C V three as far as I know.
speaker 2: So within a single machine, I guess you have like like you have like five different machines. You have like maybe the first point percent of the predators are across the first would types of parallel one. And then that pipeline parallels into the second issue where you have the next step.
speaker 1: Yeah. So the question was there, you do tensor parallel within machine and like pipeline parallel across machine, for example. Yeah. So so you would do something like tensor parallel within machine and a combination of data and pipeline parallel across machines, for example, right? And I'll show you the rule thumb later. But basically you do pipeline parallel because your models won't fit. Like you if you could fit your entire model, you just do data parallel plus tensor parallel or you know just maybe even data parallel. Great. Okay, excellent. So then you know we've been talking about memory, and memory is you know in some sense a very important part of parallelization because we're going to be training big models. And so you know when you look at your memory, you realize that actually activations are a really big part of your memory usage. So if you look at you know a standard kind of forward backward PaaS, I think this was from one of the pi torch tutorials, you see that memory usage is very dynamic, right? So I'll just talk through this because I think it's an interesting plot in general, right? You always have your parameters as your training, right? Because that's static. But know in iteration zero, you don't still have optimizer state at all. So actually you don't have that part of your memory use. But as you do, you know your forward and backwards, you see activation grows, grows, grows, grows, grows as you accumulate all the activations. And as you start your backwards PaaS, right, your activation grows down because you're freeing it as you use up your activations and then you're accumulating your gradients. So your gradient memory usage goes up and the peak is actually somewhere partially through your backwards PaaS where you haven't freed all your activations yet and you're still building up your gradients. And so an iteration to you kind of see the same thing here, right? So the point of this diagram is to say, well, we've thought about all the other pieces. We've thought about the parameters, we've thought about optimizer state, we've thought about the gradients, but we have not thought about very deeply at least the activations. And so let's do that, right. So the final complexity that I want to talk you through is the activation memory. So tensor and pipeline parallel can linearly reduce basically most things, but it can't actually reduce all of the activation memory usage. And so this is an example from one of the nvidia papers that's talking about how do you reduce activation memory. And I think one thing that's really interesting to see is if you make your models bigger and bigger, so from left to right, you see that a parameter and optimizer stay memory can remain the same if we paralywise aggressively. But activation memory just kind of continues to grow because some parts of it don't paralyze very cleanly. So no matter the number of devices you have, actually you can't really get rid of the growth of activation memory per device. And I'll show you why in a moment here. Whereas I think if you do some slightly more clever things like recomputation, you can keep the activation memory low, and that's really key to paralyzing some of the biggest models. Okay. So what's the activation memory per layer? You've kind of done some of this transformer math and calculus before, so hopefully you're now familiar with all of this, but we can compute what's the amount of activation memory we need per layer. And there's a handy formula here, and this is the amount of memory you need. It's sbh times 34 plus five as over H. And some of these numbers are mystifying, but actually they're not so mystifying. You know you can very, very much see that there's a left term and then there's a right term. The left term comes from the mlp and other point wise operations. That's where sbh times 34 comes from. These depend on the size of your residual street. The H on the right side, you have a term that's actually if you multiply this out as squared b, right? Because the H is canceled, that's the memory that you need for the soft tmax term and other sort of quadratic terms in your attention, right? Of course, if you use flash attention, you can drastically reduce and use recomputation. We know that we can drastically reduce that second term. So then let's say we do tensor parallel. We do tensor parallel everywhere we can. So we do it in the mlps. We do it in the kq computations, in the attention computation. We will end up with something that looks like this. And this is looking pretty good, but not quite there. So activation memory per layer divided by t, which is the number of sort of devices that were tensor paralling over. So if we're dividing by eight, ideally we would divide all the activation memory by eight. But you see there's this straggler term, sbh times ten, that has not been sort of reduced down. And if you think about what these are, these are the non Mamal components. So the layer norm, the dropouts, the inputs to the attention and the mlp, all of these terms will unfortunately continue to grow with size and they will not be paralyzed very nicely, right? And so the very last thing that we need to think about is to take those simple point wise operations, which thus far we have not parallezed, and we just need to split them up, right? And there's a very simple way to split them up, which is to say, well, if we're doing like a layer norm, right, these layer norms across different positions in the sequence do not interact at all with each other. Like they just don't care about anything else. And so what we are going to do is, let's say we have a ten, 24 long sequence, we're going to cut that up and then each device will handle a different part of that layer norm or a different part of that dropout, right? Those point wise operations can now be completely split up across the sequence dimension. And because you now we're cutting things up across the sequence dimension, we're going to have to do some synchronization to make sure you the parallel computations that we did can get aggregated back again. And so in the forward PaaS, these g's, they're going to be all gathers and g bars are going to be reduced scatters. And in the backwards PaaS, the two are reversed. In some sense, there's sort of a duality here between the two. And what we're doing here is know for the layer norm, we've kind of scattered things around. And so we're going to have to gather them back together so that we can do sort of our standard computation. And then now whenever we get to the dropout, we want to scatter them back out into the sort of parallel components that we have. And in the backwards PaaS, we're kind of doing that in the reverse, right? Okay. So hopefully that is clear. This is a very simple idea, right? We're just paralyzing sort of the very last components that we failed to paralyze with. And so now we can sort of put all these different pieces together and sort of get to sort of the end, which is we started up here, which is no parallels metal. We did tensor parallel, which allows us to divide everything. That's not a point. Wise up by t. And then if we apply this sequparallelism idea, we can divide this component by t once more. And then you know we can do things like activation recomputation, which is the flash attention trick to remove the second term. And the minimal memory that you can kind of easily get away with is going to be this thing on the bottom, which is sb eight H 34 over t. And this is often used if you're looking at different formulas for transformer arithmetic on like how much activation memory do I use? You often see something like sbh 34 and then if you have t tensor parallel divided by t because this is the sort of easy minimum that you can get for that kind of a memory. Okay. Any questions on sequence parallel and activations?
speaker 2: Yes, one, each other. I suppose the peration of graphwill grow more and more involved people.
speaker 1: You're saying if we have something that's a more complicated computation graph than like a single linear chain, will that become a problem? It's a good question. I haven't thought about that. I would guess not. Like at least for tensor parallel, this operates purely layer wise. It doesn't really care about the dependencies. Maybe for pipeline parallel, there's opportunities for increased parallelization if there's more than one branch, but I'm not too sure.
speaker 2: That's really yers usually like a mirror, right? Yes, great. Okay, cool. So there's a few other parallelim strategies that I'm .
speaker 1: not going to talk about just because in the interest of sort of time and sort of fatiguing you because I think I've already dragged you through a whole bunch of low level details about how to do parallelization. So the first one I want na talk about is context parallel or ring attention. You may have heard the term ring attention before. This is a way of essentially splitting up both the computation and the activation cost of computing. Really large attention, where essentially you're just going to PaaS keys and values around different machines. So each machine is responsible for a different query, and then keys and values are going to sort of travel from machine to machine in a sort of ring like fashion in order to compute your kqv inner products. And the cool thing here is you already kind of know how to do this because you've done the tiling for flash attention. So you know that so you know that attention can be computed in this kind of online tile by tile way. And that's kind of what's happening in your ring attention. The other thing, which now that you know tensor parallel is pretty straightforward, is expert parallelism, right? Expert parallelism you can kind of think of as almost like tensor parallel in the sense that you're splitting up one big mlp into smaller expert mlps, let's say, and then scattering them across different machines. The key difference with expert parallelism is that the experts are sparsely activated. And so you have to think a little bit about routing. And the routing is not going to be sort of as predictable, let's say, as all to all communication that we had before, intense or parallel, because now you know maybe one expert is overloaded, your networking is going to be a little bit more complicated. But otherwise, conceptually, you're living in kind of the same world as tensor parallel for expert parallels. Okay. So just to recap all the things we talked about, I've made a little small cable of the different kinds of strategies that we have. You know we have ddp and zero one. This is kind of the naive data parallelism thing that you do here. You have some overhead per batch. You have no memory scaling, reasonable bandwidth properties, but you consume batch size in order to be able to do this, right? You need big batch sizes to have big data parallelism. You have fsdp, which is kind of like a nicer version of zero one in the sense that you can get memory scaling, but you're gonna to pay overhead across sort of different layers, right? And so now you've got higher communication costs and you've got potentially synchronization barriers that lead to poor utilization. Pipeline parallel you know is nice in that you know you no longer have this dependence on this per batch aspects, but and we can get linear memory scaling, but we have sort of another issue, which is this also consumes batch size and it's horrendous to sort of set up and use. And so a lot of people like to avoid pipeline parallelism if it's possible. Now finally, tense or parallelism is very high cost in terms of bandwidth and the amount of synchronization you need to do. But this has this really nice property that has no impact on batch sizes. So it's like kind of the one parallelism strategy you can use that has no cost in terms of your global batch size, which is nice, right? So we have to balance a number of limited resources, right? We have memory, which is one resource. We have bandwidth and compute, which is another resource. And then we have batch size, which is kind of an unconventional resource, but one that you should really think of as a limited thing that you can spend on different aspects of these to improve your efficiency. And there's a very nice tpu parallelism, or tpu book, let's call it, from Google, that I referred to last week, but also actually, they have a really nice parallelism section, and they have this great figure that I wanted to show you before I moved on to some of the examples. So the key quantity, as I was saying before, is the batch size. And depending on the ratio of batch size to the number of GPU's you have, different kinds of parallelism become optimal. And so they use sort of certain formula on how much communication and computation you end up doing, sort of for each of these models. So this a simplified formula to sort of generate this plot and you can kind of see if your batch size is too small, you have lots of GPU's and really tiny batch sizes, then there is no way for you to be efficient. You're always communication bound, which is this bottom half here. And in fact, you're spending most of your time on communication. As you sort of get more and more batch size, eventually, you can get to a point where if you mix both fsdp so zero stage three and mp, which in this case is tensort parallel, you can actually get basically to a place where your compute bound. So now you're not you know spending sort of wasting your flops waiting for communication. And then finally, if you get to a point where your batsizes are big, then you can just get away with pure data parallel, like pure fsdp is going to get you into a regime where the time you spend doing computation is higher than the time you spend doing communication. So if your batch size is big enough, you can just get away with fsdp. So this is kind of a coillustration of this idea of you. Why would you mix these? When would you mix these? Why is baptized the resource? Hopefully, this kind of shows you in a very visual way what this is. Okay? And so when you put these all together, you end up with what people call 3D or 4D parallelism. I think I've heard the term five d parallelism recently. I wasn't quite sure what the fifth dimension was yet. I'll have to read up on that. But now you can put it all together, the different dimensions of parallelism. And this is a really simple rule of thumb. I originally sort of looked it up and put this together last year, but turns out it's still the same this year. So so you can sort of follow this now. So the first thing you have to do is you have to fit your model and your activations in memory, right? If you don't do that, you just cannot train. So this is a requirement. So until your model fits in memory, we have to split up our model. So we're gonna to do tensor parallelism and we know that up to the number of GPU's per machine, that's very efficient, that's very fast. So we're gonna to do tensor parallel up to that point. Now after that, depending on things like your desire to deal with pipeline parallel and or your bandwidth constraints, you're either going to use zero three or pipeline parallel across the machines until you can fit your model in memory. Now after that point, well, until you sort of run out of GPU's, you can now run the whole thing. And your only goal is to increase the amount of total flops that you have on hand. So you're going to scale the rest of the way with data parallel because data parallel is it works well on low bandwidth communication channels and it is very simple, right? And so that's going to give you a way of sort of using all of your GPU's. Now if your batch size is really small, then there is a way of trading batch sizes for better communication efficiency. Like if you haven't consumed all of your batch size as a resource, what you can do is you can use gradient accumulation on your devices and thatlet. You basically have effectively larger batch sizes, even if your memory constraint and thatwill let you trade your batch size for better communication efficiency, since you're synchronizing less often across machines. Okay, simple rule of thumb. This will let you train models with reasonable efficiency no matter what you're doing. And so to sort of make this concrete, I'll talk through a few examples at the very end here, a flash through both this really lovely paper back in 2021 from Megatron lm, basically showing you exactly these things in pictures, and also a lot of ablations, as well as some of the models from last year. So this is a big table of how they trained models, going from 1.7 billion parameters to 1 trillion parameters. And they get great utilization on all of these, right? You see a percentage of theoretical peak flops that they get, and it ranges from 40 to 52%. It's pretty good. And so you can see tensor parallel starts at one, and then they eventually go up to eight, and then it caps out at eight. So they are using tensor parallelism first, and then pipeline parallel stays at one. But once the models get big enough, you know they can't fit these big models. So pipeline parallel has to increase in order to kind of in order to compensate. And then the data parallel size basically starts out as big as possible and then slowly kind of goes down because as we increase the amount of pipeline parallel, this is now consuming in some sense the batch sizes. And so you can't have effectively as big of a batch size if they're being used in some sense for pipeline parallel. Okay. So careful 3D parallelism is going to give you sort of linear gains in aggregate flops. So you see if you do careful 3D parallelism, you see sort of very flat overall achieved flops per GPU, which is giving you if you add more GPU's linear scaling in the total aggregate throughput, that's great. Tensor parallel eight is often optimal. You see this is the pipeline parallel size and the tensor parallel size. You see going to 88 with a batch size of 30 or batch size of 128 is optimal even if you have a smaller batch size. Know tensor parallel size of eight remains optimal, and activation recomputation enables larger batch sizes. And remember that larger batches can in turn help you sort of mask overhead for pipeline parallel. So activation recomputation, even though it's more flops, can pay for itself. We've seen that story play out already in flash attention all. So the last part of this is recent language models, like what do they do? So, you know, I've gone through a few papers to look at examples of what people's paralylization strategy is. Omo and the doma paper, they do fsdp for a 7 billion parameter model, deep seek. The first paper does zero stage one with tensor sequence and pipeline parallel. This is the vanilla thing that I told you. V3 actually does something slightly different. They do 16 way pipeline parallel, 64 way expert parallel, which is kind of like tensor parallel, and then zero stage one for their data parallelism strategy. E, which is another Chinese model, does once again, zero stage one tensor and pipeline parallel. And e lightning, because they're doing moes, replaces tensor parallelism with expert parallelism. The final thing, if you're interested in kind of state of the art, you know distributed training with lots of details, lama three's report is actually really interesting to read. They have a lot of detail about how they do their networking, what sort of things happen. And you see sort of, once again, the kinds of things I said before. You see a tensor parallel of eight, you see cp or this is context parallel. This is only relevant for long context training, which is this very last step. So you can ignore that you've got pipeline parallel and data parallel happening in these sort of first two phases. You can also even ignore the first stage here because that's kind of the small batch size training that they did in order to be stable. And if you look at kind of their rationale for how they do their parallelism strategy, you see exactly what I had said before of basically, all right, you want to do tp, cp, pipeline parallel and dp in that order in terms of the amount of bandwidth that you need where data parallel can tolerate these like long network latencies because you can do the sort of asynchronous fetching of shded model weights, right? And so they're using kind of the strategy that I told you in order to train some of the biggest models. The funny side note about llama three, I, and you may have heard this of in sort of not rumors but sort of casual conversation with your friends is you there's lots of GPU failures when you train models at a huge scale. They had 148 interruptions from faulty GPU's, totaling about 30% of the total interruptions that they had. They had things like unplanned maintenance of machines, and that was 30 different things. You 32 instances of interruptions for their training. So when you're training a model this big, know I've talked about the algorithms, but you also need kind of fault tolerant architectures to be able to deal with these kinds of things. And I've also heard various stories of people saying the even scarier thing is not actually explicit model failures, but actually data corruption, like gps can silently fail on you and give you garbage data, completely ruining your run. Okay. And then the last one example is for Gemma two. And I wanted to end on this because this is a tpu example. You know they do zero three, which is roughly fstp, and then they do model parallelism and data parallelism, right? And so here, you know, as I said before, the sort of tpu's allows them to sort of stretch model parallelism a little bit further. Okay? So putting it all together, scaling beyond a certain point is going to require sort of multi GPU multinode parallelism. There's no single solution. So you want to combine all three approaches to sort of leverage strength, and then there's simple and interpretable rules of thumb for how you might execute this parallelism in practice, right? Thank you.

最新摘要 (详细摘要)

生成于 2025-05-13 19:24

概览/核心摘要 (Executive Summary)

本讲座(Stanford CS336 Language Modeling from Scratch, Spring 2025, 07 Parallelism 1)深入探讨了在多机器环境下训练大规模语言模型所需的并行化策略。核心目标是解决单GPU在计算能力和内存容量上的瓶颈,通过跨机器并行实现对巨大模型的有效训练。讲座首先介绍了并行化的动因(计算需求和内存限制)以及相关的硬件基础,特别是不同层级的网络通信(如GPU间NVLink/NVSwitch与跨机器InfiniBand的速率差异,以及TPU的环形网络拓扑)。接着,重点阐述了三种主要的并行化范式:

  1. 数据并行 (Data Parallelism):核心思想是在多个GPU上复制模型参数,并将数据批次(batch)切分给不同GPU处理。讲座详细介绍了朴素数据并行及其内存开销问题,并重点讲解了Zero Redundancy Optimizer (ZeRO) 的三个阶段(ZeRO-1、ZeRO-2、ZeRO-3/FSDP),它们通过逐步分片优化器状态、梯度和模型参数,显著减少内存占用,同时分析了其通信开销和实现技巧(如ZeRO-3中计算与通信的重叠)。
  2. 模型并行 (Model Parallelism):当模型过大无法在单个GPU内容纳时,需要将模型本身切分。讲座讨论了两种主要方式:
    • 流水线并行 (Pipeline Parallelism):按层切分模型,不同GPU负责不同层段,通过传递激活值进行协作。挑战在于“流水线气泡”(GPU空闲),可通过微批次(micro-batching)和更复杂的调度(如Zero-Bubble Pipelining)缓解,但实现复杂。
    • 张量并行 (Tensor Parallelism):在模型内部(如矩阵乘法)进行切分,将单个张量操作分布到多个GPU上。通信开销大,通常用于节点内高速互联的GPU(如8卡NVLink)。
  3. 激活并行/序列并行 (Activation/Sequence Parallelism):针对训练过程中激活值占用大量内存的问题,通过在序列维度上切分操作(如LayerNorm、Dropout)来减少单个GPU的激活内存占用。

讲座强调,实际大规模训练通常结合使用这些策略(所谓的3D或4D并行),并给出了选择和组合策略的经验法则,如优先使用张量并行填满节点内带宽,再用流水线并行或FSDP跨节点扩展以适应模型大小,最后用数据并行进一步扩展计算规模。批处理大小(batch size)被视为一种关键资源,影响各种并行策略的效率。最后,通过Megatron-LM、Llama 3、Gemma 2等实际案例,展示了这些并行策略在业界顶尖模型训练中的应用和效果,并提及了大规模训练中硬件故障等实际挑战。

讲座核心内容总结

一、引言与并行化动机

  • 目标:从优化单个GPU吞吐量转向理解训练真正大型模型所需的复杂性和细节,实现跨机器并行。
  • 核心挑战
    • 模型大小 (Memory Concerns):大型模型参数量巨大(数十亿甚至上万亿),远超单个GPU内存容量。
      • 引用数据:GPU内存虽然在增长,但速度不及模型参数增长。
    • 训练速度 (Compute Concerns):需要利用多台服务器的计算资源以快速训练模型。
      • 引用数据:单个GPU的浮点运算能力(FLOPS)虽呈超指数增长,但仍不足以快速扩展。训练顶级模型需依赖拥有ExaFLOPS级别算力的超级计算机。
    • 异构通信 (Communication):不同层级(GPU内部、机器内部、机器之间)的通信速度差异显著,影响并行策略选择。
  • 讲座结构
    1. 网络基础 (Networking Basics)
    2. 并行化策略 (Parallelization Strategies)
    3. 案例研究 (Case Studies)

二、硬件基础与网络通信

  • GPU服务器硬件层级
    • 机器内部 (Intra-node):通常一台机器包含多个GPU(如8个),通过高速互联技术(如NVIDIA的NVLink和NVSwitch)连接,通信速度极快。
    • 机器之间 (Inter-node):GPU通过网络交换机(如InfiniBand HDR)与其他机器上的GPU通信,速度相对较慢。
      • 引用数据:讲者提到一个例子中,InfiniBand HDR每通道吞吐量比NVSwitch慢约8倍。
    • 大规模集群:当GPU数量超过一定阈值(如256个GPU),可能需要更复杂的网络拓扑(如叶脊网络 Leaf-Spine Switches),可能引入更多通信瓶颈。
  • TPU网络设计
    • Google TPU采用不同的网络方法,如环形网格 (Toroidal Mesh),每个TPU芯片与其邻居高速连接。
    • 这种设计对于集体通信操作(如All-Reduce)非常高效,即使不采用全连接(all-to-all)拓扑。
  • 集体通信操作 (Collective Communication Operations)
    • 是并行算法的构建模块,理解其性能特征对优化至关重要。
    • All-Reduce:所有机器拥有数据,进行规约操作(如求和),结果分发回所有机器。通信成本约 2 * N (N为数据大小)。
    • Broadcast:一个机器的数据复制到所有其他机器。成本约 1 * N
    • Reduce:不同机器的数据规约后发送到单个机器。
    • All-Gather:每个机器拥有数据的一部分,将其复制到所有其他机器,最终每个机器拥有完整数据。
    • Reduce-Scatter:每个机器拥有完整数据,对数据进行规约(如按行求和),然后将结果的不同部分分发给不同机器。
    • 重要等价关系All-Reduce 操作在带宽受限情况下,其最优性能等价于一个 Reduce-Scatter 操作后接一个 All-Gather 操作。
      • 讲者强调:“this could be replaced with two operations, a reduced scatter. And I'll all gather... In the bandwidth limited regime, this is basically the best that you can do, all reduce.

三、并行化策略的目标

  • 将整个数据中心 (Data Center) 视为一个新的计算单元。
  • 线性内存扩展 (Linear Memory Scaling):随着GPU数量增加,可训练的最大模型大小也应线性增长。
  • 线性计算扩展 (Linear Compute Scaling):随着GPU数量增加,用于训练模型的有效计算量也应线性增长。
  • 算法实现依赖于调用简单的集体通信原语 (Collective Communications Primitives)

四、数据并行 (Data Parallelism)

  • 核心思想:模型参数在所有GPU上复制,数据批次被切分到不同GPU上。
  • 朴素数据并行 (Naive Data Parallelism)
    • 每个GPU处理一部分数据,计算梯度。
    • 通过 All-Reduce 同步所有GPU的梯度。
    • 进行参数更新。
    • 计算扩展:良好,前提是每个GPU的微批次大小(micro batch size)足以饱和其计算单元。
    • 通信开销:每次批处理需传输 2 * P(P为参数量)的数据(All-Reduce)。若批次大,可掩盖通信。
    • 内存扩展:差。每个GPU需存储完整模型参数和优化器状态。
  • 内存瓶颈分析
    • 实际内存占用远超参数本身,主要来自优化器状态 (Optimizer States),尤其是Adam等自适应优化器。
      • 引用数据:可能需要存储约16字节/参数,包括参数本身(如BF16占2字节)、梯度(2字节)、Adam主权重(4字节)、Adam一阶矩(2-4字节)、Adam二阶矩(2-4字节)。“you need to store something like five copies of your weights.
      • 一个7.5B参数模型在64个加速器上,若朴素复制,总内存占用巨大(示例中为120GB)。
  • Zero Redundancy Optimizer (ZeRO):旨在减少数据并行中的冗余内存。
    • ZeRO Stage 1: 优化器状态分片 (Optimizer State Sharding)
      • 将优化器状态(如Adam的一阶和二阶矩)分片到不同GPU。每个GPU仍持有完整的模型参数和梯度。
      • 流程
        1. 各GPU计算其数据子集的完整梯度。
        2. 对梯度进行 Reduce-Scatter:GPU i 获得其负责的那部分参数(对应其拥有的优化器状态分片)的全局梯度总和。
        3. GPU i 使用其本地的优化器状态分片和收到的梯度总和,更新其负责的那部分参数。
        4. 对更新后的参数分片进行 All-Gather,使所有GPU恢复完整的最新参数。
      • 通信成本Reduce-Scatter + All-Gather,与 All-Reduce 相同。在带宽受限情况下,内存节省“几乎免费”。
    • ZeRO Stage 2: 梯度分片 (Gradient Sharding)
      • 在Stage 1基础上,进一步将梯度也进行分片。
      • 在反向传播过程中,当计算完一层梯度后,立即通过 Reduce 操作将其发送给负责该参数分片的GPU,并释放本地梯度内存。
      • 通信成本:总量仍为 2 * P,但由于逐层同步,可能引入额外开销。
    • ZeRO Stage 3: 参数分片 (Parameter Sharding) - 即 FSDP (Fully Sharded Data Parallel)
      • 将模型参数、梯度、优化器状态全部进行分片。
      • 在计算过程中按需请求参数。
      • 前向传播All-Gather 当前层权重 -> 执行前向计算 -> 释放该层权重。
      • 反向传播All-Gather 当前层权重 -> 执行反向计算 -> Reduce-Scatter 梯度 -> 释放权重。
      • 通信成本:增加到约 3 * P
      • 效率关键:通过重叠计算与通信 (Overlapping Communication and Computation) 来隐藏延迟,即在当前层计算时预取下一层参数。
      • 引用数据:一个例子中,使用FSDP可以将8xA100 80GB节点上能容纳的模型从约6B参数提升到50B参数。
  • 数据并行的关键资源批处理大小 (Batch Size)
    • 并行度不能超过批处理大小。
    • 批处理大小存在收益递减效应(OpenAI的“临界批处理大小”研究)。

五、模型并行 (Model Parallelism)

  • 核心思想:当模型过大无法放入单个GPU时,将模型本身切分到不同GPU上。主要传递激活值而非参数。
  • 流水线并行 (Pipeline Parallelism)
    • 概念:沿模型深度(层)切分,不同GPU负责一部分连续的层。
    • 挑战流水线气泡 (Pipeline Bubble),即GPU空闲等待。
      • 朴素实现中,GPU利用率仅为 1/N (N为流水线阶段数)。
    • 优化
      • 微批次 (Micro-batching):将大批次切分为小批次,使流水线各阶段能重叠工作,减小气泡。气泡大小与 (阶段数-1)/微批次数 成正比。
      • Zero-Bubble Pipelining (如DeepSpeed DualPipe):更复杂的调度,将权重梯度计算等可延迟任务安排在气泡时间内执行。实现极为复杂。
        • 讲者趣闻:“actually there's two people in the group that understand how the pipeline parallel in our infra works. One person left. And so there's a single load bearing person in our training infra.
    • 优点:节省参数和激活内存,点对点通信对慢速网络链路友好。
    • 缺点:消耗批处理大小以减小气泡,实现复杂。
  • 张量并行 (Tensor Parallelism)
    • 概念:沿模型宽度(隐藏层维度)切分,将单个大运算(如矩阵乘法)分解为多个小运算,在不同GPU上并行执行。
    • MLP示例Y = GeLU(XA)Z = Dropout(YB)
      • 将权重矩阵A切分为A1, A2,B切分为B1, B2
      • 前向:输入X复制给两个GPU -> XA1, XA2 -> (GeLU) -> Y1, Y2 -> Y1B1, Y2B2 -> All-Reduce 合并 Z1, Z2 得到Z。
      • 反向:类似地,梯度在某个点需要 All-Reduce
    • 通信:每层有同步点(如 All-Reduce),通信量大,需要高速互联。这种高速互联在GPU节点内常通过NVLink等技术实现全互联,而TPU等架构则凭借其特有的网络设计(如高效支持集体通信的环形网格)来满足此类需求,突显了不同硬件对并行策略实现方式的影响。
    • 经验法则:因此,张量并行在GPU上通常限于单节点内部 (within a single node),如8个通过NVLink连接的GPU,以利用其最高带宽;扩展到节点外(如跨InfiniBand)则因通信瓶颈性能会急剧下降。
      • 引用数据:HuggingFace教程显示,TP扩展到16个设备时吞吐量下降42%,32个设备下降65%。
    • 优点:不产生流水线气泡,不消耗批处理大小,实现相对简单。
    • 缺点:通信开销大。

六、激活内存与序列并行 (Activation Memory & Sequence Parallelism)

  • 激活内存问题:即使使用模型并行,激活值仍可能占用大量内存,尤其对于大模型和长序列。
    • 某些激活(如LayerNorm、Dropout、Attention和MLP的输入)在标准张量并行下不会随并行度T线性缩减。
  • 序列并行 (Sequence Parallelism)
    • 针对上述未被张量并行有效切分的逐点操作(point-wise operations)。
    • 沿序列长度 (Sequence Dimension) 进行切分。例如,LayerNorm或Dropout操作在序列的不同位置上是独立的,可以将长序列切分给不同GPU处理。
    • 需要额外的 All-GatherReduce-Scatter 操作来同步跨序列分片的结果。
  • 最终激活内存:结合张量并行(TP, 并行度T)、序列并行和激活重计算(如FlashAttention),每层最小激活内存可达 SBH * 34 / T (S:序列长度, B:批次大小, H:隐藏层大小)。

七、其他并行策略

  • 上下文并行 (Context Parallelism) / 环形注意力 (Ring Attention)
    • 用于处理超长上下文的注意力计算。
    • 将Queries分配给不同机器,Keys和Values在机器间以环形方式传递计算。利用了FlashAttention的分块计算思想。
  • 专家并行 (Expert Parallelism - MoE)
    • 将MLP层替换为多个小型的“专家”MLP,并进行稀疏激活。
    • 概念上类似张量并行(将一个大MLP切分为多个小MLP并分散),但增加了路由(routing)的复杂性,因为专家激活是稀疏且动态的。

八、组合并行策略 (3D/4D Parallelism)

  • 资源权衡:内存、带宽、计算、批处理大小都是有限资源,需平衡。
  • Google TPU书籍的图示:最优并行策略取决于 批处理大小 / GPU数量 的比率。
    • 比率过小:通信受限。
    • 比率适中:混合FSDP和模型并行(主要是张量并行)可达计算受限。
    • 比率够大:纯FSDP即可计算受限。
  • 通用经验法则 (Rule of Thumb)
    1. 适配模型和激活内存 (Fit Model & Activations)
      • 首先使用张量并行 (Tensor Parallelism),直到用尽单机内GPU的高速互联(如8卡)。
      • 若仍无法容纳,则跨机器使用 ZeRO Stage 3 (FSDP)流水线并行 (Pipeline Parallelism)
    2. 扩展计算能力 (Scale Compute)
      • 使用数据并行 (Data Parallelism) (如ZeRO Stage 1或FSDP的DP部分)扩展到所有可用GPU。
    3. 小批处理大小优化
      • 若全局批处理大小受限,可使用梯度累积 (Gradient Accumulation) 来增大有效批处理大小,减少同步频率,提高通信效率。
  • 讲者提到可能存在“五维并行”,但尚不清楚第五维具体指什么。

九、实际案例研究

  • Megatron-LM (2021)
    • 展示了从1.7B到1T参数模型的训练配置,实现了40-52%的理论峰值FLOPS。
    • 策略:张量并行最高到8,然后根据模型大小增加流水线并行度,数据并行度相应调整。
    • 激活重计算有助于使用更大批次,从而掩盖流水线并行开销。
  • 近期语言模型
    • Olmo, DeepSeek (v1), E AI:通常采用FSDP (或ZeRO Stage 1) + 张量并行 + 流水线并行。
    • DeepSeek (v3):16路流水线并行,64路专家并行,ZeRO Stage 1数据并行。
    • E AI Lightning (MoE):用专家并行替代张量并行。
    • Llama 3 (Meta)
      • 采用张量并行(8),上下文并行(CP,用于长序列),流水线并行,数据并行。
      • 并行优先级(带宽需求从高到低):TP -> CP -> PP -> DP。
      • 提及大规模训练的挑战:训练过程中遇到148次GPU故障中断,占总中断30%。更可怕的是静默数据损坏。
    • Gemma 2 (Google TPU)
      • 采用ZeRO Stage 3 (FSDP的TPU等效实现) + 模型并行 + 数据并行。TPU的互联特性使其能更好地支持模型并行。

十、结论

  • 扩展到一定规模以上必须依赖多GPU、多节点并行。
  • 没有单一的万能解决方案,需要结合数据并行、模型并行(流水线、张量)、激活/序列并行等多种方法。
  • 存在简单且可解释的经验法则来指导实践中的并行策略选择与组合。