Hardware-aware Algorithms for Sequence Modeling - Tri Dao | Stanford MLSys #87

本次讲座是斯坦福MLSys研讨会第87期,主讲人Tri Dao讨论了面向序列建模的硬件感知算法。讲座分为两部分:第一部分关注改进Transformer中的自注意力机制,指出其在处理长序列时存在时间和内存复杂度呈平方增长的问题。通过IO感知算法(如FlashAttention),可以显著提高注意力计算的速度和内存效率,从而支持更长的上下文并提升模型质量。同时介绍了长上下文大型语言模型推理的优化方法。第二部分探讨了二次以下时间复杂度的模型架构,如循环神经网络(RNN)、门控卷积和结构化状态空间模型(SSM)。讲座指出这些模型的关键弱点在于缺乏基于内容的推理能力,并提出了选择机制来解决此问题。尽管这会影响卷积效率,但设计了硬件感知的并行算法。将这些选择性SSM集成到简化架构Mamba中,该模型在语言建模任务上达到或超过了现代Transformer的性能,并具有更快的推理速度和更长的上下文处理能力。

媒体详情

上传日期
2025-05-16 20:59
来源
https://www.youtube.com/watch?v=foG0ebzuw34
处理状态
已完成
转录状态
已完成
Latest LLM Model
gemini-2.5-pro-exp-03-25

转录

