Stanford CS336 Language Modeling from Scratch | Spring 2025 | 08 Parallelism 2

该讲座探讨了利用多GPU和多节点并行来加速模型训练,核心在于克服数据传输瓶颈以最大化GPU利用率。讲座首先回顾了单GPU内的并行技术,并重点转向跨GPU和节点的并行。内容介绍了数据传输的层级结构,从GPU内部的L1缓存、高带宽内存(HBM),到同一节点内GPU间的NVLink,再到跨节点的NVSwitch,指出数据传输速度远慢于计算速度,是主要的性能瓶颈。

讲座的第一部分详细阐述了集体通信操作(Collective Operations),这些是分布式编程的基础构建模块,例如广播(broadcast)、分散(scatter)、收集(gather)、规约(reduce)、全体收集(all-gather)和规约分散(reduce-scatter)。同时解释了相关术语,如“world size”(设备总数)和“rank”(设备编号)。

接着,讲座对比了GPU通信硬件的演进:传统方式通过PCIe总线(节点内)和以太网(节点间)通信,常受CPU开销和带宽限制;而现代NVIDIA系统采用NVLink实现节点内GPU直连,NVSwitch实现跨节点GPU直连,从而绕过CPU和以太网,大幅提升通信带宽和效率,并以H100节点的连接为例进行了说明。讲座后续将具体讨论这些操作在NCCL和PyTorch中的实现及分布式训练策略(如数据并行、张量并行和流水线并行)。

媒体详情

上传日期
2025-05-13 17:47
处理状态
已完成
转录状态
已完成
Latest LLM Model
gemini-2.5-pro-exp-03-25

转录

