Latent Space: The AI Engineer Podcast

FlashAttention 2: making Transformers 800% faster w/o approximation - with Tri Dao of Together AI

FlashAttention was first published by Tri Dao in May 2022 and it had a deep impact in the large language models space. Most open models you’ve heard of (RedPajama, MPT, LLaMA, Falcon, etc) all leverage it for faster inference. Tri came on the podcast to chat about FlashAttention, the newly released FlashAttention-2, the research process at Hazy Lab, and more.

This is the first episode of our “Papers Explained” series, which will cover some of the foundational research in this space. Our Discord also hosts a weekly Paper Club, which you can signup for here.

How does FlashAttention work?

The paper is titled “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”. There are a couple keywords to call out:

* “Memory Efficient”: standard attention memory usage is quadratic with sequence length (i.e. O(N^2)). FlashAttention is sub-quadratic at O(N).

* “Exact”: the opposite of “exact” in this case is “sparse”, as in “sparse networks” (see our episode with Jonathan Frankle for more). This means that you’re not giving up any precision.

* The “IO” in “IO-Awareness” stands for “Input/Output” and hints at a write/read related bottleneck.

Before we dive in, look at this simple GPU architecture diagram:

The GPU has access to three memory stores at runtime:

* SRAM: this is on-chip memory co-located with the actual execution core. It’s limited in size (~20MB on an A100 card) but extremely fast (19TB/s total bandwidth)

* HBM: this is off-chip but on-card memory, meaning it’s in the GPU but not co-located with the core itself. An A100 has 40GB of HBM, but only a 1.5TB/s bandwidth.

* DRAM: this is your traditional CPU RAM. You can have TBs of this, but you can only get ~12.8GB/s bandwidth, which is way too slow.

Now that you know what HBM is, look at how the standard Attention algorithm is implemented:

As you can see, all 3 steps include a “write X to HBM” step and a “read from HBM” step. The core idea behind FlashAttention boils down to this: instead of storing each intermediate result, why don’t we use kernel fusion and run every operation in a single kernel in order to avoid memory read/write overhead? (We also talked about kernel fusion in our episode with George Hotz and how PyTorch / tinygrad take different approaches here)

The result is much faster, but much harder to read:

As you can see, FlashAttention is a very meaningful speed improvement on traditional Attention, and it’s easy to understand why it’s becoming the standard for most models.

This should be enough of a primer before you dive into our episode! We talked about FlashAttention-2, how Hazy Research Group works, and some of the research being done in Transformer alternatives.

Show Notes:

* FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arXiv)

* FlashAttention-2

* Together AI

* From Deep Learning to Long Learning

* The Hardware Lottery by Sara Hooker

* Hazy Research

* Is Attention All You Need?

* Nvidia CUTLASS 3

* SRAM scaling slows

* Transformer alternatives:

* S4

* Hyena

* Recurrent Neural Networks (RNNs)

Timestamps:

* Tri's background [00:00:00]

* FlashAttention’s deep dive [00:02:18]

* How the Hazy Research group collaborates across theory, systems, and applications [00:17:21]

* Evaluating models beyond raw performance [00:25:00]

* FlashAttention-2 [00:27:00]

* CUDA and The Hardware Lottery [00:30:00]

* Researching in a fast-changing market [00:35:00]

* Promising transformer alternatives like state space models and RNNs [00:37:30]

* The spectrum of openness in AI models [00:43:00]

* Practical impact of models like LLAMA2 despite restrictions [00:47:12]

* Incentives for releasing open training datasets [00:49:43]

* Lightning Round [00:53:22]

Transcript:

Alessio: Hey everyone, welcome to the Latent Space podcast. This is Alessio, Partner and CTO-in-Residence at Decibel Partners. Today we have no Swyx, because he's in Singapore, so it's a one-on-one discussion with Tri Dao. Welcome! [00:00:24]

Tri: Hi everyone. I'm Tri Dao, excited to be here. [00:00:27]

Alessio: Tri just completed his PhD at Stanford a month ago. You might not remember his name, but he's one of the main authors in the FlashAttention paper, which is one of the seminal work in the Transformers era. He's got a lot of interest from efficient transformer training and inference, long range sequence model, a lot of interesting stuff. And now you're going to be an assistant professor in CS at Princeton next year. [00:00:51]

Tri: Yeah, that's right. [00:00:52]

Alessio: Yeah. And in the meantime, just to get, you know, a low pressure thing, you're Chief Scientist at Together as well, which is the company behind RedPajama. [00:01:01]

Tri: Yeah. So I just joined this week actually, and it's been really exciting. [00:01:04]

Alessio: So what's something that is not on the internet that people should know about you? [00:01:09]

Tri: Let's see. When I started college, I was going to be an economist, so I was fully on board. I was going to major in economics, but the first week I was at Stanford undergrad, I took a few math classes and I immediately decided that I was going to be a math major. And that kind of changed the course of my career. So now I'm doing math, computer science, AI research. [00:01:32]

Alessio: I had a similar thing. I started with physics and then I took like a programming course and I was like, I got to do computer science. I don't want to do physics. So FlashAttention is definitely, everybody's using this. Everybody loves it. You just released FlashAttention 2 last week. [00:01:48]

Tri: Yeah. Early this week on Monday. Yeah. [00:01:53]

Alessio: You know, AI time. Things move fast. So maybe let's run through some of the FlashAttention highlights, some of the innovation there, and then we can dive into FlashAttention 2. So the core improvement in FlashAttention is that traditional attention is a quadratic sequence length. And to the two, FlashAttention is linear, which obviously helps with scaling some of these models. [00:02:18]

Tri: There are two factors there. So of course the goal has been to make attention go faster or more memory efficient. And ever since attention became popular in 2017 with the Transformer paper, lots and lots of folks have been working on this. And a lot of approaches has been focusing on approximating attention. The goal is you want to scale to longer sequences. There are tons of applications where you want to do that. But scaling to longer sequences is difficult because attention scales quadratically in sequence length on both runtime and memory, as you mentioned. So instead of trying to approximate attention, we were trying to figure out, can we do the same computation and maybe be more memory efficient? So in the end, we ended up being the memory is linear in sequence length. In terms of computation, it's still quadratic, but we managed to make it much more hardware friendly. And as a result, we do get wall clock speed up on the order of 2 to 4x, which really helps because that just means that you'll be able to train with 2 to 4x longer sequence length for the same cost without doing any approximations. As a result, lots of folks have been using this. The thing is available in a lot of libraries that do language model training or fine tuning. [00:03:32]

Alessio: And the approximation thing is important because this is an exact thing versus a sparse. So maybe explain a little bit the difference there. [00:03:40]

Tri: For sure. So in addition, essentially you compute pairwise similarity between every single element in a sequence against each other. So there's been other approaches where instead of doing all that pairwise computation, you only compute similarity for some pairs of elements in the sequence. So you don't do quadratic number of comparison. And this can be seen as some form of sparsity. Essentially you're ignoring some of the elements. When you write down the matrix, you essentially say, OK, I'm going t