下载为TXT
speaker 1: So I guess I mean.
speaker 2: one of us should monitarily you do that you hear maybe .
speaker 1: and I'll just control the camera that's needed so you really have to do it because I can't be using my laptop for so so even okay, that's Yeah I don't all so I shall do what to improve. Hello, testing.
speaker 2: testing、testing、testing.
speaker 1: All right, hello everyone. Welcome to was it 5:28 is our number? Yes, 528. That's great to have you here. We are delighted to have tree Dow with us today. Michael will say more about him in a second. But first I want to cover some just very quick logistics. So this is being streamed on YouTube as we speak right now. So if you you know feel free to ask questions or whatnot, but know that it will come through what else, I would say we're going to have a poll at the end of this in terms of how much we want to do virtual versus in person things going forward. So please stick around like after tree finishes for that just so we can figure out exactly how we want to coordinate this going forward. We want to do the best thing for you, and we're to try to make this as interact with a session today as possible. So do not be shy. I ask whatever questions you have on the fly. I think tree has certainly made time for questions at the end, but it's much more fun if we have questions during as well. So with that, any logistic questions regarding the class that I should handle right now? Or yes.
speaker 2: a Dreger lie on the street. We initially .
speaker 1: can't make a data. Yes, you can rely on the stream. Maybe we screw the technology somehow. We put it prinple. It's supposed to work.
speaker 2: So it's the, I mean, they're reliable on YouTube now, so they .
speaker 1: should just be in, Yeah, theybe to stigrow live upwards. Other questions? Yes, ving, no easy money. Any other questions? Awesome. All right, Michael, you want to tha treat it away? Today, we're very delighted to have three dijust sort much attention himself. So the tree is nice.
speaker 2: Tree is the what is here right now. He's an incoming assistant professor at Princeton University and currently chief scientist at together at ionq. He played his PhD in computer science at Stanford who advised Mike, Chris, ray and souore urban. And he works at the intersection of machine learning and systems. His research interesting through sequence models with long range memory and structured matrices. His work has received the icl 2022 outstanding parunner of award. So to have your thank you Michael. Hi everyone. Thank you for having me and I'm excited to be here. So for this this talk or seminar is actually just a bunch of exciting results on the at the intersection of machine learning and systems and some of these results I want na share with you guys. So feel free to stop me and ask questions is not a super formal talk or anything. Yeah. So let's get started. So just to motivate, you probably have seen a lot of this. We've seen machine learning making a lot of progress, especially recently. There's been a bunch of tasks that for a while we thought were really, really hard. And now we have machine learning systems that can do these things. So for example, fixing box, you can go on ChatGPT or GPT -4, and you paste this snippet of code and you ask it, what's wrong and will tell you what's wrong, right? And that this is being built into a product. But I think Microsoft just released Copilot pro and things like that. So you know, millions or hundreds of millions of people are using this stuff. But two, three years ago, we didn't think these things were possible. Or you can generate art. This my attempt of using Stable Diffusion. I think this was a while ago. You type in a problem and you wait a couple seconds and the model does a pretty decent job. I think these models are getting so much better now in other domains, things like alpha fold has done such an amazing job at predicting protein structure that now a bunch of companies are using these models to accelerate drug design. And so what we want to ask is, what enabled these advances? What are some of the outstanding problems? And how do you approach these outstanding problems? And so one argument has been that is scale that has brought about quality and capabilities. Maybe some of you have seen this graph where model size and data size have increased about 1000 fold in the past five years or so. So back in 2018, with bird at 300 million parameters, and last year's Megatron, Turing ing and palm at 500 million parameters, and gb four reportedly at the trillion parameters scale. And as you scale up the model size and you train these models on more data, they just get better. They get better on existing benchmarks. So that's great. But what's really amazing is that they seem to have new capabilities. So here's my attempt of using language models to explain jokes, move this. So I tried 10000 random reof, my newer network, but I was accused of overfitting. I guess no good seed goes unpunished. And then I just try, I asked a 1 billion parameter model to explain this joke. And it didn't get the joke. It just when you ask it to explain the joke, it just tried to repeat the punch line without understanding the joke. But if you ask a 175 billion billion parameter model, it doesn't understand the joke. It doesn't understand this. A pun on no good deed goes unpunished. So this is amazing. It seems like scale is now more closely tied to advances in machine learning than ever before. And with scale, there comes a core challenge, efficiency. And I assume throughout this this seminar, you see this issue being brought up, a large efficiency. And there are a couple of reasons why you care about efficiency. One is that efficiency can make training and deployment so much easier, and it can facilitate research. So right now, we have a bunch of large and slow models that do really well and they're really accurate. Can we get to the point where we have smaller and faster models that do just as well? And we've seen a lot of progress, especially in the just last year or so where now we have you know seven b models doing as well as 70b models. We have models running on phones are super capable. I think, for example, stability, just release like stable code that's like three b that's run on phone. That's about as good as some of the larger colama models. So we've seen amazing progress here. Another reason you might care, you know these are very practical reason to care about efficiency. Another reason you might care about efficiency is that efficiency can bring can unlock new capabilities. So here's an example from I think this is GPT -3 playground. So that was last year when if you ask it to write a 4000 word essay on the best ice cream flavor is just it couldn't do. So. This was last yeas. It was gpthree. And now GPT -3 point five can deal with 16K contacts. Thanks to some of the work that we've been doing and the community have been doing, GPT -4 can scale to 120 8K contacts and so on. So this is amazing. And I'll argue that some of these advances are key to better efficiency. I'll touch on this on do some of the work that that I'll touch on. Yeah. So efficiency, you might care about from a practical point of view, make it easy to train and deploy models. From a capability point of view, I think efficiency can unlock new capabilities as well, and I'll touch on this. So my approach to understanding efficiency is to understand both the algorithms and the systems side. So on the algorithm, you can spend a bunch of time understanding things like matrix, vector, multiply, which is kind of the fundamental operation in a lot of these neural networks, or you can try and understand the extension mechanism, which is the heart of the transformer architecture that has brought about a lot of these advances. On the system side, we want to understand things like hardware accelerators and distributed systems, basically whatever these models are running on, right? So things like GPU's, you want to understand these block oriented devices, and that has a bunch of implications on how you should design your model and how you design your algorithms. And when you wire a bunch of these accelerators together, you want to understand that these have asymmetric memory hierarchy, and you can exploit that to to make your system five to ten x more efficient. And I'll talk about some of the ideas of how to do that. So as an example of how we can we can think about both the algorithm and the hardware, so I'll talk about what main ideas is hardware where algorithm. So any kind of algorithm that takes advantage of the hardware that they run on. So for this for this talk, I'll mention 22 examples of these of this idea. One is io awareness, which is you want na, for example, memory for GPU memory. You want na reduce the amount of written and rights to GPU memory because that really slows you down and that can bring you significant speed ups. So an example is flash attention. So this is jowork with Dan fu and Stefano and auery and Chris, who some of them are at Stanford. So this is a fast and memory efficient attention algorithm with no approximation. And that's the first half of the talk. For the second half of the talk, I'll talk about how some of these ideas apply to new model architectures, not just transformer, but these ideas are more general, and we'll show how they apply to state space model. So I'll show some example of how you can expand the recurrent state. This is a recurrent neural network architecture. We can expand recurrent states into sram to avoid memory cost. And this leads to mamba, which is I join work with Albert gu, who's now he's a professor at cmu. So just a selective state space model that the exact thing is that it's not a transformer, but it matches transformer on language model with a much faster inference and much longer context. So I'll touch on some of these points. And we've been very lucky to have some of our research being adopted by folks in the community. So for example, flash attention is now part of has been part of high torch for a while. So I think it's called scale dot product attention. So you can call it directly from PyTorch. The folks at OpenAI really like this. So they reimplemented it in the language OpenAI Triton. Folks at meta and vimicrosoft and a bunch of companies have been using flash attention for training and inference. So I think in hugging face, it's also been integrating flash attention into some of these models as well. Yeah. So that's kind of the overview of what I'm going to talk about. So for the first half, I'll talk about flash attention. The main idea is you want to reduce memory reads and rights. And this in the second half, I'll talk about some how some of these ideas apply to other kinds of models, such as state space models. So I'll pause here to see if there are questions on kind of the motivation and the overview. All right, so let's get started on the first half of the talk. So I'll focus on this layer, the attention layer, which is this core primitive in the transformer architecture that's being used everywhere. We want to understand what the bottlenecks are since how the bottlenecks are memory written, right? We want to understand some of the approaches to reduce this bottlenecks. So tiling and recomputation, and I'll show some applications this how this allowed you to train transformers faster or with longer contacts with batquality. So the motivation why we want to do this is that we want to model long sequences. This is motivation on kind of the AI side, capability side. So one, anchoring applications and of course, natural language processing. I think a lot of you are familiar with where we need large contto process things like books and plays and instruction manuals and code bases. One example, one thing I really like from for example, the GPT -4 demo is they just put in like a whole you know 50 pages of documentation of a new library. And the model kind of figures out how to use the new library because it can process 50 pages of 100 pages of documents. So these are new capabilities that can be unlocked if the models can process long context in computer vision. We want a close reality gab. And we've seen a lot of this in, for example, diffusion models for image generation, where we know that modeling images at high resolution generally leads to better and more robust insight. But if you're using something like vision transformer, which is what most people use, or diffusion models, high resolution tends to lead to longer sequences. And know that just means that it's a little bit harder for on the system side or on the efficiency side to deal with high resolution images. So those are pretty popular application domains. But in other domains, I would argue that if you can model long sequences, you can unlock a bunch of new areas. So things like Sam series, audio, video, medical imaging, where data is naturally modeled as sequences up to millions of steps. One example is Histol pathology, where we have some folks from Chris ray's lab who study these kind of cell images where they're trying to predict things like cancer, tumor and things like that, where you need very high resolution and people don't ever downscale these images. So the images are so high resolution that when you vision transformer, you can't deal with like a million hamstep. So that's just a motivation with some of the applications that can benefit from modeling log sequences. So why can't we just you know what is the problem of modeling log sequences? And I would argue that efficiency is one of the main problem is one of the bottlenecks for the modeling long sequences, especially with transformer, which is the dominant architecture right now. So here in terms of terminology, I will use context lengths to mean how many other elements in the sequence does the current element interact with? So as as you increase the context length, you can slow down or training my completely stop. And I'll give some examples to illustrate. So Yeah, I'm showing here to a pot of training speed of two models. One is a one b model and one is a three b model. And I'm using megatrlm, which is kind of a state of Dr. Art Library for training these language models. So this is from nvidia has been used to train models of up to, I think, 500 billion parameters. So this is kind of a state art, state of the Art Library. And if you use context lanof 2000, that means you're processing maximum document of 2000 tokens, then your training speed here plotting on a 100 is pretty reasonable. You know it doesn't matter what the you know the absolute scale kind of is proportional to how many tokens you can process for second hires better. So you in absolute term, these are very reasonable speed, processing speed. But as you increase the context length just from two k to 8K, what happens is that your training slows down. For the one b model, it slows down by two x. And for the three b model, you go out of memory and you can't train at all, right? So this makes sense because you know now the sequence is longer, so each element has to interact with way, way more, way, way too many elements. So you one, you slow down because you need to process, you know you do more floating point operations. Two, you might go out of memory because you just need more memory to store the longest imauses. And I'll go into the specifics of how so it's not trivial to just increase the context length. And this has been a problem ever since kind of transformer came out in 2017. And attention was already a thing before that as well. So it's been lots and lots of paper working on this. So just in terms of terminology, I'll go over just a recap of what the transformer architecture looks like. I know a lot of you are familiar with this. It's just so that we're on the same page in terms of terminology. So transformer architecture is a neural network architecture consisting of a bunch of blocks. Each block is called an encoder block or sometimes decoder block. And each block has two main layers. There is the attention layer and there is the mlp layer, multilayer perceptron. So I'll focus on the attention layer because for the mlp layer, scaling up sequence laying is relatively straightforward, but the attention is more difficult. So I'll just go over a little bit of background on attention to select that. The notation where we're on the same page in terms of notation. Okay, so here's what a single head attention look like. So you have inputs, the query in the key, and these are upshape n by d, where n is the sequence length. So as I mentioned, the sequence things on the order of and the d is the head dimension, usually on the order of around 100. So you given the query in the key, you multiply them together to get the similarity score. So you're essentially doing a pawise comparison between the query and the key. So every pair of query and key, you compute some score that says how similar they are, and then you normalize them using this operational sofmax of a vector, which just means you exponentiate the vector and divide by the sum so that everything is positive and sum to one. So it looks like a probability distribution. And as you do this normalization per row, and once you have this normalization matrix a, you multiply by the value, which is also an input value of size n by d, and you get the output n by d. So written more compactly, output is equal to soft tmax of q times k and then multiply by v. So these are actually a very, you know, written in this way. It's relatively simple. You can write it in a couple lines of pi torch, but it is the core is the heart of the transformer architecture. And so we see that attention scales quadratically in sequence length. And this is the core of the problem. So as you increase the sequence length n, you're doing you know alssa, you double the sequence length n and that means you have four x more four x number of pairs. You need to do this pawise comparison. So doubling the sequence length means you increase the amount of computation by forex and you increase the amount of memory by four x. So this is why it's hard to scale to longer sequences with attention. So ever since transformer came out in 2017 and became wildly popular, there's been lots and lots of work on approximating attention or trying to scale it to longer sequences. So this probably, I don't know, ten papers every week on this stuff b and a lot of them are pretty are pretty cool. So I personally have worked in this area as well. So just a little bit of contacts on how these methods work. So the goal here is to trade off quality for speed. So we're going to approximate attention. We say, well, attention, that is scales quadratically in sequencing n, it's too computationally expensive. Can we approximate it? You know, we do less computation or use less memory. The quality might be worse. But you know, maybe maybe there's an okay, traoff. And there are two main classes of approach. So one is sparsity and one is using low rank. So on the sparsity side, the core idea is you say, well, attention, instead of doing all this n square number of pairwise Res comparison, we'll just decide to ignore some of them. And we only do pairwise comparison for some of the pairs. So here on in the diagram, I'm showing the cells in purple. So we note, know, let's say you decide some of these cells I want to compute, and some of these other cells I want to pretend that they're zero. I'm going to ignore some of the pawise comparison. And so if you use this kind of sparsity, then you can do less computation and hopefully get some speed up on memory saving. So that's the idea of sparsity. There are many, many ways to choose which entries to computer and which entries to ignore. So that's why there are tons of papers on how to choose this, either statically or dynamically and so on. So that's one class of methods. The other class of method, the is assuming some low rank structure on the attention matrix. So you say attention matrix, even though nyn, maybe we can assume that is actually low rank or close to low rank. And and then if you can factor it out as two low product of two matrices that are together low rank, then you can do the multiplication the other way, which is multiplied by k times v first, and then multiply by q. And then different methods would have different ways to approximate you do this lowring approximation. So the goal here is to trade off quality for speed. And I've worked on this as well. And there's an excellent paper, sorry, excellent survey, a long range of arena, which benchmarks a lot of these efficient transformer methods. So turns out when, after I've worked on this for a while, we went and talked to some of the practitioners who trained large scale models, and I asked them, how do you use this method or that method and so on. And generally they say no. They say, no, we don't really use these methods. And one of the core reasons is that, well, the two reasons, one, first of all, the quality is worse. So okay, that makes sense. We're approximating attention. We're not computing all the entries. So Yeah, okay, maybe the quality is worse, but more importantly, they're not even faster or they don't even save memory for a lot of them. So this was surprising to me. Now my mindset was more on the theoretical side. Hey, if you're doing fewer operations, makes sense. It should be faster to out it practice. These are not generally faster. So the core idea is they're doing fewer floating faroperations, but that might not necessarily be the bottleneck. So fewer floating operations might not translate to a faster walk clock tam. And faster walk clock tam is actually what people really care about. So we ask the question, is there a fast memory, efficient and exact attention algorithms? So I'll show at least one way to to get there. But before we get there, we want to understand what the bottlenecks are. Previously, I had assumed that the bottleneck was, Oh, we have to do a bunch of floating point operations. So if we can reduce that, we should be faster. But from talking to some of the systems, folks, this is kind of the biggest bang for your buck. The advice they gave me was, did you profile your code? This was a surprise to me because I'm more on the I came from a machine learning background or a math background. So I'm like, Hey, you know, I approve this thing. I approximate this. It should go faster, right? When when we talk to systems folks, they're like, did you profile your code? Where is the bottleneck? Right? And turns out if you profile your code on GPU's, the bottleneck is in memory readand rights. So as I mentioned, we have these large matrices, n by n, that store the kind of pairwise similarity between the query and the key. And turns out they're just reading and writing. These large matrices take most of the time is not that the time is not being spent on full point operation, which are useful. The time is just reading and writing to memory. So this was a surprise, but you know in hindsight, that that kind of makes sense. Yeah the biggest cost is in moving the bits. Turns out standard implementation requires repeated bits and rights to GPU memory and that's causing the slowdown. So in order to tackle this problem, we need to understand the hardware just a little bit. So I'll just have one slide kind of crash course on how GPU's work. So here I'm showing a diagram of a streaming multiprocessor, which is you think of as as 11 part of a GPU GPU you can think of as having 100 of these things. And the the Green box is hbm, a high bandwidth memory. This is what you think of as GPU memory. So if you're using a GPU, you type nvia smi and it shows you, I don't know, 24 gigs, 40 gigs, 80 gigs. So that's the hbm. That's a high bandwidth memory. But there there are other components such as the compute units. So these are specialized hardware units that actually does the computation, things like matrix, multiply, adding things, multiplying things. And then there is the so that I'm showing that in purple. And then there's asram, which you can think of as cash aci'm, showing here in orange. So asram sits very close to the compute unit. It holds, acts like a cache, is quite small, but it's quite a bit faster than hbm. So this is on the a 100 where the hbm, which is you know an absolute term, is a marvel of technology. It's an absolute term, is actually really fast, 1.5 tb per second. H -100 now is quite a bit faster than that. And it's quite large. 40 gigs, 80 gigs. The s ram is an order of magnitude faster than that, but three orders of magnitude smaller. So know you see this kind of memory hierarchy popping up everywhere in computer science. You know on cpu you have the dram and the cpu l one, l two cache. So here in the GPU, you kind of have this this the same situation and it makes sense. It comes down to physics. So if you have a compute unit that's built you know using silicon, there's only so much area around it that you can put silicon to know as s ram, right? So there's limited amount of area that you can build around it that's very close to compute unit. So that part is fast, but there's not much area. Then as you go further and further away from the compute unit, you can have you have more space to put silicon for the hbm, but it's sitting physically further away from the compute unit. So it's gonna to take some time to move the data. So this this all comes down to physics. Okay, so how does a GPU work? The input start out in hbm, also known as the GPU memory and then you need to move the inputs to the compute units and sram for computation. So there's a data transfer, and then the computation happens, and then you move the output back to write to hbm. So if you want to understand, you know this is kind of a crash course on how GPU work. Other accelerators are similar. And if you're interested, I highly recommend this blog post by Horace ha. He's a piytordeveloper. He wrote an excellent blog post, cormaking deep learning, your burr from first principles. So I think he's, you know, you spend ten, 15 minutes reading the blog post and you understand machine learning performance way more. I think is one of the highbanang for your buck thing to do to read his blog posts. So now that we understand hardware a little bit, we ask, can we exploit the memory asymmetry to get speed up now that we understand there's hbm and there's asram. So the trick is io awareness, which is we're gonna to try and reduce the reads and rights to hbm and do a lot of reads and rights to s raps. So that's the core idea. Once once you understand kind of the bottleneconce, you understand the the hardware model, then the approaches look so so alien. The approaches kind of turn out to be quite natural. So that's a little bit of background on attention and hardware and where the bottlenecks are. So I'll pause here to see if there are questions.
speaker 1: Yeah. So I have a question, which is I mean, the tricks that go into flash attention and you know proiling code, like these are like fairly scperformance engineering tricks. Like why in your opinion, did no one did a video like bother the profile of your attention code and you know figure this stuff out.
speaker 2: right? Right. Yeah, that's a that's a great question. So you know transformer came out in 20, 2017 and kind of lots of people switch over and attention performance or speed is one of the core problems. So folks with nvidia and Microsoft and so on have certainly optimized attention. They spend a great deal of engineering effort optimizing attention. I'll go into some of the issues there. The issue was that I think they were thinking of it from kind of more purely hardware perspective, where they think of, well, you know, I need to do a sofmax and I can kind of fuse it with the next operation. Sometimes there's like causal masking. So people have done things like fusing soft tmax and causal masking and so on, but they weren't thinking about it from a kind of mathematical perspective, which is, I'll show you kind of need this mathematical trick of softmax rescaling in order to make the algorithm work. So I think if you just approach it purely from a systems perspective is a little bit hard. If you approach it purely from a mathematical perspective, I think you kind of misyou miswhat the bottlenecks are. So you kind of need to understand both what is the bottleneck kind of from a system perspective and how do I change the algorithm a little bit, rewrite things mathematically, get to get the same answer, but be more much more hardware friendly. Yeah. The one .
speaker 1: understanding it was like GPU's came out and then they were like, Oh, this could be really useful for neural nets. And so it wasn't really designed for that use case or transformers. And then now you see how means a lot of companies coming out creating AI specific chips, you know how they're specifically changing the architecture is more vision.
speaker 2: Yeah. So the question is, you know how how would new companies making chips, they're more tailored to AI applications? So one way we've seen for AI applications, want one characteristics, is that they have a lot of kind of parallel processing, so things like matrix multiplication, which is quite different from graphics workload. So as a result, when these machine learning workload became popular, for example, Google came out with tpu, and what they did was they just put in specialized units that do matrix multipcation really fast. And these things are called systolic array. You can read about them. So tpuu came out in 2015, and they have these systolic arrays that only erthat's specialized to accelerating matrix multiplication. And of course, nvidia to respond. So they put tensor cores, which is also systolic array, in their V1 hundred. So the shift has been, let's put more specialized unit that do matrix multiplication. I think back then people weren't specializing to the architecture themselves, so they were like, let's specialize in matrix multiplication. This is an interesting trend where maybe more recently people feel that, okay, maybe the transformer architecture has been dominant for five, five, six years. Maybe we want to specialize with transformer architecture and we've seen a little bit of that is kind of a risky bet. Like you know if you make chips that only work for transformer, what if someone else, you know, something else come out? I'll mentioned in the second half of the talk, but is a somewhat sensible bet. And if you want a special attransverse, for example, nvita has this thing called transform engine for the H -100. It's not really a hardware thing. It's more of a software thing where they do some of the operations in lower precision fp eight, you know and they integrate with transformer, make it easy to use for transformer, you know who knows for the next chip, B1 hundred, who knows if they're gonna to you know maybe put a specialized attention unit on their hardware like that part? I don't know. Yeah, Yeah, Yeah. So you mentioned .
speaker 1: reusing solutions x spcy and little rine. They reducer flops, but really didn't be pretty bloof time. Yes. So I'm wondering why do you think the research community is focused on this area if reality it was not reduced?
speaker 2: Yes. So I think there are different groups like they have different motivations. So there are groups know there are folks who actually train large scale models and for the most part, don't use these methods, right? So they worry more about, okay, how do I scale it up? Yeah make sure it's scalable. Know how do I communicate between different GPU's and then things like that. These are folks dealing with large scale training and more. They're folks focusing on the gorithm and know they're interested in, as I've saying, you know fundamental problem like how should things scale? You know things should scale lenearly instead of quadratic, something like that, right? So there is more of a academic question of like they're not so much concern about trading large scale models. You know they're more motivated by, Hey, you know I've write this cool paper using this cool method, using this cool method and know I've certainly done that as well. But more recently, I feel like there's been a convergence of interest. So large language models being so useful everywhere that people are starting to pay more attention to, Hey, these fancy methods, does they actually run fast? So that's something that I've been really, really happy about. So one example, I think more recently, there's a paper called hyperattention, which does some of these sparse approximation, but they also incorporate some of the ideas from flash attention. Like you know they do some sparse approximation, but then use flash attention as a core subroutine to solve a subproblem. So I think nowadays more people are caring about performance because if you want your stuff to be to deploy or be useful, I think ultimately you do have to care about performance and inefficiency. Yeah. So that that is changing. That's a good question. Yeah, I was going .
speaker 1: ask to talk more about difference between the architecture of ppu versus gpus. Keep use a sort of the only main competitor to choosing gpuu models .
speaker 2: is Yeah so those are you differences between GPU and GPU. I think tpu are in some sense a little bit simpler on the architecture side. GPU's are more parallel. So GPU's here, I'm showing to one unit, but you think of it as like having hundreds of hundreds of these units working in parallel. So you need to when you program it, you kind of need to decompose the problem into parallel things that the gpuu can then handle. This is natural from a graphics point of view. You know, the motivation for graphics is like you want, na, let's say a render, you know, render games or something. You have millions of pixels. You can do a lot of computations in parallel. So GPU's are much more parallel GPU's. On the other hand, they say, you know, let's just tackle the core problem, which is matrix multiplication. Let's just make matrix multiplication really, really fast. So their programming model in some sense is simpler. You don't need to worry as much about as much about parallelism. But ultimately, I think the core of it know nvida has GPU's have also been shifting to cater towards AI workload. So everyone's putting more of these matrix multiplication units in their accelerators Yeah .
speaker 1: multiother and understand rates to sort modify others ors .
speaker 2: Yeah besides matrix multiplication, are there other trigs? I think very much. Matrix multiplication. Matrix multiplication accounts for, I don't know, 9090 5% of the floats, the amount of floating point operations. So if you speed that up like you're pretty much open, there are some applications where it's not so so heavy on matrix multiplications. So some of the you if you do things like physics simulation or oil and gas where you're doing more signal processing, image processing, then things like convolution and ffbecomes more important. But so far that hasn't been so important that they would put like specialized hardware units on it. Yeah. Okay. These are all great questions. So I'll move on just a little bit to talk about some of the approaches you you can use to reduce the memory reason, right, which is the bottleneck that we have. So as Ben mentioned, know some of these tricks are pretty old and the systems folks have known about these for a while, but there are challenges to applying these tricks. Computing things by block is a very old trick. It is from actually 1970 ties when people started writing matrix multilication. But the softmax in attention really kind of screws things up for to if we want to compute things by block, because the sofmax has this normalization factor that it couples the entire row together. So it's hard to just compute things by block. And for the context of training models, you do need to compute the gradient. And to compute the gradient, you need to these intermediate tensors, these intermediate matrices. So it sounds like you would need to store these intermediate matrices anyway to compute the gradient. So you need to store a contrratic amount of memory. And so our approaches are, in some sense, really classical. So one is tiling, which is we're going to break things up by block, but with a little bit of a twist. You need to restructure the algorithm a little bit so that you can compute things by block. You know, you load one block from hbm to sram, do all the computation, and then write out the result. But the biyword PaaS, the approach is even simpler, at least conceptually simpler, which is you instead of having to store the big matrices and by n matrices from the forward PaaS to compute right in the backward PaaS, we don't store it, which is recompute it in the backward PaaS. And turns out there is a way to recompute things that is even faster than storing and loading backup. So I'll go into a little of details there. Okay, so for the next two or three slides, so therebe a little bit of math of how this actually works. So if you don't follow everything, that's fine. You know, feel free to ask questions, but in three slides or so, I kind of I'll hop back up and talk again in higher level. Okay. So this is the core kind of mathematical trick that make things work, at least for attention. So here I'm drawing block diagrams of how attention computation proceeds. So you have the query in the key q and k and then you multiply them together. So I'm just drawing block diagrams to get the, you get q and k, multiply them together to skithe score matrix. And then to do softmax, you exponentiate this matrix entry wise, and then you compute the row row wise normalization constant. So you exponentiate and some the row, and to get this constant l and then to multiply by v, what we do is we take a divided by l, that's the soft mac normalization, and then multiply by v. So this is matrix multiplification. So so far I haven't, you know I'm just rewriting things in this this block diagram form. Okay? So if we want to compute things by block, what we can do is that the queries are kind of embarrassingly parallel. So you can split by query, but the keys and values k and v, they're coupled because of this soft max normalization constant. You need to sum across the entire row. You can't just do like local computation for just a couple of columns. So that is that is the challenge. Okay, so let's just try to split k and v anyway and see what happens. This is our first attempt. So let's say you split k into two parts, K1 and K2, and similarly v into two parts. And the goal here is let's see if we can what we can do. If we want to do just local computation, we want to load one block from hbm to Aram. And then let's see if we can just do local computation. So you know we'll proceed as before. We haven't changed anything. We're just rewriting in two blocks. So you would have q ed, hms K1 to get s one on the left, and q, hms K2 to get s two on the right. No change so far. And then you exponentiate things to get. A one is exponential of s one. A two is exponential of s two. No change so far. And then the trouble comes when you want to compute the normalization constant, because you need to sum across the first block and then sum with the second block. So this is coupling the two blocks together. You can't just do one block at a time. Then when you multiply with v, that's if v is split into V1 and V2, then the output is a one divided by l. So that's a normalization constant, a one divided by l times V1 plus a two divided l times V2. So here, here's the challenge. The soft maximization couples the two blotogether. So you can't just do one block and a tap, and then the output requires the contributions from both blocks. So this is this is annoying, but so this is kind of how the systems folks have been thinking about it, which is okay. This is softmax. You know I can't really break it into blocks, so it's hard to optimize and reduce memory rates and rights. Okay, so here here comes the trick. So this is our attempt, our second attempt, and the trick is softmac rescaling. Okay, so I've rewritten what we want on the top right corner. So we want this output, this normalization constant l, which is sum of exponential of s one and sum of exponential of s two. And then the output we want is a one divided by l times V1 and a two divided by l times V2. So I'll show a way to rewrite the algorithm that gets the same answer, but it's that will allow us to do local computation. Okay, so let's see. So I'll use blue boxes to show a store in hpm, and I'll show White and in orange boxes to show what's computing in sram, but not store in hbm. So this is important. We have to control what's store in hbm and what's not because as I mentioned, reading and writing to hbm is what's slowing us down. Okay, so as before we proceed, we take q times K1 to get s one and then we exponentiated to get a one. So nothing has changed so far. But here's the first change is that we're gonna to compute the normalization constant with to just the first block. This is l one, which is suof the rows of a one. Then we multiwith V1 to get this output. O one equals a one divided by l one times, V1. So on one hand, this is great. You know, we are only doing local computation. We load it in a block of k, we build it in a block of v, and we're doing this local computation. We're not touching the second block at all. This is great. But on the other hand, what's not so great is that we're we're getting the wrong answer. We're dividing by this normalization constant l at one. But what we really want is divide by the normalization constant l, which we don't have right now. We only have l one. Ideally, we want to divide by l, so we're getting the wrong denominator, but that's okay. We'll come back and fix it up. So let's move to the second block. So we'll proceed as before. We take q times, K2 to get s two, we exponentiated to get a two. And then here, because we have l one, we can appeal two by taking l one and adding it to the row sum of a two. So now if you look closely, is l two is actually equal to L, L two is actually the right normalization constant. Okay, so great. We have, we now have l two equals to l, the right normalization constant. So how to proceed? We multiply by V2 by taking l one, and then we rescale it. So previously we divide by l one. So that was the ranormalization constant now, which is multiplied by l one, and then divide by l two. So now we're scaling by the right thing, right? And then for the second block, which is add in the contribution a two divided l two times V2. So this kind of mathematical rewriting allows us to do still do local computation. Let's still get the right answer. And this softmax rescaling trick actually has been known to the ml community since think is a paper that formally wrote it down in like 2018. It's not a new trick either. And so this is great. This allows us to do local competition sram and without having to write to hbm and get the right answer. So 11 caveat is soft tmax. There is this trick with numerical stability, which is for softmax, you need to subtract the max before exponentiation. I'm not showing it here. If you don't subtract, the max is not numerically stable. So in practice, now in natural code, we do subtract by the max, and we're subtract by the local max instead of the global max. And there are ways to still rescale, to still get the right answer. So there's a literal caveat about numerical stability where you have to subtract the max. In practice, it still does work. Yep. Yeah, you have a question.
speaker 1: Doesn't this mean that right now you're still having the same Barrion, the outto different readable one where you .
speaker 2: can write political two? Yes. Yeah. So we so I think when I show should then yes, so o one, we're going to start an s rwe're not going write it down yet, right? And and we're going just in practice, if you're familiar with gviews, like we store it in register and so we just update those register .
speaker 1: without ever writing. Yeah is the .
speaker 2: same worker that's processing block one and then moving on to block two. Yeah, I'll talk about how to paralyze this a little bit later. Yeah, your parallelism is important. Yes, we'll talk about that. Yeah, more questions. Okay. So that's a little bit of that's more on the details of how this actually works. The takeaway is you need to rewrite things mathematically for this to work. That's why you I think it took a while for people to figure this out. Okay. So and I'm popping back up how high level, what does this look like high level, we load inputs block by block from hbm to s ram and then we just compute things on chip computing and without writing to hbm and only to the end, we update the hbm by scaling. So is a liberal animation by Francisco masathis is the he's on the he leathe x former team at edmeta. And he, I think he created torch vision and he's been working on this for a while as well. So it was a little bit of animation that he made where we we have the queries on the left, the keys at the top and the values at the output are on the right and you're moving block by block kind of sequentially from left to right. So that's kind of high level how the computation proceeds. Okay. So the forward PaaS is you know we figure out how to do the forward PaaS, the backward PaaS reminder, the backward PaaS, you need to compute the gradient. And to compute the gradient, you would maybe naively need to store the attention matrices to get the gradient of the query key in value. So what do we do in the forward PaaS? You know, we didn't store things in the forward PaaS, right? In the forward PaaS, we only compute s and a kind of in asram, and we didn't store it. So how do we compute the gradient in the backward paand? The answer is we just recompute things. So if you store this this soft manormalization constant, it's actually not so hard to recompute the intermediate matrices. And the intuition here is that computation is cheap and memory, reading and writing is expensive. So here we're comparing two implementation, so standard attention implementation and flash attention. And in terms of amount of compute, flash attention actually incurs more computation, in this case, 13% more computation. So this is kind of going the opposite way of what of some of these approximation methods do they were try to reduce the flops. In this case, we're actually increasing the flops. We do more computation in order to reduce memory written rights. So in this particular case, we reduce memory written right by nine times. And as a result, the runtime is six times faster. In wacal taso, this is a deliate counterintuitive, but once you kind of understand the bar on axis, it becomes more clear. So we can speed up the backward PaaS even with increased amount of computation. Okay, so more benchmarks. So we benchmark this. This is about two to four x faster than kind of state of the art, the optimized baseline for different sequence length for different scenarios, whether you have masking or droowand whatnot. So this is with no approximation, the answers are exact up to numerical error. And then in terms of memory, memory now skills leanally instead of quadratically in sequence length. So the memory saving, which is I'm plotting memory reduction, how many times as you increase the sequence length, you save more memory. So memory is now linear in sequence length. And coming back to the motivation of training language models with longer context. So previously we saw that when you increase the context laying by from two k to 8K, it slows down or go out of memory. But now with flaattention, you can train with ak reasonably well, for example, just like 2.4x faster. And to end a one b model and a 43b, it no longer runs out of memory. You can train it somewhat comfortably. And when you train models with longer context, the models get the models do get better. So I wouldn't go into the details too much. We've seen that flash attention is able to speed up transformer training, language model training by two x and increase context length by forex. And that leads to better models. And know as you scale it to larger scale, I think this this has enabled a lot of models to scale to longer sequences. For example, a lot of them happened last year. Okay. So you know that's the version version one we did, I think a year and a half ago, and then we came back and kind of optimized it further. So this is a little bit more lower level detail. So we had more better parallelism and work positioning between different warps in the GPU's. That's a little bit more hardware details. We reduced the amount of non matte moflops. So again, it comes back to intuition of reality. That is that they're specialized units that do matrix multiplication. So anything else that's not matrix multiplication is relatively slow. So you want to do spend most of your time doing matrix modification and you want to reduce communication between parallel workers. So as a result, flash attention two is about generally two x faster than what flash attention sion one. And I think flash attention one is in pytortwo 2.0 or 2.1. And flash attention two is going to be in PyTorch 2.2, which is going to be released, I think, end of this month. And it's also, it's been in hugging phase transformfor a while. Okay, so that was for training, transformer training. But nowadays, a lot of the workload is in transformer inference and I'll show some of the ideas how we can optimize for inference and it involves parallelism. So this is join work with the folks at meta.
speaker 1: the x four team.
speaker 2: Yeah, just give a 20 minute. Okay. Okay. Yeah, thanks. So for inference, turns out the bottleneck again is in io wees and writbut is about loading the kv cache. If you're familiar with with transformer inference, there's is kv cyou store history, the query in the key, sorry, the key and the value. But the new query, there's only kind of one new query. So if you store the entire history and then there is this, you predict one more token, and then for the next step, you kind of compare that one token with the history. So the query is actually very short, but the key and the value could be potentially very long. And the challenge is how to load the kv cache as fast as possible. So previous method, as I show you, there is a sequential dependency. We process one block at a time. We move from one block of key in value and move to the next block, then move to the next block, and then move to the next block to compute the output. So this is fine if the query is very long, because you still have a lot of parallel work to do. But for a transformer inference or our generation, the query is usually very short as like one token or a couple tokens. So there's not enough parallelism. Most of the time, most of the gpuu is just sitting idle because there is this sequential dependency. So with the folks from meta, so we have this better algorithm for decoding. We call it flash decoding. So what you do is you want more parallelism. So again, using the same trick, we're gonna to load kv cache using parallel workers, and then we're gonna to compute this local output just as before. And then we're gonna to have a separate step that will combine the local output into the final output, again, using the same sofmax rescaling trick. But this and being more aware of the parallelism on GPU's, so this next generation, quite a bit faster, two to eight x faster generation on code lama, where the context thing is very long. So 32K to 100k. So now a lot of these use cases, for example, you want your Copilot to understand your entire code base, to generate tokens, next tokens, very, very fast, while having a very long history. So attention, inference become the bottleneck. And some of this work has gone into speeding up of this core workload. So in summary, this is a flash attention that faand memory efficient alorithm for exact attention. There are a couple ideas, are classical tiling and recomputation. And the upshot is you can get faster training, faster inference, better models with longer sequences. This is the end of the first half of the talk, so I'll possibly to see if there are questions.
speaker 1: Yeah what would you say then is the current following because it's still memory?
speaker 2: Bwidth is thrilled. Yes, I think so. I've I've spent the last five months at together AI. We've been building the infant stack. And so the baronecks are different for different use cases. So normally, if you're running on your desktop or laptop, we call it a small bsize regime where preserving one user and there the bottleneck is loading the weight matrices, not not even attention, but loading the weight matrices from the hbm to compute unit. So it's still memory is bottleneck by memory. So that's why things like quantionzation is so widespread. You know people love running like four bit, I don't know lma or mistreal or whatnot on their laptop s so llama dot cpp is really popular. You know I think there's a subretical local llama where know anytime there's a like four bit model release, like the people get very, very happy. So their quantization reduce the amount of a size of the weight matrix and it makes it either you fit into a smaller device or you run quite a bit faster. So that's on the generation for for a single bath, single user kind of local. You're using your local machine. And then there is the kind of data center regime where, let's say, your provider, so your open enai anthropic Google or something, and you're using chat, you have tons of users using ChatGPT, and they all send a request. And what you do is you bash these requests together and process them at the same time. And so depends on the regime. Sometimes you're still bottleneck by memory, reading and writing, but sometimes you are bottleneck by computation. So as you batch batching, what it does is that you still need to load the weight matrix. But then instead of multipying by a single vector to batsize one, you can multiply by 100 vectors, which is bsize 100. So in terms of loading the weight matrix, you still need to load approximately around the same amount. But in terms of computation, you need to perform, let's say, 100x mum computation. So the balance will shift between memory and computation. Depends on your bassize and lots of different use cases I won't go into, but I would think about that and in kind of you two reeps small batch versus this large batch. Small batch generally brought act by memory computation. Large batch sometimes still memory, but sometimes it's more botnet by compute. Yeah, great question. All right, so we've we've seen Yeah honcuristry, do you have any look for, let's say, some of like these attention approxmaor, some of the kernel based stuff that you're doing? Yes. Yes, great question. So I you know I brought up these approximation methods and I said, you know people don't really use them. Within the last couple of months or so, we've seen actually a resurgence of these methods thanks to now there's more efficient implementation. So I think this I've just last month, I've seen a bunch of blocos and code release on making linear attention or low ering attention actually go fast. And the idea is, again, you want to use things like the matrix multiplication unit, like how do you reformmulate the algorithm so that you can use the matrix multiplication unit? So there's kind of a resurgence to these methods. And I know Michael has been looking on linear attention for a while. So that's on the efficiency side. On the quality side, we've made a lot of progress on these approximation methods as well. You thanks some the work, for example, from Michael, that the quality now is getting close, sometimes matching the full attention. So it's quite exciting. We're making progress on both the quality and the efficiency side for these approximation methods. Yeah. P, okay. So for the last ten, 15 minutes or so, I'll just talk about how some of these ideas will generalize to other models. So how these is transformer ers dominant? And know 9090 5% of people are using transformer ers for these applications. But at the end of the day, transformers still still scale quadratically at sequence length. So we reduce the memory computation, but we're actually not reducing the amount of computation. And with inference, you still need to keep around this. Kv cacyou know we have different ways to deal with that, but you know it's still kind of a headache. And for example, at the start, tup min, we're building an infstack and dealing with the kv cash is you know 80% of our problems are from dealing with the kv cash acso. I'll talk about some other model architecture that a different non transformer. We've seen some progress recently where they do almost as well or as well maybe sometimes beating transformer on language modeling. So talk about maba selective state space. So this is joined work with Albert gou who's now a professor at cmu who he finished his PhD from Stanford last year and so you know of the a lot of the credit goes to him. Okay, so I'll just do an overview of what these model architectures look like. So the core computation, the core layer is going to be a stay space ssm, so similar to things like cnn, whether the core a core layer is convolution, you wrap it around in normalization and linaning layers and residual and so on, you get a cnn architecture for the transformer. You have the attention layer, and then you wrap around kind of the same thing and you get the transformer architecture. So similarly with state space, you have the state space model ssm, and then you wrap around kind of the same building blocked, and then you get a state space neural network ssnn. So I'll focus mostly on the ssm side, then I'll just touch on the architecture a little bit. So the architecture stayed largely the same, which is kind of swapping out the corlayer. So things like lay ying arm residuals still stathe same. Okay, so recurrent neural network, they were pretty popular from you know back in 20015, 2016 or so. They were state vr that were used for translation. I think some of the text to speech model running on phones or still using a recurrent neural network before transformer came out. So transform came out in 2017. And kind of a lot of people, most people switch over from our end, but our end is still great in some sense, they're natural at modeling this kind of sequential sequential process because what they do is they're recurrence. So know they kind of summarize the history into a hidden state and then given that hidden state, you can generate the next the next token. So they're good for inference, but they're they're not great for training. So they're not friendly to accelerators because they have the sequential dependency. Also on the quality side, it's harder to optimize for them because they have have it's harder to optimize because they have vanishing gradient problems. You can read more about that if you're interested. And so attention, no, a transformer came out and they have this dense interaction where all elements interact with with each other is kind of pairwise. So they perform really well in language. They're paralyzable. But the downside, as I mentioned, is that they have quadratic time scaling in terms of during training and during inference of generate a new token, you kind of have to keep around the entire history. So one step of inference scales leanally in the sequence like flashdecoding to make that fast. But fundamentally, should things still scale quadratically in sequence length and selective state space? This work with with Albert is our attempt to get the best of both worlds. So on the efficiency side, we would still have paralyzable training and fast inference. And on the performance on the quality side, we've seen at least the smst scale. We're able to match transformer around language models. So we're matching at the three b scale. And with lung contacts, we can improve the model, continue to improve at up to million length sequences. Okay. So just a little bit of overview on how state space work. They're and actually very classical. They are from the 19 sixties, for example, if you heard of Kalman filters in signal processing, that's the work, that kind of pioneer state space models. And they're mathematically quite elegant. They're defined by a very two very simple equations. So you have the input x and the hidden state H and the output y. And so the H is defined by a differential equation where H is varying by some factor that's controlled by the current H, and then the input x and then the output is in some sense just the projection of the hidden state. So they've been there since the sixties. They haven't really worked well for deep learning, but Albert and Kern and Chris rawith, the s four paper, they show that they can work pretty well really well for deep learning for some of the applications. So they figure out a way to put them in a deep learning architecture. So using space space as a core layer and then putting things like layer norm and lean a layer and residual together to make an architecture as for that work really well, especially for things like audio and images. So they have a bunch of advantages. So they have a continuous representation. So they great for like continuous domain, like audio, they have a recurrent representation. So they connect to rnend, but they also have a convolutional representation that allows fast training. So I won't go into the details. The idea is still you're using this very simple primitive know defined by these two equations, but if you figure out a way to run them efficiently during training, run dom efficiently during inference, and kind of put all the building blocks of modern deep learning in them, and you get an architecture that pered reasonably well. So I won't go into details, but I want to highlight one work, which is know, on the system side, how do we make them fast? And turns out they're equivalent to long convolutions. So it's equivalent to convolving with a very long, long filter and I want na. So the way you do that is on jeep, on any accelerator, you want to do the Fourier transform in order to do long convolution, you do Fourier transform, point wise multiplication and then an inverse Fourier transform. So I think that's called the convolution theorem in signal processing. Yeah, in signal processing, this is pretty classical, but turns out there's there's no fast or you can make convolution quite a bit faster you know even faster than q fft, which is this vendor library from nvidia. And so I just want to highlight some of the work done by Dan fu and Herman. Maybe Dan will one day give a talk at this seminar where about taking this kind of hardware aware of thinking from flash attention. Know Dan and I collaborate on flash attention. He took some of this thinking on flash attention and apply ed to the problem of speeding up convolution. And for example, on coup fft, you get the utilization of the hardware around 2% because, again, it's very much dominated by memory, reason, rights. And Dan and Herman use some of these. The same idea f flash attention, and they may flash fft conv and jothose utilization is like way higher, and it's close to flash attention. So I just want to highlight the know excellent work from these folks that using same ideas to show that some of these ideas are actually general. They don't just apply to attention, they also apply to convolution. So I'll skip some of this. But so the main feature of state space is that during training, you use this convolution mode. During inference, you switch to this recurrent mode. They're equivalent, they compute the same thing. But convolution is great for training. You can use, for example, some of the work that Dan has done making convolution really fast. But during inference, you know you're generating one token in a Tayou, use this recurrent mode. So with mamba a, we're kind of changing that a little bit. So the issue with state space aces that they don't work super well for for language and at least at least so far, you know I've personally have worked on this as well, and they haven't been able to match kind of the strongest transformer until more recently. So the motivation here is that you can understand the quality in terms of how much of the the history is being summarized, right? A state space or rnthat kind of summarized the entire history into a fixed context length, a fixed state vector. So this is efficient. This is great, because, you know, given this given this fixed size state vector, you can generate the next opens because you kind of have summarized the entire history already. But they generally don't do well on information dense modalities like language. On the other hand, attention, another way to view it, why it performs so well, is that it actually keeps around the entire history. So it's very easy to kind of look back what happened in the sequence, what happened in the document, you know, 2000 stesteps before. So it's great for performance, but at the same time it's inefficient because you know you keep around, you need to keep around this history. We're training. This is the cost of the quadratic scaling for inference is the cost for having to keep around this kv cso we want to use you know, kind of get the best of both worlds. And this is the selection mechanism in mamba is one way together. So the idea is that we want to model to still store the summarize the history into a fixed state vector, but we want it to be smarter about what to put in the history. So the idea is the model is able to kind of pick what to put in the history based on the input it sees. And the core idea is actually very, very simple. So with s four and so on, you know you have these abcd matrices that control the dynamics. I won't go into details too much with mamba, what with s four, these are parameters. They only they don't depend on input ts. They're parameters of the network. With mamba, what we do is we just say, Hey, let the model control how much to put in the history. And that means that you let these abcd matrices depend on input. They're functions of input. And just the simple bochange actually allow the model to do very well on language modeling. The on the efficiency side, once you have abcd depending on input, you can no longer use convolution. So for example, the excellent work that Dan has done, we use we can't longer use his kernel. So we have to find some other tricks to do this. And it's the same trick actually, which is we're going to keep this large hidden state in sram, and we never write it down. We only load in the input, and we only expand in the large hidden state in sram. So again, using kind of the same memory hierarchy idea, you can you can still do this recurrent mode really fast without having to write down this big hidden state so that again, you know kind of the same idea shows up. So I'll just show you some of the results. So we did scilling law analysis comparing transformer, this is kind of vanilla transformer that was used in GPT -3 with the recent some of the kind of transformer free models like H three and hyena and rwkv and retnet, and compared to transformer plus plus, which is kind of the strongest transformer architecture that we know of, based on a lama model and training recipe, and compared are it to mamba. And mamba here is the purple curves. So the x axis, the flops, how much computation you perform and why this is perplexity. So lower is better. And so we see mamba here is scaling about as well as transformer plus plus, sometimes beating transformer plus plus. So this is this is really exciting to us. This is, you know, mamba is kind of the first attention free model to to compete with strong modern transformer models. So this is very exciting. We also train models up to three b and compare to some of the models out there like Pythia and so on, opt and mais matching a beating transformers of similar size. So we're very excited about it. Maybe finally we have a non transformer architecture that does really well on language modeling. Of course, you know the caveat is still validated at relatively small scale at three b, and there's still a lot of open questions about know how to do inference efficiently with these state space models and you know whether they can follow instruction or whether you can quantize them efficiently and run on device and so on. So there's just lots of lots of open questions still, but this is something I'm very excited about, so I will stop here. I think we're at time. So the takeaway message is that you know if as you think about both the algorithm on the model side and the hardware, I think if you think about both of them, that kind of gives you superpower, you know you can speed things out, you can devise new architecture that runs really well, and you can you unlock a bunch of capabilities. So I think you know excellent choice you guys may to attend this, to enroll in this seminar. I think you see more of these ideas coming up and they're becoming more and more relevant to modern AI. So that I'd like to thank you for your attention, and I'll stick around and answer questions.
speaker 1: Let me add with the stream.