下载为TXT
speaker 1: This is week two of the systems lecture, where we try to leverage the most out of the hardware we have to make models train faster. And last week we talked about parallelism within a single GPU, and this week we're talking about parallelism across multiple GPU's. So this is a picture you should have in your head. So we have a bunch of nodes. These are basically computers that have each have a number of GPU's, usually eight. And within each GPU, there's a bunch of streaming multiprocessors or sms, which actually do the work. And you see that in Green here are essentially the memory and the communications. So within each sm, you have some very small l one cache. On a GPU, you have high bandwidth memory hbm, which is bigger, and then you have these links that connect the different GPU's. So the way to think about it is that compute has to happen within the sm on these alus. And a compute needs inputs and needs to write outputs. And generally, the inputs and outputs are can be relatively far. If you're lucky, they're on the l one cache. If you're say less unlucky, they're in hbm. And now this week, we're talking about multi GPU and multi node training, where the data that you might need might be across on another GPU. So the name of the game is how do you structure all your computation to avoid data transfer bottlenecks? Because we want to, remember, keep the arithmetic intensity high. We want to saturate our GPU's, make them go hum along. And generally, data transfer is going to be a lot slower. So we have to that's going to be the bottleneck. So last week, we saw a bunch of different techniques to try to do that within a GPU, including fusion and tiling. So the idea basically is that instead of reading and writing from hbm, you can load into l one cache, or I guess, shared memory, which is using the same type of, has the same speed, and just work there on your local scratch pad and then write out to hbm only judiciously. And this week, we started looking at communication across GPU's and nodes where we have to replicate and shard our models and parameters and optimize our state. And there, the way we do that will determine the cost. So here's a kind of I'm taking a little bit of liberty to put everything in kind of one hierarchy. You can think from small fast to big slow. So the smallest and fastest is are on a single note, single GPU. You have l one cache that's extremely fast but very small. And then you have hbm on a single GPU. And then between GPU is on the same node. We have mv link, and then finally we have mv switch. And of course, this is all in the nvidia ecosystem. So the idea is that many of the core concepts of minimizing data transfer are really the same. But now the mechanics are a bit different because l one is behaves differently than these kind of envy switches. So this lecture is going to be mostly about concretizing the concepts from the x lecture in code. There's going to be a few new things, but Tosui did an excellent job of giving you an overview of all the different types of parallelism. I'm going to try to anchor it in the code so we can more deeply understand what's going on. And then we're going to have I'm going to refer to the standard outfile here, which is the output of running this lecture. There were some minor issues I'll spare you off, where if you have multi processing, then this framework doesn't quite work. Okay, so this lecture has two parts. One, in part one, we're going to look at the building blocks, collective operations, which we discussed last time, how this is implemented in Nicoll and PyTorch, and then we're going to do some benchmarking. And then in part two, we're going to look at actually distributed training, data tensor and pipeline imparallelism. Okay, so let's start with collective operations. So collective operations are these primitives that are used generally for distributed programming. And collective means that you have many nodes. These are actually quite old from at least the 80s in the parallel programming literature. And generally they provide a better abstraction than trying to manage the point to point communication yourself. So these are really tried and you know primitives that have stood the test of time. So a bit of terminology. So world size refers essentially to a number of devices, for example, four, and the rank, sort of confusingly, if you're used to kind of linear algebra, is actually just refers to device. So we have rank zero, rank one, rank two and rank three if you have four devices, okay, so the collective operations are as follows. So starting from broadcast, the idea is you have t zero on one of the ranks and you just want na put it on all the other ranks or all ranks. Okay? So that's very straightforward. Scatter is similar, but you have four values and you want na put each of the values are different ranks. So each of the ranks get different values, not the same value. Gaatter is the sort of the inverse of scatter where you have each rank having a different value, and then you bring them all together on one rank. You know, reduce is the same as gather, except for instead of concatenating, you add them all. Gather is the same as gather, except for you Jeice do it for all the destinations. Gather was just rank zero, or rank one, or rank two, or any individual rank, all gather as you do it for all of them. And then finally, reduce ed scatter. I couldn't find a good picture of this, so I'm reusing. The one from last time is like reduce, where you take a bunch of different values and you add them or perform another commutative operation on them and put it on one rank. But like scatter, you're going to be putting different pieces of the vector or tensor on different ranks. Okay. And remember that all reduce ed is equivalent to reduce plus all gather. So the way to remember this terminology is as follows, because it can kind of confusing like which ones all gather, which ones reduce ed gather is that reduce just means you're performing some associative and commutative operation, like some or minia max or average broadcast. Gaatter is the inverse of gather and all. It just means all a destination is all devices. Okay, so totally this is a review from last time, so actually .
speaker 2: any questions .
speaker 1: before I move on since we're going to build on these primitives. So it's useful if everyone understands. Okay, so now let's see how this is actually implemented in starting with the hardware. Okay, so here's a classically what hardware for GPU's looks like. So this is kind of in the home. You have a computer, I guess, and you have your cpu's and generally you have your GPU's on onnode that communicate via pci e bus. And if you have to connect, communicate between different nodes, then this is all connected to Ethernet. So this is kind of typically how you machines were built. If you buy a GPU and you're for gaming or something, this is kind of probably what your setup looks like. As we'll see, this is kind of suboptimal because there's a lot of overhead. When the data gets needs to get shipped from GPU to GPU, it has to go through the kernel, get copied into buffers and then go through this kind of a transport over Ethernet. And that introduces a lot of overhead. So what has happened in modern times with scientific computing and deep learning is that if you know that you're going to just string a bunch of GPU's together and do something together, then we're just going to hook the GPU's up directly basically. So in the nvidia ecosystem, we have nvy link that directly connects the GPU's, therefore bypassing the cpu. You don't need to go through kind of the kernel of the host machine. And even at across nodes, we can connect the GPU's directly via mv switch. So therefore, we're bypassing Ethernet because Ethernet was developed a long time ago, clearly not for these type of applications. So mswitch just and mvy link kind of skip all of that and just optimize directly for the type of workloads that we're interested in. So if you look at H -100s node, sorry, each GPU has 18 mv oplinks generation four coming out. So that gives you a total bandwidth of 900 gb. If you compare to these, it's certainly a lot faster than pcie and it's certainly way faster than Ethernet. And in comparison, if you think about the cost of just going from the sm to reading from high bandwidth memory, that's still quite a bit faster by a factor of four or so. And of course, these numbers are constantly changing with a new black wells, this number is like two or three times more, I believe. Okay, Yeah .
speaker 2: this the cpu and then like to another GPU or it's direct.
speaker 1: So the question is for the pcie, where how does the data get transferred? I think it has to still go through the cpu. Was there another question? And the piece was, I mean, it's developed for things like other things are connected to it as well, like your sound card or your ssd hard drive. So it's not really it's sort of like a general purpose, you know a bus for communication of devices Yeah also has a connection with Yeah. So the question is mv link also connects to the cpu. We're going to see a bit later how I think maybe just in the slide how things are connected. Yeah. So you still need to talk to your cpu.
speaker 2: of course. Yeah.
speaker 1: Okay. So there's this command that you can run and this produces some output which allows you to see how the GPU's are actually connected. So I ran this on our cluster. There's eight GPU's. I guess you won't be able to get hgpu's, but I guess if you could, this is what it would look like. And you see that between every pair of GPU's, there's mv 18 connecting. There's also these kind of network cards and and other things. Okay. Oh Yeah. So then network cards are basically what gives you the pcie connection and the cpu's. So okay, so that's the hardware. So how do you use the hardware? So nvidia has spent a lot of time developing really good software on top of their, I guess, really good hardware. And there is a collective communication library by nvidia called nickel. And this essentially translates the collective operations, which we looked at before, like all reduced into low loof packets that need to be sent between GPU's. So this library actually does a lot of work because it allows the programmer just to operate that level of, I need this tensor to appear on all the machines and it just happens. Okay. So you just a little bit of what what happens is when you configure setup, Nicol, you bring up a bunch of devices and there's some communication that happens to figure out the pollogy of the hardware. It optimizes the path between the GPU's. And then when you actually call these collective communication operations and the launch kuda kernels to send and receive data. Okay, so that's nickel. It's provided as a library, but nickel is still a bit too low level to us because most of what we're doing is you in Python. So there's a PyTorch has this torch distributed library, which essentially provides a clean interface for these collective operations. Now from the comfort of your PyTorch program, you can just write all gather into tensor on a tensor and it will appear on the different ranks. It also has this nice useful feature that it supports multiple backends for different hardware. So in particular, if nickel rewas for GPU, but you can also run collective operations. Remember, this is not GPU specific, it's just for any set of devices. So you can also do it for cpu using this backend called glue. So if you're debugging stuff on your laptop for your assignment, for example, you can use glue and still be able to run things without even a GPU. So anyway, that's another advantage of having these high level primitives is that they're much more portable than having to only having something that's very GPU specific. Of course, the performance is going to really depend on the hardware, but at least logically, you can make sure your code runs. Pdistributed also supports other high level things like fsdp, which tatu talked about last lecture, but we're not going to use this in this class because in the spirit of developing things from scratch, that's just what we're going to do. Okay, so let's look at some examples of how torch side distributed collective operations work. Okay, so there's this utility function I wrote, which you can take a look at it in the code if you want, which takes a function and just runs this. Basically it's a wrapper around Python multi processing where it just runs four processes that execute this function. So when you're in this function, you should think about it as there's actually world size number of processes running this identical function where the rank indexes from zero one all the way to world size minus one. Okay, so right now I'm stepping through just one of the ranks because lectures are not parallel. And so generally what you do is the first thing the process needs to initialize itself. And you essentially they need to kind of find each other, right? Because you're a multi processor running A A lot of processes, they need to connect to a single host so that they can figure, know that each other exists. So know that this is not where all of the data goes. The data goes through nickel, but this is just for kind of coordination. And since we have a GPU, we can use nickel, otherwise you would use glue. Okay? So after you set up, so now we're gonna to do some stuff. There's this useful function called barrier, which basically waits for all the processes in your process group to get to this point. Remember, everything thing's running asynchronously and in some cases you just want to have a synchronization point. So barrier does that. The reason I put it here is actually sort of for trivial reasons because I want all these print statements to kind of be grouped together. But there's other reasons why you might want to use barrier that's we'll get to later. So I'm going to, for each of these groups, construct a tensor. So the tensor is zero, one, two, three plus the rank. So I'm going to print out for each rank before it all reduce. What does it look like? Okay. So here's what it looks like. Can people read that in the back? Yes. Okay, good. All right. So on rank zero, it's zero, one, two, three. Rank one, one, two, three, four and so on. And notice that because it's async, the orders it's just out of in whatever order happens to print. Okay? So each rank has a different tensor. And then then you all reduce. So all reduce, you PaaS in that tensor. You say, I want to sum it in this case, I'm not going to do async, but you can do async, which is useful for overlapping communication in computation. And then afterwards, what happens, after all, reduce as advertised, basically for the first component, you add them up, you get six this you get ten, 14 and 18. Okay? So after all, reduce the basically this tensor gets overwritten with the corresponding sum. So it's very, very kind of nice and simple to use. Okay, so so let's do reduce scatter. So reduce scatter, I'm going to create an input which has dimension worlsize, in which case this is four and I'm going to allocate an output because reduced scatter is not going to operate in place. This is just going to be a scalar. So before the bdo scatter, this is what it looks like. I have my input. As before, output happens to be zeros, but it could be any value since I didn't initialize it. And then after the reduced scatter, we're passing the input and the output and I'm gonna na sum. Then I get essentially what happens is that for the first component, I sum and that goes on rank zero. For the second component, I sum and it goes on rank one and so on. So as you notice, it is producing the same operation as all reduced, except for the output is sort of scattered across all the different ranks. Okay, so now let's do all gather. So I'm going to just directly use the output of reduced scatter, which is this as the input, and then I'm going to allocate an empty array for the output. And then so so before the all gather, the input is this and the output, I guess are just arbitrary values. And after I do the all gather, now what happens is I get all these tensors to show up on all the devices. Okay. So this is just a kind of also an example. Hopefully now you're very convinced that reduced gather plus all gather is just all reduced because I computed exactly the same quantity as I did for all. Okay. Questions is clear. Yeah in reduced catare we keep track of. So the question is, in reduced scatter, do you keep track of which index goes to which GPU? So by convention, the dimensionality has to be then basically the world's I mean, it could be a general tensor, but one of the dimensions is the world's size. And it just infers that basically what you want to do is the output is, let's say, the sorry, the input has to be basically world size. And then it knows that basically the corresponding computations go to each of the outputs. Yeah you have to be a bit careful with making sure the dementiality aligns. So going through this small examples can be helpful. Is there another question? Okay. So finally, we're now in this process that's running and when you're done, you just clean up. Okay. So so far, we've talked about these collective operations, a bit about how they're implementing PyTorch and it's nickel and then pi torch. Let's do a bit of benchmarking in the spirit of what we did in assignment or the first lecture or rather the second lecture, we're going to focus on one node for now. So let's do all reduce. So I'm going to have this tensor of 100 million elements and a world size of four. So I'm going to just allocate a tensor. And generally I think as you hopefully are can appreciate now, that when you benchmark, you have to really be careful to kind of clean your palette in some sense. Like you in this case, I'm going na warm up, basically run the operation once and then synchronize and do barrier. Some of this is I think probably a bit defensive, but just to be safe so that all the kernels get loaded and whatever needs to be kind of compucomputed. And then I'm going to start the clock, all reduce and then synchronize again and stop the clock. Okay.
speaker 2: so now .
speaker 1: I can look at how long that took. Okay? So if I scroll down here, I guess this is not that infora I should have printed in microseconds. Probably. It was, I guess, very quick, some number of seconds. And now let's measure the bandwidth, which is the number of gigabytes that were actually transferred in aggregate per second. So the way we do that is we have to think about what actually gets transferred here. So there's a tensor with that element size. And the size of each element is, I guess, I think this is float 32. So that would be two or sorry, four bytes. And so that's the site and bites. Okay, so now this is a little bit subtle. How many bytes are actually sent or transsent slash received? So each tensor sitting on a rank has size bytes, okay? And it needs to send it to world size minus one. Other machines or not ranks rather so, but there's a factor of two. So why is there a factor of two? Because you're doing it all reducremember. So you need to send all the distinct elements into basically one place. It needs to get summed up and then that needs to go back to everyone. Okay. So a rank needs to kind of send the input out and then receive the output. So that's why there's a factor of two there. And so the total duration is the world size times the actual duration that PaaS. So I guess we're just kind of assuming that every you if there's four processors, that's sort of like four times as much wall clock time that happened and the bandwidth is just the bes stoover the duration. Okay. So what do we get here is about 277 gb per second. Okay. So you I think for H -100 above, I think I claimed that there was something like 900 gb per second. Now of course, as we know, your mileage varies depending on the size of the tensors and the exact number of devices and the weather, and no, not the weather, but various factors. So your mileage might vary. So it's always good to benchmark to see what is actually the number of gigabytes per second you're getting. Okay? So reduced scatter is going to be very, very similar. So let's just go through this square very quickly. So we created input, which is world size times number of elements. So each rank is going to have the matrix and and so we're going to warm up and then start the clock, reduce scatter, stop the clock and then see how long it took. Well, okay, that's not helpful. And then let's look at the bandwidth. So this number of scent bytes is no a factor of two here because in reduced scatter, remember, all you're doing is you're sending your inputs into one place. If you just think about reduce, right, all the elements just go into one place and that's it. And scatter just means that different components of your tensor are going to different places, but it's effectively it's like a reduce. Okay? So if you do the same calculation, you'll see that it's I guess I get 70 in this case, so I don't exactly know why it's exactly 70 as opposed to some other number. I guess one could speculate that all reduce generally there's more traffic that happens. And all reduces. Are you potentially more optimized? I think that nvidia hardware has this kind of sharp acceleration that actually does sort of some of these computations in the actual network, which just shades off a factor of two. But I don't know if that completely accounts for a difference here. There's a lot of stuff that happens in nickel that it's a little bit hard to kind of reason about the performance exactly. Hbenchmarking, Yeah. I A question .
speaker 2: about the set heor the deheparticularly Yeah, specifically, it looks like it actually just like the that's a piece as well. But what about like the inputs of the reduction stuff of wondering how it gets the inputs.
speaker 1: So the question is, it seems like this is just the bytes for the output and word about the input. So to be clear, I am suming that the inputs just are already on the device. So I'm not counting that time and I'm just counting what needs to happen to do the reduce scatter.
speaker 2: This is just a scatthis is a .
speaker 1: reduced scatter operation.
speaker 2: So this function .
speaker 1: does reduce scatter. So it's one .
speaker 2: operation. We covered it twice in the.
speaker 1: So you're saying that for all reducthere was a two x because you needed to reduce emagain, you needed to spread out again for reduced scatter. I mean, it's just a name. It's called reduced scatter, but it's really just a reduction. Okay. And you can also see based on this that if you do reduce scatter and you do all gather, each of those is doesn't have the factor of two. So when you add them up, you get a factor of two, which is another way to see that all reduces twice. Okay? And there's some references you can go read about how to benchmark and these collective operations. Okay, so let's now talk about the distributed training piece. So our general approach here is going to be I'm going to walk through a bare bones implementation of each strategy on deep mlps essentially. So recall that you generally are in the regime where mlps are the compute bottleneck and transformers, not the attention. So in some ways, even though this is a very simple architecture, it's fairly representative of the type of you know workloads that you'll see. Okay, so let's start with data parallelism. Actually, just one note is that data tensor and pipeline pioism are you can just think about them as different ways of cutting up either your model or your data, which hopefully I'll depict visually here. Okay, so in data parallelism, here's your model, assuming it has four layers, each layer of the mlp is just a matrix multiply where this is the hidden dimension. And so the data is also a matrix, which is there's the batch dimension and then the hidden dimension and data parallel just cuts along the batch dimension into essentially smaller pieces. Okay, so now each rank is going to get a different slice of the data. So let's do an example here. So I'm going to generate some sample data. So let's say I have a batch size of 128, hidden dimension of 1:24, and then just generate some random data. Okay? So I have batch size by number of dimension, and I'm going to run this data parallel algorithm or ddp. So here I'm going to so I got past this data. There's a batch size and the dimension as claimed from before. Now I divide the batch size by the world size. So I get the local batch size. That's how many how big the batch size is on a given rank. And then I'm going to, based on the rank, just figure out which starting ending indices of size, local batch size I need to access and then get the corresponding data from that. So basically I'm just reaching in and grabbing some subset of the roads based the rank. Okay, so now I'm setting ding up the mlp here. And this has done very sort of bare bones you could say. So here I am creating the mlp parameters. So each layer has essentially a matrix, which is numb dimension by numb dimension. And remember, numb dimensions is 1:24, and I'm going to create the optimizer. So remember, this function is running asynchronously on all the different on each rank. So each of the four ranks is going to be running this with rank equals zero, one, two, three. And now I'm going to start training. So for a number of steps, I'm going to do a forward PaaS through the layers matrix, multiplied nonlinearity, matrix, multilied, nonlinity. There's four layers here going to compute some loss. I don't care what the loss says, it's just made up, something made up. And I'm going to do the backward PaaS. So so far, this just looks like I'm implementing sgd, right? And that's kind of the point. The only difference is now to implement ddp is that you just inject this line here, which synchronizes the gradients across worker. So what you do is for each of the layers, you call it all reduce where you're averaging. And the thing you're averaging is prim dot grad. Okay? So it's just like you've kind of hijacked this someone's sd code and you're saying, wait, I'm actually going to just mix all the gradients after the backward PaaS. And then after you do that, you just update the parameters as usual. So from the sd perspective, it seems like nothing is happening. I'm just running sd, but someone has just mixed my gradients. Okay, so I guess just print out some things. So data parallel, I'm printing out the loss. So one thing to note is that the losses are different between all the different ranks because they have different datbut. After I'll reduce all the parameters are the same. Okay, so this is a crime, your textbook application of all reduce in ml setup.
speaker 2: but each rank rous this through all views.
speaker 1: So the question is, how do you ensure, if all of these processes are just running asynchronously, how do you make sure that each of them is actually, for example, on the same step? This is because all reducis a synchronization point, theystop everyone and do the all reduce. So you have to be careful because if one of your ranks has a missing all reduced, then it just hang. Yeah, Yeah, Yeah. Oh.
speaker 2: why does getting .
speaker 1: the .
speaker 2: iniframeters .
speaker 1: the question is why does getting initial .
speaker 2: parameters depend on the rank?
speaker 1: They're the same. The reason is just because I guess I don't the code for this basically puts it on the appropriate GPU. Okay, any other questions? So ddp is something you implement in assignment two, which maybe some of you have look at or maybe you're not, it will be done in the context of a transformer. But this is sort sort of the most bare bones version. So you can see very clearly what's happening. Okay, so that's .
speaker 2: a ddp .
speaker 1: losses are different across francs, but the gradients are reduced to be all the same. So therefore, the parameters of all the ranks are the same, right? So actually, you're doing world size number of sd runs, but because they're synchronized, they're doing the same thing. So you can think about this as sort of an instantiation of you know analog of activation, know checkpointing where sometimes you just do extra compute because you don't want to store things. In this case, you know we could have, for example, I'll shift the optimizer state around, but that would be a bad idea because know it's much faster just to run the update the optimizer state then to actually move the optimizer parameters around. Okay, so last year I did try to do fsdp that, but that was a sort of a haball. So I'm gonna to skip that and do a tensor parallel. So here the picture is we leave the data the same and now what we're gonna to do is we're gonna to cut the model along the hidden .
speaker 2: dimension.
speaker 1: okay? So each rank is going to get every layer but it's going to get only part of each layer and what we're going to end up doing is transfer all the data and the activations around, okay? So we're generating the same sample data and let's look at tensor parallel. Okay, so I have the batch size and number of dimension as before. And now I'm going to knobefore. I was cutting batch size, but now I'm cutting numdim. So I have local numb dim equals 121, zero, 24 divided by world size and that's 256. So each model, essentially each rank gets a part of the model, which is one over the world size fraction of the parameters. Okay. And remember the whole why we're doing feralliism at all is because the model won't be able to fit into a single GPU, so we're going to shard it across multiple GPU's. So the parameter matrices are now numb dim by local nb dim, and now each rank is going to I'm only going to implement the forward PaaS here, not the whole training loop. So I'm going to start going through all the layers. So I'm going to compute the activations first. So this looks pretty normal except for remember, the activations are actually batch sized by local numb dim rather than numb dim because I each rank only has a fraction of the activations now. But now once I get the activations, I need to communicate. And here what I have to do is I'm going to allocate memory for all the activations. So at this point, every one has a as an x, but that x represents a different part of the activations. Okay, so now I'm going to just allocate batch size. I'm local numb dim, but world size number. So basically each rank is going to basically have enough. I'm going to just get the basically have world size number of batch size by local numb dim know matrices and then I'm going to do an all gather. Okay, so I'm going to send all the activations and this, I mean, it's fairly simple. So x, remember, is batch size times local, numb, dim, but x is different for every rank. So when I do that all gather, I'm going to put it in activations, which has essentially a world size number of the same shape as x. Okay? So now every rank has the same vations now has activations of all the models of the whole model, okay? And then just just to concatenate them together to get x, okay? So now x is now again batch size by numbin. Okay. And I know repeat. So as you can see, this is know there's quite a bit of communication that happens, which is why, remember, tatu said that for tensor parallel, you need pretty high interconnects, otherwise you'll be passing a lot of these activations around. Okay? And then you do it for the next layer and the next layer and you get the idea. And just to print out ce, some output. So tensor parallel, let's see here. Forpass produces activations of basically the full size and everyone has the same activations at the end. Okay. So backward PaaS I'm going to skip because that's kind of annoying to do. All right.
speaker 2: any questions .
speaker 1: about that? Yeah, I was just wonderso. Why is it hard to do the backpads? I don't think it's necessarily hard, but in I guess, in the constrained time and space, it's not hard. It's just requires a bit more work. Okay, so now let's .
speaker 2: go to pipeline .
speaker 1: parallelism. So in this case we're cutting the model by layers. So all the ranks get all the data and all the rks. Each rank gets all of one layer, but they get different layers. Okay, so sample the data and run of this function for all the ranks. Okay, so here I'm going to figure out how many layers go in each know rank, which is two here. So I have a four layer network. I have two you know two ranks. So each rank gets two of the layers, just like this picture actually. And here I'm going to just allocate the parameters just for the layers that I need. Okay, so I'm going to do the forward apass. Remember there a further optimization that you do, which is if you just do it naively, you get these pipeline bubbles that tattoo talked about before. One way to sort of mitigate that is to break up the batch into micro batches. So here I'm going to divide this batch into batches of size 32. So four batches of size 32. And then now the idea is that every rank is going to is actually wait for the previous rank to PaaS it to the activations. It's going to apply those layers and then it's going to forward it to the next rank. So starting at the base case, we have rank equals zero. That's just the data. So I'm just chunking the data into a bunch of micro batches and going through each of micro batches. First, I receive the tensor. So I'm using these point to point primitives now instead of the collective primitives, and I essentially basically receive the tensor x, and then I'm going to compute the layers that are assigned to this rank. So in this case, there's only two of them, and then I'm going to send it to the next rank. And then again, sand is a point point operation. And then the next batch I'm going to do the same thing. So I'm here to skip that. Okay, so that's basically it. So pipeline parallel, at least the very naive version of it, is relatively conceptually simple. But it's not to mention last time, there's many things that are missing from this basic implementation. Overlapping the communication and computation is something we're not doing at all here. For example, receive and send our synchronous, but you should really make them async. And also the order in which you do the forward, actually, this is just a forward, even that, not the backward. But once you have the backword, then you have to figure out how to interleave the forward and the backward steps.
speaker 2: Wondering I guess like maybe what you just mentioned about like the asynof being shown here is I guess, seeing like the GPU will be sort of leasing like whether another one can access something to it. And it's kind of this game. Like it only starts processing it once the layer before it passes into it.
speaker 1: So the question is, is this kind of like a vendriven program where you're just kind of waiting for things to happen? And I think event driven programming, you basically write these hamburgers and then whenever stuff happens, maybe you get a mouse click, maybe you get a file ready event, then a piece of code runs. That's quite different, I think from this style of coding where everything has to work in lockstep. It is that you're sort of waiting for the previous you're rank to send you the information. But at least in this implementation, there's no flexibility of where it's getting from. It's not like it's waiting for arbitrary data to come from anywhere. I think there are ways to do asynchronous training, which was you, I think quite popular you know ten more than ten years ago, where there is more event driven, where you have a server that sends data and whenever the gradients already, it just uploads and then the gradients get accumulated. And if workers die, then that's sort of hanthem more robustly. But in modern training, despite scaling up quite a bit, everything seems to be kind of in a synchronous paradigm. Yeah so it is that when I say the workers are and the ranks are operating ering asynchronous, that's just because it's different processes, but you're still putting quite rigid synchronization on how everything is working in lockstep. Yeah. Students.
speaker 2: how do you change this program? Who have take did.
speaker 1: So the question is how would you change this to overlap communication and computation? So for example, when you send this, there's no reason to just wait for the data to be sent. You just basically fire off the send. Remember that the send actually happens on the GPU via some kernel launch. So that's sort of independent and it can just go and process another micro batch right away. So the way I think you would do this is there's another function called I and which is asynchronous. Actually, this should be assynchronous asynchronous, which returns a handle. And so you basically do all the Scand, then at the end you basically wait for all the sento complete and then for overlapping the when you actually have the backward step, then you basically have to schedule that in here Yeah .
speaker 2: for svia a multiple ch. So the question is if .
speaker 1: you have multiple sends and multiple receives, how do you know which is which? So here your the tenser name doesn't matter is just whatever a variable is there and what you're specify is the source. So if I'm at a node and I'm receiving then whatever the next message coming from that rank, I'm just going to put in this aemcontinue executing. If you want to do $0.02 from the same rank to the same destination.
speaker 2: So I'm not .
speaker 1: quite sure about this, but I think if you have two cends, it's sort of put in a stream. So the order of the sends still is preserved. It's just that other stuff can happen at the same time. Like you you can send to like I think if you have a pair into $0.02, then that order is preserved. But the order in which you know you send some other rank ascending to another rank, it can happen at any time. Yeah if .
speaker 2: you the water just gets off there or like.
speaker 1: so what happens if you send and the one's receives it? I think it would just stwe just wait because there's no, Yeah, I mean, I mean, the process could just be running and you don't know whether it's just I mean, just code executing so you don't know if it's never gonna to get there or if it's just gonna to be a matter of time Yeah. So the question is, what happens to the last rank? So at the end, the last rank has all the activation. So that has basically the results of a full forward PaaS. And then if you implement the backward PaaS, then you would be actually now computing the gradient with respect to loss, and then you would go back down and send to from rank to rank minus one and so on. Okay. I guess maybe I was afraid I was going to run out of time, but it looks like I had actually half time. Maybe next year I should do the backward PaaS. Okay. So actually, I'm going to finish quite early today, but so if you have any other questions, you should ask. So so far, we've gone through three simple examples of data tensor pipeline parallel. Of course, this is forsimple mlps. You would actually want to do this with your own fancier model like a transformer. I did argue that at least at the the core idea is you can sort of understand through the mlp, I think. But of course, when you want to train, you want to train transformer, not a deep mlp. So you still have to implement the full complexity. What's also missing is the communication and computation overlap, which is not really handled very carefully here. And there is generally a more complex code with bookkeeping. I encourage you to check out like Megatron lm or pytorches fstp, it gets fairly Harry. And one of the things that I think misome of the bookkeeping, at least for, let's say, fsp, and you'll be exposed to this in a two a bit, is that if you want something that handles arbitrary architectures, then you have to figure out the parameters and do a book, a bunch of bookkeeping, too, and figure out where their layers are and so on. Whereas in the mlp case, it's just I've sort of made the decision that I'm going to split the model in this particularly simple way. One other thing I'll just mention as an aside is that all of what we're doing in this course is as piytorch, but it is useful to be aware of this whole other ecosystem around jacks and tpu's, which is actually kind of nice in some way. And the idea here is jx has allows you to define the model. It definine the sharding strategy and then the jx compiler handles the rest. So there's this toolkit that we developed called laanter based on jx. And I'll just show you a snippet of what so this is fsdp and ten lines of code. And basically you have a your model and then you just say shard with this particular. I mean, I don't expect you to kind of read this exactly, but basically you define which dimension you're gonna to shard by and then you know that's it. And similarly for tensor parallel, you're just saying, I'm going to shard the model along, know you can shard by on the head dimension for attention and also you can shard based on the model dimension. So in some sense, know this gives you a sort of you a conceptual simplicity of what you're trying to do is you have this basically computation graph, but it has these kind of dimensions, the modimensions, the embedding dimension, the attention sequence dimension. And jx allows you to basically just specify which dimensions you want to cut by and also define a mapping from that onto the actual tpu's. And then the Jax compiler magically just figures out how to compile that down into the primitives that shuffle things around. So this is much more higher level than doing the operating with a collective communication. But we're sticking in with PyTorch because it allows you to see kind of underneath the hood what's actually happening. But if you're actually doing this in the real world, obviously you don't need a and you probably shouldn't implement all of this from scratch. Okay. So that's the end of the jacks digression. So just summarize. We've seen many ways to paralleze so far. And each of these ways of paralyzing is you can think about, just like splitting either the model or the data along some dimension, either the data, the batch dimension, the width dimension, or the depth dimension, or the context length dimension. We also see this kind of recurring theme of recomputation. You can kind of recompute something from scratch, or you can store in memory and suffer the data transfer cost. Or now in the multi GPU multinnode setting, you can actually store on another GPU's memory and then communicate, which is even slower. So there's a kind of these these trade offs you know here and know often recomputation is actually can be better, but obviously you can't. You compute the whole thing and often you're either communication or memory limited. A final word is that it is the case that harbor is getting better. So you might think that, well, maybe none of this is really necessary because in five years everything will fit in l one hpm. So this is not going to be the case because those might grow quite a bit, although there are still physical limits, will always be ending up with bigger models that sort of are at the limit of what the harbor can do. So this harcle Cal structure, ever since system computer systems was a thing, has always been with us and it will always be there. Okay, that's all I have for you today, so I can take any questions. Yeah .
speaker 2: set of five the bottom fast very different because if you're not verification might be a function of the important from example then that's that there any impasor?
speaker 1: So the question is in data parallel, you're saying that even though the parameters are all kind of synchronized, there could be other things that depend on the data like in batch norm. So I don't actually know how you was always kind of knowing so I don't know exactly how you would do that off the top of my head in at least in an llm world that doesn't really show up because layer norm is used and as long as you initialize all the parameters using the same random seed, you'll be fine. I mean, there could be like non determinism issues on the GPU, but hopefully those are minor. Yeah. So the question is, does piytorch have some niceties as well, kind of like what jx offers is that so I mean, piytorch does have the fsdp library, which you should absolutely use if you're not taking this class, which basically is a wrapper. You define any model and it just does sfdp on it. I think that now if you're asking how well it can more custom allow you to more do custom charting, I think there are some things that are coming, but it's not, I think, as developed. I mean, I think there's sort of this, I think, spectrum between the jacks world where you sort of declardeffind things. And I think the Google infrastructure, if you stay within the jx tpu system is pretty well developed. But then if you look at kind of deep sequwhich is a kind of an opposite end where you have these GPU's with actually really bad interconnect, which means that they have to go in and hack. You know they actually go to the kind of nickel level and actually do a bunch of things, which I don't quite understand to eke out the performance. Whereas if you're writing a jture just to kind of from on high, declare your model and then you know stuff happens. So it's kind of the ways that you leverage hardware, I think really depends on what what ecosystem .
speaker 2: your upper hand. Yeah, we manage the amount of conditions and can recomplete some subset of .
speaker 1: the activations. Yeah. So the question is activation checkpointing what there is an api that basically allows you to, I mean, I guess, in pytorgen jacks to specify which parts you want to recompute, because clearly you don't want to recompute everything or nothing, probably every few layers, probably right after like big mamalls, where for example, if you have, let's say, Mamal and then point wise, law ineity, I don't think you need to store like two copies of basically, if you have two things where it's sort of trivial to get to, then you might as well just store one version. Yeah .
speaker 2: over there, hardware, we're looking more specialized.
speaker 1: So the question is, are GPU's going to ever be replaced by transformer specific hardware? So you're seeing this in the inference space quite a bit already with like Grock and Cerebras have specialized hardware that can do inference and also, I guess, training. So we used those training. So basically those hardware essentially give you just a lot more on ship memory. I mean, that's basically the name of the game. I think the reverse has like a huge you know essentially effectively A M one cache so you don't have to move things off. And I think a lot of simplifications can happen because GPU's were there's a lot of baggage actually, if you think about because they were divine in an era where you had to do a lot of branching and like various types of ad hoc computations, which are not really needed in the deep learning regime. So I think there are quite a few opportunities to improve that hardware as well. I think there was a hand back there and.
speaker 2: This is like the right question. I'm thinking that. But in the context of the lecture, it's basically a model that's been changing one, though, that's been optimizing the 21, the tastep. We're talking about communities to preventing and train a model dle, for example, I. Not just like White doom, but actually need to kind of recatculate everything about her.
speaker 1: Yeah. So the question is, can these techniques be used to essentially do continued training? Yeah, absolutely. So if you think about the unit of what we're working with is just doing gradient steps, right? So if you take a half train checkpoint, you can just continue doing what this is. There's nothing specific about starting from scratch here. I think there was a question from that.
speaker 2: So I'm not like the models were still right where and the previous like presumably there's is a physical technical reason make nodes much larger than they are currently. Like what's the change that you're talking about? So if you could just make GPU nodes like infinitely like as big as you wanted, people would do that. So presumably, and there's a tech like a hardware reason that's not possible. What's the actual advancement being done for specific haryeah?
speaker 1: So the question is there are physical limits for sure for for a GPU. Let me just go. So the so you kind make GPU's obviously infinitely large or infinitely dense. I mean, there's also like you know power issues. You know you do get rid of all the heat and know there's only so much kind of bandwidth that can no fit. So I don't know the exact details, but at least in some of the cerebris case, I mean, they sort of have this way of manufacturing basically the chips so that the memory is kind of on the chips. And so I guess it's just a way of putting it on there end. I think that there are obviously no trade offs because it comes at a cost of not having as much flexibility. But but in general, I think the way to maybe think about this more broadly is that you know GPU's were still developing kind of the cpu era where it's much more control focused. I have code I'm executing. That's the sort of first class citizen, and then data needs to be moved, execute to handle the code. But the big difference with deep learning worklois that it's all sort of data flow. Like the computation graph, if you look at these is like static you know from the beginning, exactly all the computations that are going to be done until essentially the end of training, right? So using that knowledge, you should be able to kind of lay out your computation in a much smarter way than having to deal with the flexibility, uncertainty over ad hoc computation. Okay. Maybe a few more questions .
speaker 2: about and Yeah.
speaker 1: So the question is, where is the computation graph stored? Well, the code is I mean, all this code is running on this cpu, but when you call something like I torch function that needs around on GPU, then it launches kernels under the hood and the kernels are a code that runs on the GPU. Yeah.
speaker 2: I'm not sure of that.
speaker 1: So I guess maybe another answer is that the computation tional graph is more of A, I guess, a conceptual. You know it's not like there is a graph literally that's being, you know, I mean, I guess there sort of is, but it's not like the graph gets put on the GPU makes sense. Okay.
speaker 2: so these cations, remember Yeah we cu in rewachen.
speaker 1: So the question of the communication parameters, are they cpu or GPU? So these collective operations are in some sense abstract specification of what types of operations need to happen, which can happen. If you remember this pytordistributed has different backgrounds. So it could happen on GPU or happen on GPU .
speaker 2: first having a visit it like it is the cview sort of scheduthem or what is it on the Yeah. So what the .
speaker 1: cpu sort of drives basically is the sort of the master still. And then when you do a collective cooperation, it calls a nickel library, which launches, which is you know it's still cpu and then it launches some kernels that move data around. Yeah. Okay, maybe this is a good place to end. All right. I will see you next Monday.

