FlashAttention V1 Deep Dive By Google Engineer | Fast and Memory-Efficient LLM Training
FlashAttention V1 通过分块计算和在线 Softmax 技术优化注意力机制,显著提升大模型训练速度与内存效率。
标签
媒体详情
- 上传日期
- 2025-06-15 21:21
- 来源
- https://www.youtube.com/watch?v=_941CUKbL6A
- 处理状态
- 已完成
- 转录状态
- 已完成
- Latest LLM Model
- gemini-2.5-pro-preview-06-05
转录
speaker 1: Hello everyone. I assume you know how transformers are everywhere in AI and almost all lms these days. The secret saws is the attention mechanism. However, it can be super slow and eating up tons of memory, especially for larger and larger models. Nowadays, that's where flash attention comes in. It's a game changer and one of the most important breakthroughs in recent years. In my opinion, it makes training large models dramatically faster and more memory efficient. Let's take a look together. Before we start, I want to say that this talk is relatively math heavy. At least some part of it is. So if you don't understand it, don't worry about it. It won't affect your ability to understand high level mechanism. It won't stop you from using flash attention modules in development. It be fun and helpful to follow, even if you don't fully understand tidr, just try to enjoy and have fun. Alright, let's start with how the cpu and memory works. As you can see, here's the pyramid. This is the hierarchical structure of a typical cpu. On the top, it has processes, registries. It is very fast. It is very expensive. And on the down, there's the tape backup, which most of us don't use already. It's very slow, but it's very cheap. So the speed of the storage goes from fast to slow as you move from the top of the pyramid to the down of the pyramid. At the same time, the cost goes from very high to very low, from the top of the pyramid to the base of the pyramid. If you have experience in data structure or algorithm interviews, you must be pretty familiar with the big old notations. It's used pretty frequently for time and space complexity. However, it could give you a false impression that memory is homogeneous when it is actually hierarchical. Like the pyramid here, it also focused on compute, but neglecting io, which can be more critical for io balance performance. Computer memory hierarchy is a system where memory is organized in two levels. If memory io speeds, costs and capacities, it typically consists of cpu registers, cache memory, main memory, secondary memories, and sometimes also tastorage. This hierarchical structure allows for fast access to frequent used data while utilizing cheaper and slower storage for less frequently accessed information. This is a block diagram of a basic cpu computer processor. The black lines indicate datflow. The red lines indicate control flow. In order to finish a computer program or instruction, the cpu needs to basically do two things. First one is fedata from different level of memories, whether it's register, which is the closest, or cache, or main memory, or flash, or disk. The second thing is, after it feshed the data, it needs to process it. So if a program or task is bottlenecked at fetching data from a different level of memories, we call it io bound. To optimize io bound task, we can improve data locality, improve the algorithm to reduce read and write, or increase bandwidth with better hardware. If a program or task is bottlenecked at actual computation or data processing, we call it cpu bound. To optimize cpu bound task, we can optimize algorithm to reduce computation and processing need. Now let's move from cpu to GPU. How does GPU memory works? Well, it's actually pretty similar. GPU memory is also hierarchical, as ram is fast iOS and expensive, and it has very small capacity. Hbm high bandwidth memory is cheaper and slower, but it's still faster than the traditional dram. The capacity of hbm is also larger than s ram. When we usually say GPU is 40g or 80g, like this one is amvidia a 140g GPU. It usually refers to the hbm. This is because most of the operation that we expect a GPU to finish relies on hbm for io bound task. Imagine if we can do more with asm and less with hbm, it will be a lot faster. This is because asram's bandwidth is more than ten times more than hbm on more than GPU. Compute speed has outpaced memory speed, and most operations in transformers are bottlenecked by memory accesses. This means if you're doing training on transformer, most likely it's io bound. This is the main motivation for a flash attention. Let's do a quick recap on transformer. This is the typical encoder decoder based architecture. In the encoder, there is a multi head attention block. In the decoder, there's a cross attention block there, a mask attention block. And if we dive into the attention block, this is basically what it does. Given a matrix of queries, keys and values. This is the attention score. It basically does a matrix multiplication between queries and keys, and then it do a rek and mask if needed, also a soft max to transform it to possibilities, and then do a matrix multiplication with values. This is the standard implementation of the attention mechanism. The problem statement is given, input sequences, query key, n value, all of them are matrix of real value, dimension is n by d. We want to compute the attention output o it's also a real values matrix. Dimension is n by d, let's say query and key matrix multiplication is s. So we have possibility equals to soft max s, and both S&P are of dimension n over n. The output equals to possibility multiplied by v the dimension is n over d. In those formulas, n is the sequence length. And for large language models, the maximum sequence length can be very large. For example, Gemini can already support 1 million, so n is a very large number. D is the head dimension for multi head attention. It's usually a lot smaller, say 64 or 128. So this is the standard attention implementation. The matrix q kb is usually stored in hbm. The first step, we load qk from hbm, and then we compute s. Afterwards, we write s to hbm. This is because s is too big to store in as ran since its siis n over n. And then we read as from hbm, compute the max possibilities, and then we write p to hbm because p is also n over n. And then we load pnb by blocks from hbm and compute the outputs, and then we write the output to hbm. Let's do a memory io analysis for the standard attention implementation. Note that we're not doing flops analysis. We're doing memory io analysis, which focus on the read and writes first step load inputs. Qkv matrices are read from hbm into s ram. So we have three hbm reads, and each of them is of n over the elements for computing attention scores, s equals to cube multiplied by k. This matrix multiplication is for an n over n matrix. And because it's too big, we need to write it back to hbn because it's too large to fit in atram for subsequent operations like softmax. This is especially for a large n and for today's large language models, you can consider this always. So the hbm writes this n square elements for writing s. Afterwards, we do compute softmax. We read the s matrix from hbm, and then we calculate the soft max probabilities. It is also n over n dimension. And for the same reason, we write it back to hbm for subsequent operations like calculating outputs. So we have n square hbm reads and n square hbm rights. And lastly, when we do the compute output, we read the p matrix and v matrix from hbm and new calculations. The final output, o of size and over d is written to a. So the read complexity is n square plus n multiply by d. The hpm writes is n multiplied by d for writing ding o in total. The io complexity is big sigma n square plus and d hpm xs. By the way, big themeans tide bound, where big o means worst case. So the total io is basically big then square, since n is usually a lot larger than D This is a high level comparison I get from the hugging phase, flash attention intro. It compares standard detenimplementation with flash attenimplementation in a pretty intuitive way. For standard attention implementation, we have to load and write multiple times from hbm to asram, and this is very costly from io perspective in order to calculate s, we have to load q and k and then write s back to hbm, because the size of s is too big, is n by n. And to calculate the softmax possibilities, we need to load s from sbm again. And after computation, we have to write p back to hbm due to the same reason is n by n. And for the output, we have to load both pmb and write o back after the calculation. In comparison, flash attention tries to avoid storing those intermediate results, especially for S&P, which is M N. In order to fit into atram's memory limits, slash attention, divide qkv into blocks. And as you can see, we have kj. Vj is the jth block of kmv, and we also have qi I mi, that is the ith block of q. Lm is intermediate results that we're going to cover in details later. In this way, we avoid writing S&P back to hbm. We just need to load kvq from hbm and do some extra calculation on intermediate states like O&M, and after some computation, we just write the final result back. In this way, we successfully avoid storing n by N, S and p matrix in hbm. Before we dive into the details of flash attention, I would like to clarify a few things briefly and then provide some extra context to help us understand the content later. First, I wanna clarify between training and inference. Flash attenmechanism and principle can be used in both training and an inference. However, in this talk, the context is mainly focusing on the training part and on the inpart. For example, flash decoding. We might cover it in later episodes. The second thing is about io analysis versus flopped analysis. Flash attention aims to improve memory io instead of flops, which is floating point operation per second. So we are analyzing memory access, read and writinstead of computation need. Remember, we did flops analysis when we went through kb caching in oinference. In my transformer video, it looks very similar, but it has nothing to do with flash attention. By the way, kv caching is used in lm inference only and not in training, because inference for lm is usually auto regressive, and then kb che is useful in that way. But in training, all the sequences processed in parallel, since we already know the input, so kvcache is a lot less useful. The first extra background I would like to talk about is per row softmax in the attention mechanism. Specifically, in scaled product attention, softmax is calculated per row of the score matrix s. Since each row corresponds to a single query vector and its relationship with all key vectors, the purpose is to determine how much attention that specific query should pay to every key. Q is the matrix of query vectors. Dimension is n by d, that is, the length of sequence by length of attention heads. K is the matrix of key vectors. Dimension is also n by d. For each query vector in q, we need to calculate the attention scores against all key vectors. In k, the resulting matrix s has dimension n by n, and in the matrix s, the I. Thsi contains the dot product scores of the I query vector qi with all key vectors from K1, K2 all the way to kn, where sij equals to qi multiply by kj. This rets the raw relevance or attention score of qi and kj, which is the I query and the J key. So softmax is applied to all rows, and the sum of all elements in a single row si equals to one represent the probability. And just to remind you, softmax definition is for a given input vector s from s one to S K. The softmax function calculates the probability pi for each element si as follows. Pi equals to e to the power of si over the exponential sum of e to the power of sk. Now, given the softmax definition, you might think the implementation should be very simple. We just need to go through x two times. The first time is we want to get the denominator, which is the exponential sum. And in the second PaaS, we calculate the actual softmax for each position, since we already have sum in the denominator. However, this naive algorithm won't fly for most of the production environment. This is because e to the power of xi will likely overflow the hardware limit. That's why we need to use safe soft max in most of the production environment. And the idea is also very straightforward. We don't want the value to overflow the hardware limit. That's why we minus the maximum x for all the x values. In this way, it will definitely not overflow. So instead of doing two paths, we have to do three paths. The first PaaS is we want to calculate the maximum value, and in the second PaaS, we want to calculate the denominator, and in the third PaaS, we calculate the softmax for each position. This is called save softmax algorithm. With these extra backgrounds, I think we're now ready to take a look at flash attention details. The essence of flash attention is doing tiling. Standard attention needs multiple hbm access, since S&P is too big for asram. Flasher attention aims to improve this by doing attention calculation with tyling, which is breaking big data into smaller tiles. This is the diagram of what is tiling. We have d by n matrix k. We have n by d matrix q. We have n by d matrix v. Note that K, Q and v are broken to blocks. The block size is not n. It's not d, it's something else that I will cover later. You might get confused. The size of the block is d by d just by looking to this graph, but it's not so in standard attention. Qkt is n by n, and we need to store this whole thing in hbm in order to avoid that. We are calculating only the blocks. We copy the single block of qk and v from hvm to srun, and then we compute the block on srun before outputting to hbm. There are several key points for you to understand tiling and flash attention. The first one is divide and conquer. The input matrices qkv are divided into smaller blocks or tiles. Next one is iterative processing. Instead of computing the full attention matrix at once, which is the dotted square flash attention, process these tiles iteratively. It loads a block of q and the block of k into the much faster s ram. The third one is partial computations with computes the attention scores and the weighted sum of values for these blocks within the Aram. The next one is online softmax. This is also very important. Flash attention recomputes the softmax normalization on the fly. As it processes more blocks, it correctly normalized the attention scores without approximation and without needing a complete qkt matrix, that is, the matrix. The last one is accumulation. The results from processing each pair of q and k blocks are accumulated to produce the final output without ever forming the full intermediate attention matrix. In hbn, we have already gone through safe softmax algorithm. However, it won't work for flash attention since safe softmax require the complete row si to compute maximum value and subtract that maximum value when calculating softmax to avoid overflow. Flash attention is doing timeans breaking kv and q into blocks to fit into sram, so it no longer have access to the whole row. Then in this case, what should we do? This is the three PaaS save softmax that we have gone through in previous slides. As you can see, the first PaaS is friendly with tiling, since for each position I it only use information about I and I minus one. The second PaaS is not friendly with tiling due to the dependency on mn, meaning it needs the largest value for m. And the third PaaS is also not friendly with styling due to the dependency on both mn and ln. So what if we can find a substitute denoted as Li hat for Li, so that Li hat only depend on positions smaller or equals to I and ln hat equals to ln? If we can build this substitute, then we can totally avoid the second PaaS and make the algorithm closer to toling compatible. Since we still have mman here, the substitute is found to be Li haat equals to this formula, and given the recurrence relationship between Li minus one hat and Li hat, this is the deduction process. If you're interested, we can easily get the new two paths saves off max algorithm. Note that this is still not tiling friendly, mainly because of M N. Here, for the first PaaS for I equals to one to n, for each position we calculate mi and Li hat, and for each position, we only need information from I position or I minus one position. And for the second PaaS, since we already have our n hat, we can calculate pi for each position with rn hat and mn, we call that s equals to Q K matrix multiplilation. P equals to soft max s and o equals to p and b matrix multiplication. We have already gone through S&P, so now it's time to get the final result, o. And in this process, we'll make it tiling friendly and improve the algorithm further. The straightforward algorithm to calculate output based on the previous two paths, save self max algorithm is as follow. The first PaaS for I equals to 12n, we calculate si mi and Li hat. In the second PaaS, we calculate pi according to ln hat and mn, and then we calculate oi according to pi and b. The two PaaS algorithm uses k row of o as example for simplicity. However, the computation of different rows are independent ded, so for other row, it should be identical. If we store the intermediate states mi pi I and Li hat, and yes, we can store them in hbm given the small footprint and io cost, then this algorithm is tifriendly. We call it online safe softmax. Can we improve this algorithm further so it has only one PaaS? The answer is yes. The main reason we still have two paths is because in the second PaaS, when we calculate oi, it has dependency with nn and lm hat, which requires the first PaaS to finish. What if we do the same thing for o and we construct oi hat that meets the following criteria? Oi hat only depends on information from position smaller or equals to I, an om hat equals to om. And yes, we can find this math formula, and this is the y hat. We can now merge two painto one if we know the recurrence relationship between A Y hat and A Y minus one hat. And again, it's pretty straightforward for us to find out the recurrence relationship. This is the math deduction detail. Pause the video and take a look if you're interested. Finally, we have the one PaaS flash attention algorithm that is tiling friendly for that single PaaS. For each position from one to n, we calculate the si first, and then we calculate mi, which is the maximum value up to ith position. And then we calculate Li hat and oi hat. And after we do n times on, hat will be the final accurate output for the k row. And this is styling currently, because for each of the calculation, it only depends on the information of I or I minus one. So now let's apply tiling to it. This is the detailed forward process for flash attention V1. It's pretty much the same as we've gone through in previous slides, only adding a few details. For example, it adds the block sizes bc and vr, and then we divide qkb into blocks. Basically, each block of q is vr by d, and each blocks of kmb is bc by d. And correspondingly, we also need to divide o into tr blocks. And then it has the outer loop and the inner loop. The outer loop is for J from one to tc. We iterate through keys and values blocks, and we load those blocks from hgbm to onchip s run. For each of the k and v blocks, we go through all the query blocks along with output blocks and l and n blocks. We load them from hbm to on chip s run m. The rest of the calculation is the same as we have gone through in previous whthe diagonal function. Here is constructing a diagonal matrix given a vector and the output bracket. Here is a saving function denoted by upper bracket x, which rounds the real number up to the nearest integer. For example, 2.7 cealing function equals to three, and the ultra loops is on k and b, and the inner loop is on q. Notice that this was later swapped for pluralism in flash attention V2, which I will cover in maybe next episode. The intermediate steps m and l are written to hbm for backwards PaaS recomputation, which I will cover in later slides. Now let's take a look at the block size before we dive into memory. Io analysis for flash attention. Vcm vr block size are picked intentionally to fit in all four blocks of qi, kj, vj and sij into sram of size m. The dimension of sij is vr by bc, and based on the previous algorithm detail, we know that bc and vr is a cealing function of this. So if we put those into this formula, vr R by bc is smaller or equal to m over four, which means we strictly control the size of s, so it won't exceed 25% of the total s ram size. You can design other block size algorithm as long as it makes sure all four matrices can fit into the s ram. Now let's take a look at the memory io analysis for flash attention. In the outer loop, each element of kmv is loaded from hpm one, so both are n by d dimensions, resulting in two multiply by n, multiply by d hpm reads. So this is big data n multiply by d reads. Next, given knv are broken into tc blocks and the outer loop, and we make tc passes in the inner loop over q and o, each PaaS loading all of q and all of o to hbm. This results in two by tc, by n, by dhbm reads. So this is bit theta and dtc. The total io complexity is big theta and dtc, plus nd d equals to big theta and dtc. We can further simplify this if we make tc a function of n and d and n. Since kj and vj needs to fit into s ram, the dimension of kj and vj is bc by d. So we have this formula. Bc is of big O M over d. Similarly, qi needs to fit into atram, and the dimension of qi is br by d. So br is also big O M over d. Similarly, sij needs to fit into s run m, and the dimension of sij is br by bc. So this is the complexity of br. Then we have tc equals to n over bc, and the complexity is big data and d over m. So the final io complexity for flash attention, V1 in the forward process is big theta n square d square over m. Comparing with the io complexity of standard attention, big theta n square, since the typical value of d is very small, it's 64 or 128, and the typical value of m is very large, it's in kb square mb, so d square is still a lot smaller than m. So flash attention requires many times fewer hbm excesses, leading to faster execution and lower memory footprints. This is a great intuition graph done by the original author in his block. So I recommend you to take a look. Let's assume the whole k matrix only have two vectors. So in their ultra loop, we're looping through all the k, which is K1 and K2. And for each PaaS, we go through all the q for each k. So when we process K1, we have s one equals to Q, K one t, and we have the sofmax function here, and we have the value here, and we get o one, the output one equals to this. And now since we have to iterate n times, in this case n equals to two, we come to K2, and in this case we have s two equals to qk two t and calculate the sofmax. And then V2, we have O2 equals to this rescaling with the correct denominator from zero, one. Imagine now n doesn't equals to two is a very large number. This rescaling to correct denominator process works the same for big ends. Just imagine, this rescaling formula gets longer and longer. Now we have gone through the forward attention process. We should already know flash attenmotivation in the sense that is reducing memory io complexity. We should also know flash attentions mechanism, including tiwing online soft max algorithm. We should be comfortable using flash attention modules and developments. Now this part about backward process is optional. Stay if you feel like it. Let's recap a little bit on what is backward propagation. The goal for training a newer network is to minimize the loss function. In this case, say five. According to the chain rule, we use back propagation to update weights for all intermediate neurons in the network as zomove forward passes done and loss function is already calculated for the output, back propagation starts at the end of the network in the gradient of the loss function. With respect to the newer network's final output is calculated first. In this case, let's call it the o final. O final is the final output, and d is the derivative, and for any given intermediate layer, it takes the succeeding layers, the o as input and calculate current layers. Updated parameters. So for the backward attention process, we have the current problem statement given loss function five, and let the output gradient be do. This is a matrix of real values, and the dimension is n by d, and do means loss functions derivative to o. We want to compute the input gradients dq, dk and db. All of them are real value matrix, and the dimension is the same n by d, dq, dk, db is loss functions derivative to q kmv, respectively. Now let's briefly go through the chain rule. We already know q kb matrixes are n by D, O is also n by d, dq dkdb is also n by d. We know s equals to q kt p equals to softmax s. Both of them are real value matrix of n by N, O equals to p and b matrix multiplication is n by d. So d phi over dv equals to d phi over do multiply by do over dv. Since o equals to p, multiply by V, D over dv is p, so we have dv equals to p. Similarly, we calculate dp equals to the v, dq equals to dsk k, and dk equals to dsq. Now we need a little extra knowledge for the derivative of a soft max function. Softmax function takes a vector input and produces a vector output. Its derivative is a Jacobian matrix. The Jacobian matrix contains all possible partial derivatives of each output element with respect to each input element, let by vector equals the soft tmax of z vector. We need to compute dy yi over dz J for all inj values. There are two cases to consider. There's is case one where I equals to J, and then there's case two where I doesn't equals to J. The math deduction details is as follow Pathe video if you need to take a look. The final output of combining both cases is this. When I equals to J, this is the derivative. When I doesn't equals to J, this is the derivative. So the full Jacobian matrix J of the soft tmax function is this. This can also be written more compactly using vector notation. It's basically diagonal matrix of y minus y multiplied by y transform. Now we all for the original problem statement, we have p equals to solve max function of s given the Jacobian matrix for y equals to solve max, s is diagonal, y minus y multiplied by y transform. We have the following formulas. It's a detailed mathematic calculation of derivative s for each position. I, J, pause the video a little bit if you need to take a look. Now, with all those extra knowledge, we should be ready for standard attention. Backward implementation. When we do the backward PaaS, we already have matrices, qk, v, do, and p. In hbm, we first load pndo from hbm and then compute db, and then we load do and v from hbm and compute dp. We need to write dp back to hbm, and then we read pdp from hbm and compute ds according to the Jacobim matrix that we have gone through. And afterwards we need to write ds back to hbm. Next, we need to load ds and k by blocks from hbm and compute dq. The last app is we load ds and q by blocks from hbm and compute dk. Now we have all dq, dk and db. This backward process is typically requiring the matrix S&P. The size of the matrix is n by n, and just like the forward PaaS, due to the size of the matrix, we need to do a lot of read and write to hbm. Recall that one of flash attengoal is to not store on square intermediate values for the backward PaaS. So what should we do in this case? Remember, when we talk about io bound and cpu bound tafor transformer training, it's io bound, which means if we can sacrifice a little bit on the cpu calculation and improve a little bit on the io, it's actually a win. So that's exactly what we doing. By storing the output o and the softmax normalization statistics M, L, we can recompute the attention matrix S&P in the backward path from blocks of qkvo and l in sm. This results in more flops due to recomputation. However, it still speeds up the backward PaaS due to the reduced hbm excesses. This is the detailed backward PaaS algorithm, and the recompute happens at this time, we recompute p with m and l, and the formula is the same as the forward PaaS. This is all the mathematic formula we used in flash attention backward PaaS. Db equals to pdso. This is pretty straightforward. We have p, we have do just through the multiplication here. This we already know from previous slides. Let's say if we define di to get this formula, then we can simplify dsij to this so we can rewrite dq and dk with the definition of big di. Pause the video a little bit if you need to take a look at the details. Now let's take a look at the io complexity. It's actually very similar to forward process. In the outer loop, each element of k and b is loaded from hbm one. Both are n by d dimensions. So this is big data, and by d hbm reads. And given k and v are broken into tc blocks in the outer loop, we make tc passes in the inner loop over Q O and do each pasonly in all of Q O and do to hpm. This results in three by tc, by n, by d, hbm reads. So it's big data n by d by tpc, and we can reuse the previous calculation for tc, the total io complexity is big data and dtc plus nd d equals to big data and dtc. So the final io analysis for forward process and backward process is the same. And as we discuss, it's a lot smaller and standard attention, since d square is usually a lot smaller than m. Alright, this is the last slide I have for flash attention V1. Hopefully this is helpful and I promise I will go through the later versions of flash attention as well as flash attenusage in inference in later episodes. If you like my video, please subscribe comments and like alright, see you later. Bye.
最新摘要 (详细摘要)
FlashAttention V1 技术深度解析
副标题: FlashAttention V1 通过分块计算和在线 Softmax 技术优化注意力机制,显著提升大模型训练速度与内存效率。
概览/核心摘要 (Executive Summary)
该讲义深入剖析了 FlashAttention V1 的核心技术原理、算法实现及其在提升大型语言模型(LLM)训练效率方面的重要性。标准注意力机制在处理长序列时,会产生巨大的中间注意力矩阵($N \times N$ 的 $S$ 和 $P$ 矩阵)。这些矩阵远超 GPU 高速 SRAM 的容量,导致频繁且耗时的 HBM(高带宽内存)读写,形成严重的 I/O 瓶颈。FlashAttention V1 针对此问题,通过两大核心技术——Tiling(分块)和Online Softmax(在线 Softmax)——进行优化。Tiling 技术将输入矩阵 $Q, K, V$ 分割成小块,在高速 SRAM 中进行迭代计算,避免了在 HBM 中存储和反复读写完整的中间矩阵。Online Softmax 是一种精巧的单遍算法,它能够在迭代处理数据块的同时,动态更新 Softmax 计算所需的统计量(如最大值和归一化分母),从而在分块处理的约束下精确完成 Softmax 计算,无需一次性访问完整数据行。这些优化使得多个计算步骤能在SRAM内“融合”执行,显著减少了对HBM的访问。因此,FlashAttention V1 将 HBM 的访问复杂度从标准注意力的 $\Theta(N^2)$ 大幅降低。在反向传播阶段,FlashAttention V1 采用重计算(Recomputation)策略,利用前向传播保存的少量统计数据(如输出 $O$ 及 Softmax 统计量 $m, l$)在 SRAM 中按需重新计算中间矩阵的块,进一步减少了 HBM 访问。尽管这会增加一定的计算量(FLOPs),但由于 I/O 效率的巨大提升,总体上仍大幅加速了训练过程并降低了内存占用,同时确保了与标准注意力完全相同的精确计算结果。
序言:学习之旅的启程
讲者首先说明,本次分享包含较多数学内容。但即便未能完全理解每一个数学推导的细节,听众依然能够掌握 FlashAttention 的高层次核心机制,并在实际开发中顺利应用 FlashAttention 模块。关键在于跟随讲解思路,享受这个富有启发性的学习过程。
第一部分:核心动机 —— 为何需要 FlashAttention?
1.1 计算机内存层级与 I/O 考量
计算机内存并非单一均质的存储单元,而是由速度、成本和容量各异的存储介质构成的层级体系。该体系通常从极速但昂贵的处理器寄存器、缓存,到速度较慢但容量大且经济的RAM(主存)、固态硬盘及更低层级的存储。这种分层设计旨在平衡高频数据的快速访问与大容量数据的经济存储。在分析算法复杂度时,常用的 O() 大O表示法有时会简化内存的层级特性,并可能忽略 I/O(输入/输出)操作对性能的实际影响,尤其是在数据密集型任务中。
程序执行的瓶颈可能在于数据获取(I/O密集型)或数据计算(计算密集型)。
* I/O 密集型 (I/O Bound): 程序性能主要受限于从内存读取或写入数据的速度。优化策略包括改善数据局部性、优化算法以减少读写次数,或提升硬件带宽。
* 计算密集型 (CPU Bound): 程序性能主要受限于处理器实际执行计算操作的速度。优化策略包括改进算法以减少计算量。
1.2 GPU 内存瓶颈:FlashAttention 的催化剂
与 CPU 系统类似,GPU 的内存系统也是分层的。以 NVIDIA A100 GPU(例如其 40G 版本)为例:
* SRAM (静态随机存取存储器): 带宽极高(如 19 TB/s),但容量非常有限(如 20 MB),成本昂贵。
* HBM (高带宽内存): 带宽远高于传统 DRAM(如 1.5 TB/s),容量也较大(如 40 GB 或 80 GB)。通常所说的 GPU 显存容量即指 HBM 容量。
* DRAM (动态随机存取存储器 - 指系统主内存): 带宽相对较低(如 12.8 GB/s),但容量可以非常大(如超过 1 TB)。
核心问题: 在现代 GPU 架构中,计算单元的性能增长速度已经超过了内存访问速度的增长。因此,对于像 Transformer这样的大型模型,其许多操作(尤其是在处理长序列时)的性能瓶颈在于内存访问,即它们是 I/O 密集型的。讲者指出:“计算速度已超过内存速度,Transformer 中的大多数操作都受到内存访问的瓶颈限制。”
FlashAttention 的核心动机: 正是为了解决这一显著的 I/O 瓶颈。其目标是通过最大化利用高速 SRAM 进行计算,从而大幅减少对相对较慢的 HBM 的数据读写次数。如果能在 SRAM 中完成更多计算,减少对 HBM 的依赖,整体速度将得到显著提升。
第二部分:标准注意力机制的瓶颈
2.1 Transformer 与注意力机制回顾
Transformer 模型的核心组件之一是其自注意力(Self-Attention)机制(或更广义的注意力机制)。其标准计算公式为:
$$Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
其中 $Q, K, V$ 分别代表查询(Query)、键(Key)和值(Value)矩阵,$d_k$ 是键向量的维度。
2.2 标准实现的 I/O 瓶颈
对于给定的输入矩阵 $Q, K, V$(假设维度均为 $N \times d$,其中 $N$ 为序列长度,$d$ 为头维度),标准注意力机制的典型实现步骤如下:
1. 从 HBM 加载 $Q$ 和 $K$ 到 SRAM,计算注意力得分矩阵 $S = QK^T$。由于 $S$ 的维度是 $N \times N$,当 $N$ 很大时, $S$ 矩阵会非常巨大,通常远超 SRAM 容量,因此必须将其写回 HBM。
2. 从 HBM 读出 $S$ 到 SRAM,计算 $P = \text{softmax}(S)$。$P$ 矩阵的维度同样是 $N \times N$,也需要写回 HBM。
3. 从 HBM 加载 $P$ 和 $V$ 到 SRAM(通常分块进行),计算最终输出 $O = PV$(维度 $N \times d$),并将 $O$ 写回 HBM。
问题分析:
* 在大型语言模型中,序列长度 $N$ 可以非常大(例如,达到数万甚至百万级别)。这使得中间矩阵 $S$ 和 $P$(大小为 $N \times N$)变得异常庞大。
* 由于 GPU 上的高速 SRAM 容量有限(通常只有几十MB),这些巨大的中间矩阵无法完全存放在 SRAM 中,导致了在 SRAM 和 HBM 之间的大量数据传输。
* $d$(头维度)通常相对较小(例如 64 或 128)。
I/O 复杂度分析:
标准注意力实现的 HBM 访问总量级为 $\Theta(N^2 + Nd)$。由于在典型大模型场景下 $N \gg d$,因此 $N^2$ 这一项成为主要的 I/O 瓶颈。讲者使用 $\Theta$ (Theta)符号表示紧密界限。
第三部分:FlashAttention V1 的核心原理
FlashAttention 通过两大关键技术来解决标准注意力机制的 I/O 瓶颈:Tiling (分块) 和 Online Softmax (在线 Softmax)。讲者澄清,FlashAttention 的原理既可用于模型训练也可用于推理,但本次讨论主要聚焦于训练场景。
3.1 Tiling (分块):化整为零,SRAM 内高效计算
Tiling 的核心思想是“分而治之”。它将庞大的输入矩阵 $Q, K, V$ 分割成许多更小的块(Tiles)。算法随后迭代地处理这些小块,而不是一次性计算整个注意力矩阵。
具体流程概述如下:
1. 从 HBM 中加载 $Q$ 的一个块($Q_i$)以及 $K$ 和 $V$ 的对应块($K_j, V_j$)到高速的 SRAM 中。
2. 在 SRAM 内部完成这些块之间的注意力得分计算和值的加权求和等一系列操作。
3. 通过 Online Softmax 技术(详见下节)累积和修正中间结果。
这种在SRAM内部完成一个区块的相对完整的注意力相关计算(包括得分、Softmax和值加权)再写回必要结果的方式,避免了在 HBM 中实例化和存储完整的、巨大的 $N \times N$ 中间矩阵 $S$ 和 $P$。
3.2 Online Softmax (在线 Softmax):分块约束下的精确计算
标准的 Softmax 实现(尤其是 Safe Softmax,通过减去最大值避免数值溢出)通常需要访问到输入向量的全部元素才能计算最大值和归一化分母。例如,Safe Softmax 需要三次遍历数据:第一次找到行最大值,第二次计算归一化分母,第三次计算每个元素的概率。然而,在 Tiling 模式下,每次只有一小块数据被加载到 SRAM 中,无法获取完整行数据,因此标准 Safe Softmax 算法不再适用。
FlashAttention 引入了 Online Safe Softmax 算法(或简称 Online Softmax)。这是一种单遍(one-pass)处理数据的算法,它可以在迭代处理数据块(或逐个元素)的同时,动态地、精确地更新 Softmax 计算所需的统计量,主要是当前 পর্যন্ত的最大值 $m$ 和当前的归一化分母(或其相关统计量)$l$。
讲者逐步推导了如何从需要多次遍历的 Safe Softmax 改进到适用于分块的 Online Softmax。其核心在于为中间统计量 $m$(最大值)、$l$(归一化因子)以及最终的部分输出 $o$ 建立递推关系,使得每一步的计算只依赖于当前处理的数据块和上一步迭代的状态。以下公式示意了在处理第 $i$ 个数据块(或迭代步骤)时,这些统计量如何基于前一状态($i-1$)和当前块数据进行更新(符号与原讲义保持一致性,$\tilde{l}$ 和 $\tilde{o}$ 代表在线更新的量):
1. 更新当前最大值 $m_i$:
$m_i \leftarrow \max(m_{i-1}, \text{max_val_in_current_block_scores})$
2. 更新当前归一化因子 $\tilde{l}i$:
$\tilde{l}_i \leftarrow \tilde{l}$}e^{m_{i-1}-m_i} + \sum_{k \in \text{current_block}} e^{s_k-m_i
3. 更新当前加权输出 $\tilde{o}i$:
$\tilde{o}_i \leftarrow \tilde{o}}\cdot\frac{\tilde{l{i-1}e^{m}-m_i}}{\tilde{li} + \frac{\sum$}} (e^{s_k-m_i} \cdot V_k)}{\tilde{l}_i
通过这种方式,当所有数据块处理完毕后,得到的最终输出 $\tilde{o}_N$ 与标准注意力机制计算出的结果完全相同,没有任何近似。
3.3 FlashAttention 算法总览 (前向传播)
结合 Tiling 和 Online Softmax,FlashAttention 的前向传播算法大致如下:
1. 初始化: 在 HBM 中初始化大小为 $N \times d$ 的输出矩阵 $O$ 为零,以及两个大小为 $N$ 的辅助向量 $l$ (用于存储 Softmax 的归一化分母的中间状态) 和 $m$ (用于存储 Softmax 的行最大值的中间状态)。
2. 分块: 将 $Q, K, V$ 矩阵按预设的块大小 $B_R, B_C$ 分割。例如,$Q$ 的块 $Q_i$ 维度为 $B_R \times d$,$K, V$ 的块 $K_j, V_j$ 维度为 $B_C \times d$。
3. 外层循环: 遍历 $K$ 和 $V$ 的各个块组合 $(K_j, V_j)$。在每次迭代中,将 $K_j, V_j$ 从 HBM 加载到 SRAM。设 $K,V$ 被分为 $T_c = N/B_C$ 个块。
4. 内层循环: 对于当前 SRAM 中的 $(K_j, V_j)$,遍历 $Q$ 的所有块 $Q_i$。在每次迭代中,将 $Q_i$ 以及对应输出块 $O_i$ 和辅助统计量 $l_i, m_i$ 从 HBM 加载到 SRAM。
5. SRAM 内计算: 在 SRAM 中,针对当前块 $Q_i, K_j, V_j$ 以及 $O_i, l_i, m_i$ 的旧值,执行以下操作:
* 计算得分块 $S_{ij} = Q_i K_j^T$ (维度 $B_R \times B_C$)。
* 使用 Online Softmax 的递推公式,结合 $S_{ij}$ 和 $V_j$ 来更新 $O_i$ 以及统计量 $l_i, m_i$ 的新值。
这一系列在SRAM内完成的针对数据块的计算步骤,避免了多次读写HBM,体现了算子融合(kernel fusion)的思想,即多个逻辑运算被合并到单个计算核心中执行,以提升效率。
6. 写回: 将更新后的 $O_i, l_i, m_i$ 从 SRAM 写回 HBM。
讲者提及,统计量 $m$ 和 $l$ 被写回 HBM,是为了在反向传播阶段用于重计算。同时指出,FlashAttention V2 中内外循环的顺序有所调整,以进一步提升并行性。
3.4 算法直觉:输出的动态累积更新
讲者通过一个简化的例子(假设 $K$ 矩阵只有两个块 $K^{(1)}, K^{(2)}$,对应 $V^{(1)}, V^{(2)}$)来直观解释输出 $O$ 是如何被迭代更新的。当处理完第一个块 $(K^{(1)}, V^{(1)})$ 后,会得到一个部分输出 $O^{(1)}$ 和对应的统计量 $l^{(1)}, m^{(1)}$。当处理第二个块 $(K^{(2)}, V^{(2)})$ 时,算法会首先根据新的最大值 $m^{(2)}$ 对之前计算出的 $O^{(1)}$ 和 $l^{(1)}$ 进行重新缩放 (Rescaling),使其与当前的累积统计量对齐。然后,再加上当前块 $(K^{(2)}, V^{(2)})$ 基于 $S_{i2}$ 计算出的贡献值。这个“重新缩放并累加”的过程,正是 Online Softmax 输出更新公式 $\tilde{o}_i$ 的直观体现,确保了最终结果的精确性。
3.5 I/O 复杂度分析:显著降低 HBM 访问
- 块大小选择: 块大小 $B_C$ (用于 $K,V$) 和 $B_R$ (用于 $Q$) 的选择需要保证 $Q_i, K_j, V_j$ 以及中间计算的 $S_{ij}$ 等都能完全放入 SRAM (假设大小为 $M$)。一个常见的约束是 $S_{ij}$ (维度 $B_R \times B_C$) 的大小不超过 $M/4$。由此可以推导出 $B_C = O(M/d)$ 和 $B_R = O(M/d)$。
- HBM 访问分析:
- 外层循环加载 $K, V$ 的所有块一次:总共 $\Theta(Nd)$ 的 HBM 读取。
- 内层循环中,对于 $K,V$ 的每一块(共 $T_c = N/B_C$ 块),$Q$ 的所有块和 $O$ 的所有块以及 $l,m$ 向量都会被完整地读写一次。这部分的 HBM 访问量约为 $T_c \times \Theta(Nd)$。
- 将 $B_C = O(M/d)$ 代入 $T_c$,总的 HBM 访问复杂度近似为 $\Theta(N d \cdot (Nd/M) + Nd) = \Theta(N^2 d^2 / M + Nd)$。
- 对比: 标准注意力机制的 I/O 复杂度是 $\Theta(N^2 + Nd)$。由于在典型大模型中 $d$ (如 64 或 128) 相对较小,而 SRAM 大小 $M$ 远大于 $d^2$ (例如 $M$ 为几十MB量级, $d^2$ 为数千到一万多),因此 $d^2/M \ll 1$。这意味着 FlashAttention 的 HBM 访问次数远少于标准注意力,尤其是在 $N$ 很大时,$N^2 d^2 / M$ 远小于 $N^2$。
第四部分:FlashAttention 的反向传播
为了进行模型训练,还需要高效的反向传播算法来计算梯度。
4.1 标准反向传播的挑战
标准注意力机制的反向传播同样面临严重的 I/O 瓶颈。它通常需要读取在前向传播过程中计算并存储在 HBM 中的巨大的 $N \times N$ 矩阵 $P$ (Softmax 输出概率) 和/或 $S$ (注意力得分),这导致了大量的 HBM 访问。讲者简要回顾了梯度计算中涉及的链式法则以及 Softmax 函数的导数(雅可比矩阵)。
4.2 FlashAttention 的高效反向传播:重计算策略
FlashAttention 的反向传播巧妙地避免了存储和读取巨大的中间矩阵 $S$ 和 $P$。它采用了一种重计算 (Recomputation) 的策略。
* 核心思想: Transformer 模型的训练过程通常是 I/O 瓶颈而非计算瓶颈。因此,通过增加一些额外的计算量(FLOPs)来换取 HBM 访问次数的大幅减少,从而提升总体效率是值得的。
* 具体过程:
1. 在前向传播时,除了最终的输出 $O$ 之外,FlashAttention 只保存了两个相对很小的统计向量 $m$ (每行Softmax前的最大值) 和 $l$ (每行Softmax的归一化分母)。这些向量的大小都是 $N$。
2. 在反向传播过程中,当需要用到 $S$ 和 $P$ 的某个块进行梯度计算时,FlashAttention 会利用前向传播保存下来的 $O, m, l$ 以及原始的输入块 $Q_i, K_j, V_j$,在 SRAM 中即时地、按需地重新计算出 $S_{ij}$ 和 $P_{ij}$ 的相应块。
* 结果: 尽管重计算增加了浮点运算的总量,但由于极大地减少了对慢速 HBM 的访问,反向传播的总体速度仍然比标准实现快得多。
* I/O 复杂度: FlashAttention 反向传播的 I/O 复杂度与前向传播类似,也为 $\Theta(N^2 d^2 / M + Nd)$,远优于标准反向传播的 $\Theta(N^2)$。
结论
FlashAttention 是一项针对 Transformer 注意力机制的革命性优化技术。它通过分块 (Tiling) 处理、精巧的在线 Softmax (Online Softmax) 数学技巧以及由此实现的算子融合 (Fused Kernels) 思想,有效地解决了标准注意力机制在处理长序列时面临的 I/O 瓶颈问题。
- 核心优势: 大幅减少了对 GPU 高带宽内存 (HBM) 的读写次数,从而显著提升了模型训练和推理的速度,并有效降低了显存占用。
- 实现方式: 其关键在于避免在 HBM 中实例化和存储庞大的 $N \times N$ 注意力得分矩阵和概率矩阵,而是通过在高速 SRAM 中对数据进行分块计算,并利用 Online Softmax 算法在迭代过程中精确地、动态地完成归一化。
- 成果: FlashAttention 能够在不引入任何近似计算的前提下,得到与标准注意力机制完全相同的精确结果。它已成为现代大型语言模型训练和部署中一项至关重要的基础优化技术。
讲者最后提及,未来可能会分享 FlashAttention 的后续版本(如 FlashAttention V2)以及其在推理场景下的特定应用(如 FlashDecoding)。