最新摘要 (详细摘要)

生成于 2025-05-17 13:30

概览/核心摘要 (Executive Summary)

Tri Dao 在斯坦福 MLSys 研讨会上发表了题为“面向序列建模的硬件感知算法”的演讲,深入探讨了如何通过软硬件协同设计来提升长序列模型的效率和能力。演讲核心围绕两大创新:FlashAttention 和 Mamba。

FlashAttention 是一种精确注意力机制,它通过识别并解决 GPU 内存读写(而非浮点运算)这一核心瓶颈,实现了显著的性能提升。通过采用分块(Tiling)、Softmax 重缩放(Softmax Rescaling)和重计算(Recomputation)等IO感知技术,FlashAttention 在不牺牲模型精度(精确注意力)的前提下,将注意力计算速度提升4-8倍,内存占用从二次方降低到线性,从而支持 Transformer 处理4-16倍更长的上下文,并训练出更高质量的模型。FlashAttention 及其后续版本 FlashAttention 2 和针对推理优化的 FlashDecoding 已被业界广泛采用。

Mamba 是一种新型的无注意力选择性状态空间模型(Selective State Space Model, SSM),旨在解决 Transformer 固有的二次方计算复杂度和推理时 KV 缓存管理的难题。Mamba 的核心思想是引入一种选择机制,使模型的 SSM 参数能够根据输入动态调整,从而更智能地压缩和利用历史信息。尽管这使得传统的快速卷积训练方法失效,但 Mamba 采用了一种硬件感知的并行循环算法,在 GPU SRAM 中高效处理隐藏状态,避免了大量的 HBM 读写。实验结果表明,Mamba 在语言建模任务上能够达到甚至超越同期最强 Transformer 模型的性能,并展现出处理百万长度序列的潜力,为超越 Transformer 架构提供了新的可能性。Tri Dao 强调,理解并结合算法与硬件特性是推动现代 AI 发展的关键。