最新摘要 (详细摘要)

生成于 2025-05-13 19:25

CS336讲座回顾:Parallelism 2 - 多GPU与多节点训练策略

本讲座(Stanford CS336 - Parallelism 2)探讨了利用多GPU和多节点并行来加速模型训练,核心在于克服数据传输瓶颈以最大化GPU利用率。讲座首先回顾了单GPU内的并行技术,并重点转向跨GPU和节点的并行。内容介绍了数据传输的层级结构,从GPU内部的L1缓存、高带宽内存(HBM),到同一节点内GPU间的NVLink,再到跨节点的NVSwitch,指出数据传输速度远慢于计算速度,是主要的性能瓶颈。

概览/核心摘要 (Executive Summary)

本讲座(Stanford CS336 - Parallelism 2)深入探讨了在多GPU及多节点环境下加速模型训练的并行计算策略,核心目标是优化计算结构以规避数据传输瓶颈,从而最大化硬件利用率。讲座首先回顾了计算与通信的层级结构,从GPU内的L1缓存、HBM,到节点内GPU间的NVLink,再到跨节点的NVSwitch,强调了数据传输速度远慢于计算速度,是性能瓶颈所在。接着,详细介绍了集体操作(Collective Operations)如broadcast, scatter, gather, reduce, all-gather, reduce-scatter, all-reduce等,并阐述了它们在Nvidia的NCCL库及PyTorch Distributed中的实现和使用。讲座通过代码示例演示了这些操作,并进行了初步的基准测试,分析了带宽计算方法。核心部分转向分布式训练的三种主要并行策略:数据并行(Data Parallelism, DDP),通过在各GPU上复制模型、切分数据批次,并在反向传播后使用all-reduce同步梯度;张量并行(Tensor Parallelism),将模型参数(如MLP的权重矩阵)沿隐层维度切分到不同GPU,并在层间通过all-gather通信激活值;流水线并行(Pipeline Parallelism),将模型的不同层分配给不同GPU,数据以微批次(micro-batches)形式在GPU间顺序流转。讲座强调了这些策略的简化实现,并指出现实应用中需考虑通信与计算重叠、更复杂的簿记等问题。最后提及JAX等更高层抽象框架,并总结并行化是应对模型规模持续增长的持久需求。

