speaker 1: Hello everyone. I assume you know how transformers are everywhere in AI and almost all lms these days. The secret saws is the attention mechanism. However, it can be super slow and eating up tons of memory, especially for larger and larger models. Nowadays, that's where flash attention comes in. It's a game changer and one of the most important breakthroughs in recent years. In my opinion, it makes training large models dramatically faster and more memory efficient. Let's take a look together. Before we start, I want to say that this talk is relatively math heavy. At least some part of it is. So if you don't understand it, don't worry about it. It won't affect your ability to understand high level mechanism. It won't stop you from using flash attention modules in development. It be fun and helpful to follow, even if you don't fully understand tidr, just try to enjoy and have fun. Alright, let's start with how the cpu and memory works. As you can see, here's the pyramid. This is the hierarchical structure of a typical cpu. On the top, it has processes, registries. It is very fast. It is very expensive. And on the down, there's the tape backup, which most of us don't use already. It's very slow, but it's very cheap. So the speed of the storage goes from fast to slow as you move from the top of the pyramid to the down of the pyramid. At the same time, the cost goes from very high to very low, from the top of the pyramid to the base of the pyramid. If you have experience in data structure or algorithm interviews, you must be pretty familiar with the big old notations. It's used pretty frequently for time and space complexity. However, it could give you a false impression that memory is homogeneous when it is actually hierarchical. Like the pyramid here, it also focused on compute, but neglecting io, which can be more critical for io balance performance. Computer memory hierarchy is a system where memory is organized in two levels. If memory io speeds, costs and capacities, it typically consists of cpu registers, cache memory, main memory, secondary memories, and sometimes also tastorage. This hierarchical structure allows for fast access to frequent used data while utilizing cheaper and slower storage for less frequently accessed information. This is a block diagram of a basic cpu computer processor. The black lines indicate datflow. The red lines indicate control flow. In order to finish a computer program or instruction, the cpu needs to basically do two things. First one is fedata from different level of memories, whether it's register, which is the closest, or cache, or main memory, or flash, or disk. The second thing is, after it feshed the data, it needs to process it. So if a program or task is bottlenecked at fetching data from a different level of memories, we call it io bound. To optimize io bound task, we can improve data locality, improve the algorithm to reduce read and write, or increase bandwidth with better hardware. If a program or task is bottlenecked at actual computation or data processing, we call it cpu bound. To optimize cpu bound task, we can optimize algorithm to reduce computation and processing need. Now let's move from cpu to GPU. How does GPU memory works? Well, it's actually pretty similar. GPU memory is also hierarchical, as ram is fast iOS and expensive, and it has very small capacity. Hbm high bandwidth memory is cheaper and slower, but it's still faster than the traditional dram. The capacity of hbm is also larger than s ram. When we usually say GPU is 40g or 80g, like this one is amvidia a 140g GPU. It usually refers to the hbm. This is because most of the operation that we expect a GPU to finish relies on hbm for io bound task. Imagine if we can do more with asm and less with hbm, it will be a lot faster. This is because asram's bandwidth is more than ten times more than hbm on more than GPU. Compute speed has outpaced memory speed, and most operations in transformers are bottlenecked by memory accesses. This means if you're doing training on transformer, most likely it's io bound. This is the main motivation for a flash attention. Let's do a quick recap on transformer. This is the typical encoder decoder based architecture. In the encoder, there is a multi head attention block. In the decoder, there's a cross attention block there, a mask attention block. And if we dive into the attention block, this is basically what it does. Given a matrix of queries, keys and values. This is the attention score. It basically does a matrix multiplication between queries and keys, and then it do a rek and mask if needed, also a soft max to transform it to possibilities, and then do a matrix multiplication with values. This is the standard implementation of the attention mechanism. The problem statement is given, input sequences, query key, n value, all of them are matrix of real value, dimension is n by d. We want to compute the attention output o it's also a real values matrix. Dimension is n by d, let's say query and key matrix multiplication is s. So we have possibility equals to soft max s, and both S&P are of dimension n over n. The output equals to possibility multiplied by v the dimension is n over d. In those formulas, n is the sequence length. And for large language models, the maximum sequence length can be very large. For example, Gemini can already support 1 million, so n is a very large number. D is the head dimension for multi head attention. It's usually a lot smaller, say 64 or 128. So this is the standard attention implementation. The matrix q kb is usually stored in hbm. The first step, we load qk from hbm, and then we compute s. Afterwards, we write s to hbm. This is because s is too big to store in as ran since its siis n over n. And then we read as from hbm, compute the max possibilities, and then we write p to hbm because p is also n over n. And then we load pnb by blocks from hbm and compute the outputs, and then we write the output to hbm. Let's do a memory io analysis for the standard attention implementation. Note that we're not doing flops analysis. We're doing memory io analysis, which focus on the read and writes first step load inputs. Qkv matrices are read from hbm into s ram. So we have three hbm reads, and each of them is of n over the elements for computing attention scores, s equals to cube multiplied by k. This matrix multiplication is for an n over n matrix. And because it's too big, we need to write it back to hbn because it's too large to fit in atram for subsequent operations like softmax. This is especially for a large n and for today's large language models, you can consider this always. So the hbm writes this n square elements for writing s. Afterwards, we do compute softmax. We read the s matrix from hbm, and then we calculate the soft max probabilities. It is also n over n dimension. And for the same reason, we write it back to hbm for subsequent operations like calculating outputs. So we have n square hbm reads and n square hbm rights. And lastly, when we do the compute output, we read the p matrix and v matrix from hbm and new calculations. The final output, o of size and over d is written to a. So the read complexity is n square plus n multiply by d. The hpm writes is n multiplied by d for writing ding o in total. The io complexity is big sigma n square plus and d hpm xs. By the way, big themeans tide bound, where big o means worst case. So the total io is basically big then square, since n is usually a lot larger than D This is a high level comparison I get from the hugging phase, flash attention intro. It compares standard detenimplementation with flash attenimplementation in a pretty intuitive way. For standard attention implementation, we have to load and write multiple times from hbm to asram, and this is very costly from io perspective in order to calculate s, we have to load q and k and then write s back to hbm, because the size of s is too big, is n by n. And to calculate the softmax possibilities, we need to load s from sbm again. And after computation, we have to write p back to hbm due to the same reason is n by n. And for the output, we have to load both pmb and write o back after the calculation. In comparison, flash attention tries to avoid storing those intermediate results, especially for S&P, which is M N. In order to fit into atram's memory limits, slash attention, divide qkv into blocks. And as you can see, we have kj. Vj is the jth block of kmv, and we also have qi I mi, that is the ith block of q. Lm is intermediate results that we're going to cover in details later. In this way, we avoid writing S&P back to hbm. We just need to load kvq from hbm and do some extra calculation on intermediate states like O&M, and after some computation, we just write the final result back. In this way, we successfully avoid storing n by N, S and p matrix in hbm. Before we dive into the details of flash attention, I would like to clarify a few things briefly and then provide some extra context to help us understand the content later. First, I wanna clarify between training and inference. Flash attenmechanism and principle can be used in both training and an inference. However, in this talk, the context is mainly focusing on the training part and on the inpart. For example, flash decoding. We might cover it in later episodes. The second thing is about io analysis versus flopped analysis. Flash attention aims to improve memory io instead of flops, which is floating point operation per second. So we are analyzing memory access, read and writinstead of computation need. Remember, we did flops analysis when we went through kb caching in oinference. In my transformer video, it looks very similar, but it has nothing to do with flash attention. By the way, kv caching is used in lm inference only and not in training, because inference for lm is usually auto regressive, and then kb che is useful in that way. But in training, all the sequences processed in parallel, since we already know the input, so kvcache is a lot less useful. The first extra background I would like to talk about is per row softmax in the attention mechanism. Specifically, in scaled product attention, softmax is calculated per row of the score matrix s. Since each row corresponds to a single query vector and its relationship with all key vectors, the purpose is to determine how much attention that specific query should pay to every key. Q is the matrix of query vectors. Dimension is n by d, that is, the length of sequence by length of attention heads. K is the matrix of key vectors. Dimension is also n by d. For each query vector in q, we need to calculate the attention scores against all key vectors. In k, the resulting matrix s has dimension n by n, and in the matrix s, the I. Thsi contains the dot product scores of the I query vector qi with all key vectors from K1, K2 all the way to kn, where sij equals to qi multiply by kj. This rets the raw relevance or attention score of qi and kj, which is the I query and the J key. So softmax is applied to all rows, and the sum of all elements in a single row si equals to one represent the probability. And just to remind you, softmax definition is for a given input vector s from s one to S K. The softmax function calculates the probability pi for each element si as follows. Pi equals to e to the power of si over the exponential sum of e to the power of sk. Now, given the softmax definition, you might think the implementation should be very simple. We just need to go through x two times. The first time is we want to get the denominator, which is the exponential sum. And in the second PaaS, we calculate the actual softmax for each position, since we already have sum in the denominator. However, this naive algorithm won't fly for most of the production environment. This is because e to the power of xi will likely overflow the hardware limit. That's why we need to use safe soft max in most of the production environment. And the idea is also very straightforward. We don't want the value to overflow the hardware limit. That's why we minus the maximum x for all the x values. In this way, it will definitely not overflow. So instead of doing two paths, we have to do three paths. The first PaaS is we want to calculate the maximum value, and in the second PaaS, we want to calculate the denominator, and in the third PaaS, we calculate the softmax for each position. This is called save softmax algorithm. With these extra backgrounds, I think we're now ready to take a look at flash attention details. The essence of flash attention is doing tiling. Standard attention needs multiple hbm access, since S&P is too big for asram. Flasher attention aims to improve this by doing attention calculation with tyling, which is breaking big data into smaller tiles. This is the diagram of what is tiling. We have d by n matrix k. We have n by d matrix q. We have n by d matrix v. Note that K, Q and v are broken to blocks. The block size is not n. It's not d, it's something else that I will cover later. You might get confused. The size of the block is d by d just by looking to this graph, but it's not so in standard attention. Qkt is n by n, and we need to store this whole thing in hbm in order to avoid that. We are calculating only the blocks. We copy the single block of qk and v from hvm to srun, and then we compute the block on srun before outputting to hbm. There are several key points for you to understand tiling and flash attention. The first one is divide and conquer. The input matrices qkv are divided into smaller blocks or tiles. Next one is iterative processing. Instead of computing the full attention matrix at once, which is the dotted square flash attention, process these tiles iteratively. It loads a block of q and the block of k into the much faster s ram. The third one is partial computations with computes the attention scores and the weighted sum of values for these blocks within the Aram. The next one is online softmax. This is also very important. Flash attention recomputes the softmax normalization on the fly. As it processes more blocks, it correctly normalized the attention scores without approximation and without needing a complete qkt matrix, that is, the matrix. The last one is accumulation. The results from processing each pair of q and k blocks are accumulated to produce the final output without ever forming the full intermediate attention matrix. In hbn, we have already gone through safe softmax algorithm. However, it won't work for flash attention since safe softmax require the complete row si to compute maximum value and subtract that maximum value when calculating softmax to avoid overflow. Flash attention is doing timeans breaking kv and q into blocks to fit into sram, so it no longer have access to the whole row. Then in this case, what should we do? This is the three PaaS save softmax that we have gone through in previous slides. As you can see, the first PaaS is friendly with tiling, since for each position I it only use information about I and I minus one. The second PaaS is not friendly with tiling due to the dependency on mn, meaning it needs the largest value for m. And the third PaaS is also not friendly with styling due to the dependency on both mn and ln. So what if we can find a substitute denoted as Li hat for Li, so that Li hat only depend on positions smaller or equals to I and ln hat equals to ln? If we can build this substitute, then we can totally avoid the second PaaS and make the algorithm closer to toling compatible. Since we still have mman here, the substitute is found to be Li haat equals to this formula, and given the recurrence relationship between Li minus one hat and Li hat, this is the deduction process. If you're interested, we can easily get the new two paths saves off max algorithm. Note that this is still not tiling friendly, mainly because of M N. Here, for the first PaaS for I equals to one to n, for each position we calculate mi and Li hat, and for each position, we only need information from I position or I minus one position. And for the second PaaS, since we already have our n hat, we can calculate pi for each position with rn hat and mn, we call that s equals to Q K matrix multiplilation. P equals to soft max s and o equals to p and b matrix multiplication. We have already gone through S&P, so now it's time to get the final result, o. And in this process, we'll make it tiling friendly and improve the algorithm further. The straightforward algorithm to calculate output based on the previous two paths, save self max algorithm is as follow. The first PaaS for I equals to 12n, we calculate si mi and Li hat. In the second PaaS, we calculate pi according to ln hat and mn, and then we calculate oi according to pi and b. The two PaaS algorithm uses k row of o as example for simplicity. However, the computation of different rows are independent ded, so for other row, it should be identical. If we store the intermediate states mi pi I and Li hat, and yes, we can store them in hbm given the small footprint and io cost, then this algorithm is tifriendly. We call it online safe softmax. Can we improve this algorithm further so it has only one PaaS? The answer is yes. The main reason we still have two paths is because in the second PaaS, when we calculate oi, it has dependency with nn and lm hat, which requires the first PaaS to finish. What if we do the same thing for o and we construct oi hat that meets the following criteria? Oi hat only depends on information from position smaller or equals to I, an om hat equals to om. And yes, we can find this math formula, and this is the y hat. We can now merge two painto one if we know the recurrence relationship between A Y hat and A Y minus one hat. And again, it's pretty straightforward for us to find out the recurrence relationship. This is the math deduction detail. Pause the video and take a look if you're interested. Finally, we have the one PaaS flash attention algorithm that is tiling friendly for that single PaaS. For each position from one to n, we calculate the si first, and then we calculate mi, which is the maximum value up to ith position. And then we calculate Li hat and oi hat. And after we do n times on, hat will be the final accurate output for the k row. And this is styling currently, because for each of the calculation, it only depends on the information of I or I minus one. So now let's apply tiling to it. This is the detailed forward process for flash attention V1. It's pretty much the same as we've gone through in previous slides, only adding a few details. For example, it adds the block sizes bc and vr, and then we divide qkb into blocks. Basically, each block of q is vr by d, and each blocks of kmb is bc by d. And correspondingly, we also need to divide o into tr blocks. And then it has the outer loop and the inner loop. The outer loop is for J from one to tc. We iterate through keys and values blocks, and we load those blocks from hgbm to onchip s run. For each of the k and v blocks, we go through all the query blocks along with output blocks and l and n blocks. We load them from hbm to on chip s run m. The rest of the calculation is the same as we have gone through in previous whthe diagonal function. Here is constructing a diagonal matrix given a vector and the output bracket. Here is a saving function denoted by upper bracket x, which rounds the real number up to the nearest integer. For example, 2.7 cealing function equals to three, and the ultra loops is on k and b, and the inner loop is on q. Notice that this was later swapped for pluralism in flash attention V2, which I will cover in maybe next episode. The intermediate steps m and l are written to hbm for backwards PaaS recomputation, which I will cover in later slides. Now let's take a look at the block size before we dive into memory. Io analysis for flash attention. Vcm vr block size are picked intentionally to fit in all four blocks of qi, kj, vj and sij into sram of size m. The dimension of sij is vr by bc, and based on the previous algorithm detail, we know that bc and vr is a cealing function of this. So if we put those into this formula, vr R by bc is smaller or equal to m over four, which means we strictly control the size of s, so it won't exceed 25% of the total s ram size. You can design other block size algorithm as long as it makes sure all four matrices can fit into the s ram. Now let's take a look at the memory io analysis for flash attention. In the outer loop, each element of kmv is loaded from hpm one, so both are n by d dimensions, resulting in two multiply by n, multiply by d hpm reads. So this is big data n multiply by d reads. Next, given knv are broken into tc blocks and the outer loop, and we make tc passes in the inner loop over q and o, each PaaS loading all of q and all of o to hbm. This results in two by tc, by n, by dhbm reads. So this is bit theta and dtc. The total io complexity is big theta and dtc, plus nd d equals to big theta and dtc. We can further simplify this if we make tc a function of n and d and n. Since kj and vj needs to fit into s ram, the dimension of kj and vj is bc by d. So we have this formula. Bc is of big O M over d. Similarly, qi needs to fit into atram, and the dimension of qi is br by d. So br is also big O M over d. Similarly, sij needs to fit into s run m, and the dimension of sij is br by bc. So this is the complexity of br. Then we have tc equals to n over bc, and the complexity is big data and d over m. So the final io complexity for flash attention, V1 in the forward process is big theta n square d square over m. Comparing with the io complexity of standard attention, big theta n square, since the typical value of d is very small, it's 64 or 128, and the typical value of m is very large, it's in kb square mb, so d square is still a lot smaller than m. So flash attention requires many times fewer hbm excesses, leading to faster execution and lower memory footprints. This is a great intuition graph done by the original author in his block. So I recommend you to take a look. Let's assume the whole k matrix only have two vectors. So in their ultra loop, we're looping through all the k, which is K1 and K2. And for each PaaS, we go through all the q for each k. So when we process K1, we have s one equals to Q, K one t, and we have the sofmax function here, and we have the value here, and we get o one, the output one equals to this. And now since we have to iterate n times, in this case n equals to two, we come to K2, and in this case we have s two equals to qk two t and calculate the sofmax. And then V2, we have O2 equals to this rescaling with the correct denominator from zero, one. Imagine now n doesn't equals to two is a very large number. This rescaling to correct denominator process works the same for big ends. Just imagine, this rescaling formula gets longer and longer. Now we have gone through the forward attention process. We should already know flash attenmotivation in the sense that is reducing memory io complexity. We should also know flash attentions mechanism, including tiwing online soft max algorithm. We should be comfortable using flash attention modules and developments. Now this part about backward process is optional. Stay if you feel like it. Let's recap a little bit on what is backward propagation. The goal for training a newer network is to minimize the loss function. In this case, say five. According to the chain rule, we use back propagation to update weights for all intermediate neurons in the network as zomove forward passes done and loss function is already calculated for the output, back propagation starts at the end of the network in the gradient of the loss function. With respect to the newer network's final output is calculated first. In this case, let's call it the o final. O final is the final output, and d is the derivative, and for any given intermediate layer, it takes the succeeding layers, the o as input and calculate current layers. Updated parameters. So for the backward attention process, we have the current problem statement given loss function five, and let the output gradient be do. This is a matrix of real values, and the dimension is n by d, and do means loss functions derivative to o. We want to compute the input gradients dq, dk and db. All of them are real value matrix, and the dimension is the same n by d, dq, dk, db is loss functions derivative to q kmv, respectively. Now let's briefly go through the chain rule. We already know q kb matrixes are n by D, O is also n by d, dq dkdb is also n by d. We know s equals to q kt p equals to softmax s. Both of them are real value matrix of n by N, O equals to p and b matrix multiplication is n by d. So d phi over dv equals to d phi over do multiply by do over dv. Since o equals to p, multiply by V, D over dv is p, so we have dv equals to p. Similarly, we calculate dp equals to the v, dq equals to dsk k, and dk equals to dsq. Now we need a little extra knowledge for the derivative of a soft max function. Softmax function takes a vector input and produces a vector output. Its derivative is a Jacobian matrix. The Jacobian matrix contains all possible partial derivatives of each output element with respect to each input element, let by vector equals the soft tmax of z vector. We need to compute dy yi over dz J for all inj values. There are two cases to consider. There's is case one where I equals to J, and then there's case two where I doesn't equals to J. The math deduction details is as follow Pathe video if you need to take a look. The final output of combining both cases is this. When I equals to J, this is the derivative. When I doesn't equals to J, this is the derivative. So the full Jacobian matrix J of the soft tmax function is this. This can also be written more compactly using vector notation. It's basically diagonal matrix of y minus y multiplied by y transform. Now we all for the original problem statement, we have p equals to solve max function of s given the Jacobian matrix for y equals to solve max, s is diagonal, y minus y multiplied by y transform. We have the following formulas. It's a detailed mathematic calculation of derivative s for each position. I, J, pause the video a little bit if you need to take a look. Now, with all those extra knowledge, we should be ready for standard attention. Backward implementation. When we do the backward PaaS, we already have matrices, qk, v, do, and p. In hbm, we first load pndo from hbm and then compute db, and then we load do and v from hbm and compute dp. We need to write dp back to hbm, and then we read pdp from hbm and compute ds according to the Jacobim matrix that we have gone through. And afterwards we need to write ds back to hbm. Next, we need to load ds and k by blocks from hbm and compute dq. The last app is we load ds and q by blocks from hbm and compute dk. Now we have all dq, dk and db. This backward process is typically requiring the matrix S&P. The size of the matrix is n by n, and just like the forward PaaS, due to the size of the matrix, we need to do a lot of read and write to hbm. Recall that one of flash attengoal is to not store on square intermediate values for the backward PaaS. So what should we do in this case? Remember, when we talk about io bound and cpu bound tafor transformer training, it's io bound, which means if we can sacrifice a little bit on the cpu calculation and improve a little bit on the io, it's actually a win. So that's exactly what we doing. By storing the output o and the softmax normalization statistics M, L, we can recompute the attention matrix S&P in the backward path from blocks of qkvo and l in sm. This results in more flops due to recomputation. However, it still speeds up the backward PaaS due to the reduced hbm excesses. This is the detailed backward PaaS algorithm, and the recompute happens at this time, we recompute p with m and l, and the formula is the same as the forward PaaS. This is all the mathematic formula we used in flash attention backward PaaS. Db equals to pdso. This is pretty straightforward. We have p, we have do just through the multiplication here. This we already know from previous slides. Let's say if we define di to get this formula, then we can simplify dsij to this so we can rewrite dq and dk with the definition of big di. Pause the video a little bit if you need to take a look at the details. Now let's take a look at the io complexity. It's actually very similar to forward process. In the outer loop, each element of k and b is loaded from hbm one. Both are n by d dimensions. So this is big data, and by d hbm reads. And given k and v are broken into tc blocks in the outer loop, we make tc passes in the inner loop over Q O and do each pasonly in all of Q O and do to hpm. This results in three by tc, by n, by d, hbm reads. So it's big data n by d by tpc, and we can reuse the previous calculation for tc, the total io complexity is big data and dtc plus nd d equals to big data and dtc. So the final io analysis for forward process and backward process is the same. And as we discuss, it's a lot smaller and standard attention, since d square is usually a lot smaller than m. Alright, this is the last slide I have for flash attention V1. Hopefully this is helpful and I promise I will go through the later versions of flash attention as well as flash attenusage in inference in later episodes. If you like my video, please subscribe comments and like alright, see you later. Bye.