引言与背景

演讲者 Tri Dao 及其研究方向

  • Tri Dao:普林斯顿大学即将入职的助理教授,同时也是 Together AI 的首席科学家。他在斯坦福大学获得计算机科学博士学位,导师为 Christopher Ré 和 Stefano Ermon。
  • 研究领域:专注于机器学习与系统的交叉领域,特别是具有长程记忆的序列模型和用于紧凑深度学习模型的结构化矩阵。其工作曾获 ICML 2022 杰出论文亚军。

机器学习的进展与规模化的挑战:效率

  • 近期进展:Tri Dao 指出,机器学习近年来取得了显著进展,例如代码修复、艺术生成(如 Stable Diffusion)、蛋白质结构预测(如 AlphaFold)等。
  • 规模驱动力:他认为“规模(Scale)带来了质量和能力”,模型和数据规模在过去五年增长了约1000倍(例如,从 BERT 的3亿参数到 GPT-4 据称的万亿参数)。更大规模的模型不仅在现有基准上表现更好,还涌现出新的能力,如语言模型解释笑话的能力随参数规模增大而增强。
  • 核心挑战:效率
    • 实践层面:提高效率可以简化模型训练与部署,促进研究(例如,开发出与大型模型性能相当的小型快速模型)。
    • 能力层面:效率提升能解锁新功能,例如处理更长的上下文(从 GPT-3 到 GPT-3.5 的16K上下文,再到 GPT-4 的128K上下文)。Tri Dao 强调:“这些进步的关键在于效率的提升。”