讲座背景与核心挑战

讲者指出,本讲座是系统讲座的第二部分,重点关注跨多GPU和多节点的并行化,旨在最大化硬件利用率以加速模型训练。上周讨论了单GPU内的并行,本周则扩展到多GPU环境。

  • 核心挑战:数据传输是主要瓶颈。计算发生在GPU的流式多处理器(SMs)上,输入输出数据可能存储在L1缓存、HBM(高带宽内存),甚至其他GPU上。
    > "The name of the game is how do you structure all your computation to avoid data transfer bottlenecks?"
  • 目标:保持算术强度高,使GPU饱和运行。
  • 回顾上周:单GPU内通过融合(fusion)和分块(tiling)等技术,将数据加载到L1缓存或共享内存进行本地计算,减少对HBM的读写。
  • 本周焦点:跨GPU和节点的通信,涉及模型和参数的复制与分片,以及优化器状态的管理。

计算与通信的层级结构

讲者概述了从快小到慢大的存储与通信层级,强调了最小化数据传输的核心概念在不同层级具有相似性,但具体机制不同。

  • 层级(由快小到慢大)
    1. 单节点,单GPU内
      • L1缓存:极快但极小。
      • HBM(高带宽内存):较L1大,速度稍慢。
    2. 同节点,GPU之间
      • NVLink:Nvidia GPU间的高速互联。
    3. 跨节点,GPU之间
      • NVSwitch:连接多个NVLink,实现更大规模的GPU互联。
  • 传统硬件 vs. 现代科学计算硬件
    • 传统:GPU通过PCIe总线与CPU通信,节点间通过以太网通信。数据在GPU间传输需经过CPU内核、缓冲区拷贝,开销大。
    • 现代
      • NVLink直接连接GPU,绕过CPU和主机内核。
      • NVSwitch直接连接跨节点GPU,绕过以太网。

        "NVSwitch just and NVLink kind of skip all of that and just optimize directly for the type of workloads that we're interested in."

  • 硬件数据
    • H100 GPU:每个GPU有18个第四代NVLink,总带宽900 GB/s
    • 相比之下,高带宽内存(HBM)的读取速度(例如从SM到HBM)仍然快得多,大约是NVLink带宽的4倍左右。
    • 讲者提及这些数字会随新硬件(如Blackwell)发布而改变,预计带宽会增加2-3倍。
  • 查看GPU连接:可通过类似nvidia-smi topo -m的命令查看GPU间的连接拓扑。