硬件感知算法的重要性

结合算法与系统理解效率

Tri Dao 的核心方法论是同时理解算法层面(如矩阵向量乘法、注意力机制)和系统层面(如硬件加速器GPU、分布式系统)。他强调:“你需要理解这些面向块的设备(如GPU),以及它们对模型设计和算法设计的影响。”

IO 感知:减少 GPU 内存读写

  • 核心思想:硬件感知算法应充分利用其运行硬件的特性。
  • IO 感知 (IO-awareness):针对 GPU 内存,关键在于减少高带宽内存(HBM)的读写次数,因为这通常是性能瓶颈。
    • FlashAttention:作为 IO 感知的典范,实现了快速且内存高效的精确注意力计算。
    • Mamba 中的应用:通过将循环状态扩展到 SRAM 中,避免了高昂的内存成本。

FlashAttention:高效精确注意力机制

动机:长序列建模的需求与 Transformer 的瓶颈

  • 长序列应用广泛
    • 自然语言处理 (NLP):处理书籍、剧本、代码库等长文本。Tri Dao 提到 GPT-4 演示中模型能理解50-100页的文档。
    • 计算机视觉 (CV):高分辨率图像通常带来更好的结果,但也意味着更长的序列。
    • 其他领域:时间序列、音频、视频、医学影像(如组织病理学图像,需要极高分辨率,序列长度可达百万级)。
  • Transformer 的瓶颈:尽管 Transformer 是主流架构,但在处理长序列时效率低下。
    • 二次方复杂度:注意力机制的时间和内存复杂度均与序列长度 N 成二次方关系 (O(N^2))。“序列长度加倍意味着计算量增加4倍,内存增加4倍。”
    • 实际影响:Tri Dao 展示了 MegatronLM 在上下文长度从2K增加到8K时,训练速度显著下降,甚至出现内存不足(OOM)的情况。

传统近似方法的局限性

  • 背景:为解决二次方复杂度问题,学术界提出了许多近似注意力的方法,如稀疏注意力和低秩近似,旨在通过牺牲一定质量来换取速度。
  • 实践中的问题:Tri Dao 指出,大型模型训练的实践者通常不采用这些近似方法,原因有二:
    1. 质量下降:近似导致模型性能变差。
    2. 实际加速效果不佳:“更重要的是,它们通常并不会更快,或者并不能节省内存。” 这是因为“更少的浮点运算(FLOPs)未必转化为更短的墙钟时间(wall-clock time)。”

真正的瓶颈:内存读写而非浮点运算

  • 性能剖析的重要性:通过对 GPU 代码进行性能剖析 (profiling),Tri Dao 发现,标准注意力实现的主要瓶颈在于内存读写,而非浮点运算。“最大的成本在于移动比特。”
  • 具体原因:标准实现需要反复将 N x N 大小的注意力矩阵(或其中间结果)读写到 GPU 的高带宽内存 (HBM),这消耗了大部分时间。

GPU 内存层级与 IO 感知策略

  • GPU 内存结构
    • HBM (High Bandwidth Memory):即通常所说的 GPU 显存,容量较大(如40GB-80GB),带宽较高(如 A100 为 1.5TB/s),但相对于 SRAM 较慢。
    • SRAM (On-chip Static RAM):片上缓存,容量小得多,但速度比 HBM 快一个数量级,紧邻计算单元。
    • 计算单元 (Compute Units):执行实际运算。
  • IO 感知策略:核心在于“尝试减少对 HBM 的读写,并进行大量对 SRAM 的读写。”