集体操作 (Collective Operations)

集体操作是分布式编程中用于管理多节点/多设备通信的基础原语,比手动管理点对点通信更优。

  • 术语
    • World Size: 设备数量(例如4个GPU)。
    • Rank: 设备的索引(例如rank 0, 1, 2, 3)。
  • 主要集体操作
    • Broadcast: 将单个rank上的数据复制到所有ranks。
    • Scatter: 将单个rank上的数据(如一个列表)的不同部分分发到不同的ranks。
    • Gather: Scatter的逆操作,将不同rank上的数据收集到单个rank上。
    • Reduce: 与Gather类似,但收集数据时执行一个操作(如求和)。
    • All-Gather: 与Gather类似,但所有ranks都接收到收集后的完整数据。
    • Reduce-Scatter: 结合Reduce和Scatter,对数据进行reduce操作,并将结果的不同部分分发到不同ranks。
    • All-Reduce: 等价于Reduce操作后跟一个All-Gather操作(或Broadcast操作,原文提及reduce + all_gather)。
  • 记忆技巧
    • Reduce:执行关联和交换运算(如sum, min/max, average)。
    • ScatterGather的逆操作。
    • All:目标是所有设备。

集体操作的软件实现

  • NCCL (Nvidia Collective Communications Library)
    • Nvidia提供的库,将高级集体操作(如all-reduce)转换为GPU间发送和接收的底层数据包。
    • NCCL会探测硬件拓扑并优化GPU间的通信路径。
    • 调用集体操作时,NCCL会启动CUDA内核来发送和接收数据。
  • PyTorch Distributed (torch.distributed):
    • 为集体操作提供了简洁的Python接口。
    • 支持多种后端:
      • NCCL: 用于Nvidia GPU。
      • Gloo: 用于CPU,方便在无GPU环境下调试。
    • 高级功能:如FSDP(Fully Sharded Data Parallel),但本课程为从零构建,不直接使用。
  • PyTorch代码示例
    • 初始化: dist.init_process_group(),指定后端(如'nccl''gloo'),各进程需连接到同一主机进行协调。
    • Barrier: dist.barrier(),同步点,等待组内所有进程到达。
    • All-Reduce示例:
      1. 各rank创建不同张量(例如,tensor = [0,1,2,3] + rank)。
      2. 调用dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
      3. 结果:所有rank上的tensor都被更新为所有原始张量的总和。
    • Reduce-Scatter示例:
      1. 输入张量维度为world_size * num_elements
      2. 输出张量为标量(或num_elements)。
      3. 调用dist.reduce_scatter(output_tensor, input_list, op=dist.ReduceOp.SUM)(注意:PyTorch API中reduce_scatter_tensor更常见,此处根据描述推断,讲座中可能简化或使用了特定形式的input_list)。
      4. 结果:输入张量的各部分在对应rank上进行reduce,第i个分量的reduce结果存储在rank i的输出张量中。
    • All-Gather示例:
      1. 输入为各rank上的部分数据(如reduce-scatter的输出)。
      2. 输出列表用于存放所有rank收集到的数据。
      3. 调用dist.all_gather(output_list, input_tensor)
      4. 结果:output_list中包含了所有rank的input_tensor
    • 验证: reduce_scatter + all_gather 等价于 all_reduce
    • 清理: dist.destroy_process_group()