FlashAttention 核心技术

  • 目标:实现快速、内存高效且精确(无近似)的注意力算法。
  • 主要挑战
    1. Softmax 操作的归一化因子耦合了整行数据,难以分块计算。
    2. 反向传播时计算梯度通常需要存储前向传播产生的 N x N 注意力矩阵。
  • FlashAttention 采用的经典技巧及其创新应用
    1. 分块 (Tiling) / 核融合 (Kernel Fusion):将计算分解为小块。从 HBM 加载一个块到 SRAM,在 SRAM 中完成所有相关计算,然后写回结果。
      • Softmax 重缩放技巧 (Softmax Rescaling):这是关键的数学技巧,允许在分块计算的同时得到与标准 Softmax 完全一致的结果。
        • 过程:对第一个块计算局部 Softmax (使用局部归一化因子 L1),得到中间输出 O1。对于后续块,更新归一化因子(例如 L2 = L1 + 第二块的指数和),然后重缩放 O1 (乘以 L1/L2) 并加上当前块的贡献。
        • 实现:中间结果 O1 存储在 SRAM 或寄存器中,不写入 HBM,直到所有块处理完毕。
        • 数值稳定性:该技巧也考虑了 Softmax 计算中为保证数值稳定性而进行的“减去最大值”操作。
    2. 重计算 (Recomputation) 用于反向传播:在前向传播时不存储完整的 N x N 注意力矩阵 S 和 P (Softmax 输出),而是在反向传播时根据存储的输出 O 和 Softmax 归一化统计量(行和与行最大值)重新计算它们。
      • 原理:“计算是廉价的,内存读写是昂贵的。” FlashAttention 实际上执行了更多的浮点运算(约13%),但大幅减少了内存读写(约9倍),最终获得了约6倍的运行时加速。

FlashAttention 效果与进展

  • 性能提升:比当时最优的基线实现快2-4倍。
  • 内存优化:内存占用从 O(N^2) 降低到 O(N)。
  • 赋能长上下文:使得在8K甚至更长上下文长度下训练 Transformer 成为可能,并带来了模型质量的提升。
  • FlashAttention 2
    • 进一步优化:包括更好的并行性、不同 Warp 间的工作划分、减少非矩阵乘法(non-matmul)的 FLOPs、减少并行工作线程间的通信。
    • 性能:通常比 FlashAttention 1 快约2倍。
    • 集成:已集成到 PyTorch 2.2 和 Hugging Face Transformers 库中。

FlashDecoding:针对长上下文 LLM 推理的优化

  • 推理瓶颈:在长上下文推理(如代码生成)时,瓶颈在于加载巨大的键值缓存 (KV cache)。此时查询 (Query) 通常很短(例如一个 token),而键 (Key) 和值 (Value) 的历史序列可能非常长。
  • FlashAttention 的局限:标准 FlashAttention 按块顺序处理,对于短查询并行度不足。
  • FlashDecoding 方案(与 Meta xFormers 团队合作):
    • 并行加载 KV 缓存块。
    • 并行计算局部输出。
    • 使用 Softmax 重缩放技巧合并局部输出得到最终结果。
  • 效果:在 Code Llama (32K-100K 上下文) 上实现2-8倍的生成速度提升。

Mamba:选择性状态空间模型

动机:克服 Transformer 的二次方计算瓶颈和 KV 缓存问题

  • Transformer 的残留问题:尽管 FlashAttention 优化了内存,但 Transformer 的计算量仍是 O(N^2)。推理时,KV 缓存的管理依然是“一个令人头痛的问题”,“处理 KV 缓存占据了我们(Together AI)80%的问题。”
  • 目标:探索非 Transformer 架构,寻求更优的扩展性。

状态空间模型 (SSM) 基础

  • 背景:SSM 是一种经典控制理论模型(可追溯至1960年代的卡尔曼滤波器),由状态方程 (H' = AH + Bx) 和观测方程 (y = CH + Dx) 定义。
  • 深度学习中的 SSM (S4):Albert Gu 等人的 S4 论文展示了将 SSM 作为深度学习核心层的潜力,通过包裹层归一化、线性层和残差连接等现代深度学习模块,构建了 S4 架构,在音频、图像等连续信号领域表现优异。
  • SSM 的优势
    • 连续表示。
    • 循环表示(类似 RNN)。
    • 卷积表示(可通过 FFT 实现快速并行训练)。
  • 硬件感知的 SSM 加速 (FlashFFTConv):Dan Fu 和 Herman [原文未明确姓氏] 将 FlashAttention 的硬件感知思想应用于加速长卷积(SSM 训练的关键运算),通过优化 FFT 实现,大幅提升了硬件利用率,远超 Nvidia 的 cuFFT。

Mamba 的核心创新:选择机制

  • SSM 在语言建模上的局限:传统的 SSM(如 S4)将整个历史信息压缩到一个固定大小的隐藏状态中,这对于信息密集的模态(如语言)可能不够用,难以匹敌 Transformer 的性能。
  • Mamba 的洞察:Transformer之所以强大,部分原因在于它能“回顾完整的历史记录”。
  • 选择机制 (Selection Mechanism):Mamba 的核心思想是让模型能够更智能地决定哪些信息应该被保留在(固定大小的)状态向量中。
    • 实现方式:使 SSM 的核心参数矩阵 A, B, C, D 依赖于当前输入 x。即 A(x), B(x), C(x), D(x)。
    • 效果:“这个简单的改变使得模型在语言建模上表现非常出色。”

Mamba 的高效计算

  • 挑战:由于参数 A,B,C,D 依赖于输入,SSM 不再是线性时不变系统,因此无法再使用快速的卷积模式进行训练(如 FlashFFTConv)。
  • Mamba 的解决方案(硬件感知的循环计算)
    • 必须在循环模式下进行高效计算。
    • 采用并行扫描算法,将(可能因输入依赖而变大的)隐藏状态 H 保存在 GPU 的 SRAM 中进行迭代更新,而避免将其完整写回 HBM。
    • Tri Dao 强调:“同样是利用内存层级的思想。”

Mamba 的架构与性能

  • 架构特点:Mamba 构建了一个简化的端到端神经网络架构,“没有注意力机制,甚至没有 MLP 模块”,直接集成了选择性 SSM。
  • 性能表现
    • 扩展性:在 Scaling Law 分析中,Mamba(紫色曲线)的性能扩展趋势与当前最强的 Transformer++ 基线(基于 Llama 模型和训练配方)相当,甚至在某些点上更优。
    • 对比结果:“Mamba 是第一个在性能上能与强大的现代 Transformer 模型竞争的无注意力模型。”
    • 在高达 3B 参数规模的训练中,Mamba 匹配或超越了类似规模的 Transformer 模型(如 Pythia, OPT)。
    • 长上下文潜力:模型性能可持续提升至百万长度的序列。

Mamba 的前景与开放问题

  • 积极意义:“我们或许终于有了一个在语言建模上表现优异的非 Transformer 架构。”
  • 待解决的问题
    • 目前验证主要在相对较小的 3B 规模。
    • 高效推理的进一步优化。
    • 指令遵循 (instruction following) 能力。
    • 量化和设备端运行的效率。

问答环节要点

  • 早期注意力优化:被问及为何 Nvidia 等公司早期未采用类似 FlashAttention 的思路,Tri Dao 认为,他们确实优化了注意力,但可能更多是从纯硬件角度(如融合操作),而 FlashAttention 的突破结合了系统视角和关键的数学技巧(Softmax 重缩放)。
  • AI 定制芯片:讨论了 AI 芯片向专用化(如矩阵乘法单元)发展的趋势。为 Transformer 等特定架构定制硬件是“有风险的赌注”,但也在发生(如 Nvidia Transformer Engine)。
  • 研究焦点转变:关于为何研究社区曾长期关注 FLOPs 优化而非实际瓶颈,Tri Dao 解释说存在不同动机,但近期由于 LLM 的实用性,对实际运行速度的关注度显著提升。
  • TPU vs. GPU:TPU 架构相对更简单,专注于矩阵乘法;GPU 则更具并行性(源于图形处理需求),但也逐渐针对 AI 负载优化。
  • 当前推理瓶颈
    • 小批量(如桌面端):主要瓶颈是加载模型权重,因此量化技术非常流行。
    • 大批量(如数据中心):瓶颈可能是内存读写,也可能是计算,取决于具体的批次大小。
  • 注意力近似方法:Tri Dao 提到,由于出现了更高效的实现(如利用矩阵乘法单元加速线性注意力),近似注意力方法近期有“复苏”迹象,并且其模型质量也在提升。

核心结论

Tri Dao 最后总结道:“当你同时思考算法/模型层面和硬件层面时,这会赋予你‘超能力’,你可以加速事物,设计出运行良好的新架构,并解锁一系列新的能力。”他鼓励听众关注机器学习与系统交叉领域的发展。