基准测试集体操作 (Single Node)

讲者演示了如何在单节点(4个GPU)上对all_reducereduce_scatter进行基准测试。

  • 通用测试流程
    1. 创建张量。
    2. 预热 (Warm-up):运行一次操作以加载CUDA内核等。
    3. 同步 (dist.barrier()torch.cuda.synchronize())。
    4. 记录开始时间。
    5. 执行待测集体操作。
    6. 同步 (torch.cuda.synchronize())。
    7. 记录结束时间。
    8. 计算耗时和带宽。
  • All-Reduce 带宽计算
    • num_elements = 100_000_000, world_size = 4, element_size = 4 bytes (float32).
    • bytes_sent_received_per_rank = num_elements * element_size * 2
      • 乘以2的原因:数据需要发送到某处进行聚合(reduce),然后结果需要发送回所有ranks(broadcast/all-gather部分)。
    • 带宽计算基于实际墙钟耗时:bandwidth = bytes_sent_received_per_rank / measured_wall_clock_duration
    • 讲者提及 world_size * measured_wall_clock_duration 作为一种衡量总计算资源消耗的方式,但这不直接用于计算单个操作的有效带宽。
    • 实测带宽:约 277 GB/s。这低于H100理论上的900 GB/s,讲者指出实际性能受多种因素影响(张量大小、设备数等)。
  • Reduce-Scatter 带宽计算
    • 输入张量大小为 world_size * num_elements (在每个rank上,此张量将被视为待reduce和scatter的部分)。
    • bytes_sent_or_received_per_rank_effectively = num_elements * element_size (每个rank最终只负责一部分结果的发送或接收,不像all-reduce那样每个rank都得到完整聚合结果并发送完整数据)。
      • 此处没有乘以2,因为数据仅单向汇聚并分散到目标ranks。
    • 带宽计算基于实际墙钟耗时:bandwidth = bytes_sent_or_received_per_rank_effectively / measured_wall_clock_duration
    • 实测带宽:约 70 GB/s。讲者推测all_reduce可能有更多优化(如Nvidia硬件的SHARP加速,可能在网络中进行部分计算),导致其带宽表现相对更好,但具体原因复杂。

分布式训练策略

讲者介绍了在深度MLP上实现三种基本并行策略的简化版本,认为MLP的计算瓶颈与Transformer类似(不考虑Attention)。

1. 数据并行 (Data Parallelism - DDP)

  • 概念:模型在所有rank上复制,数据沿批次维度切分。
    • 每个rank处理数据的一个分片(shard)。
  • 实现步骤 (MLP示例)
    1. 数据准备:总批次大小 batch_size,每个rank的本地批次大小 local_batch_size = batch_size / world_size。各rank根据自身rank号取数据子集。
    2. 模型初始化:每个rank独立创建相同的MLP模型和优化器。
    3. 训练循环
      • 前向传播:在本地数据分片上执行。
      • 计算损失:基于本地输出。
      • 反向传播:计算本地梯度。
      • 梯度同步:
        > "synchronizes the gradients across worker. So what you do is for each of the layers, you call it all_reduce where you're averaging. And the thing you're averaging is param.grad."
      • 参数更新:使用同步后的梯度更新优化器。
  • 关键点
    • 各rank上的损失值不同(因为数据不同)。
    • 经过all_reduce后,所有rank上的梯度相同,因此参数更新也相同,模型参数保持一致。
    • all_reduce操作本身是一个同步点,确保各rank在同一步骤。

2. 张量并行 (Tensor Parallelism - TP)

  • 概念:数据在所有rank上复制(或相同),模型参数沿隐层维度切分。
    • 每个rank持有模型每一层的一部分参数。
  • 实现步骤 (MLP前向传播示例)
    1. 参数切分:隐层维度 num_dim,每个rank的本地隐层维度 local_num_dim = num_dim / world_size。参数矩阵变为 num_dim x local_num_dim(或 local_num_dim x num_dim,取决于切分方式,此处示例为列切分)。
    2. 逐层计算与通信
      • 本地计算:输入激活 X (batch_size x num_dim) 与本地参数分片 W_local (num_dim x local_num_dim) 相乘,得到部分激活 X_partial (batch_size x local_num_dim)。
      • 激活通信 (All-Gather)
        1. 每个rank上的 X_partial 不同。
        2. 分配空间以收集所有rank的 X_partial
        3. 执行 all_gather 操作,将所有rank的 X_partial 收集起来。
        4. 拼接收集到的部分激活,得到完整的层输出激活 X_full (batch_size x num_dim)。
      • X_full 作为下一层的输入,重复此过程。
  • 关键点
    • 需要在层与层之间大量通信激活值,因此对GPU间的互联带宽要求很高。
    • 讲座中的示例仅实现了前向传播,未实现反向传播。

3. 流水线并行 (Pipeline Parallelism - PP)

  • 概念:模型按层切分,不同(连续的)层分配给不同rank。数据(或微批次)在ranks间顺序流动。
  • 实现步骤 (MLP前向传播示例)
    1. 层分配:例如4层网络,2个ranks,则每个rank负责2层。
    2. 参数初始化:每个rank只初始化自己负责的那些层的参数。
    3. 微批次 (Micro-batching):将一个大批次数据切分成多个小微批次,以减少流水线气泡(pipeline bubbles)。
    4. 逐微批次处理与通信
      • 对于每个微批次:
        • 当前rank(非首个rank)通过点对点通信(dist.recv)从前一个rank接收激活。
        • 在接收到的激活上执行分配给本rank的那些层的计算。
        • 将计算结果通过点对点通信(dist.send)发送给下一个rank(非末尾rank)。
  • 关键点
    • 讲座展示的是非常朴素(naive)的实现。
    • 缺失的优化
      • 通信与计算的重叠(例如,使用异步的isendirecv)。
      • 复杂的前向和后向步骤调度以进一步优化流水线。
    • 最后一个rank拥有完整前向传播的结果。反向传播则需要将梯度反向传递。

实现的局限性与进一步的复杂性

讲者强调,讲座中展示的MLP实现是“bare bones”(最基本的)。

  • 实际模型:需要应用于Transformer等更复杂的模型。
  • 通信计算重叠:未在本讲座代码中仔细处理,但对性能至关重要。
  • 代码复杂性:实际的并行库(如Megatron-LM, PyTorch FSDP)包含大量簿记(bookkeeping)代码,以处理任意模型架构、参数定位等。

JAX:一种更高层次的抽象

讲者简要提及了JAX生态系统作为PyTorch之外的另一种选择。

  • JAX理念:用户定义模型和分片策略(sharding strategy),JAX编译器负责其余的底层并行化细节。
  • 示例:Lantana(基于JAX的工具包)可以用很少代码(如10行)实现FSDP或张量并行,通过声明式地指定沿哪些维度(模型维度、嵌入维度、序列维度等)切分以及如何映射到TPU设备。
  • 对比:本课程使用PyTorch是为了让学生理解底层机制,但在实际应用中,通常不建议从零开始实现所有这些并行逻辑。

总结与展望

  • 并行化多样性:通过切分数据(批次维度)、模型宽度(隐层维度)、模型深度(层维度)或上下文长度维度来实现。
  • 重计算 vs. 存储/传输:这是一个反复出现的主题。可以在本地重计算、从内存加载(有传输成本),或从其他GPU内存加载(传输成本更高)。重计算有时更优。
  • 硬件发展与并行需求:尽管硬件不断进步(L1, HBM容量增加),但模型规模的增长速度更快,使得并行化技术将长期保持其重要性。
    > "This hierarchical structure, ever since computer systems was a thing, has always been with us and it will always be there."

Q&A 环节要点

  • Q&A: 数据并行与Batch Norm:讲者表示不确定如何在数据并行中正确处理依赖数据的Batch Norm(LLM通常使用Layer Norm,不受此影响)。只要参数初始化随机数种子相同,Layer Norm表现一致。
  • Q&A: PyTorch的高级并行抽象:PyTorch有FSDP库,但对于更定制化的分片策略,JAX目前可能提供更声明式的接口。不同生态(JAX/TPU vs. DeepSpeed/GPU)有不同优化思路。
  • Q&A: 激活检查点 (Activation Checkpointing):存在API允许指定哪些部分需要重计算,通常在大的矩阵乘法之后。
  • Q&A: 持续预训练 (Continued Training):讲座讨论的技术完全适用于从已训练的检查点继续训练模型。
  • Q&A: 硬件限制与专用硬件:GPU尺寸和密度有物理极限(功耗、散热、带宽)。专用硬件(如Grok, Cerebras)通过更多片上内存等方式针对深度学习的数据流特性进行优化,减少了传统GPU为通用计算设计的冗余。
  • Q&A: 计算图存储与执行:计算图更多是概念上的。CPU代码驱动整个过程,当调用PyTorch函数在GPU上执行时,会启动GPU内核。集体操作的参数和调度由CPU发起,具体数据移动由NCCL等库在GPU上执行。