Tri Dao, Stanford: On FlashAttention and sparsity, quantization, and efficient inference

February 8, 2024

RSS · Spotify · Apple Podcasts · Pocket Casts

Tri Dao is a PhD student at Stanford, co-advised by Stefano Ermon and Chris Re. He’ll be joining Princeton as an assistant professor next year. He works at the intersection of machine learning and systems, currently focused on efficient training and long-range context.

Below are some highlights from our conversation as well as links to the papers, people, and groups referenced in the episode.

Some highlights from our conversation

“I think there are many paths to a high-performing language model. So right now there’s a proven strategy and people follow that. I think that doesn’t have to necessarily be the only path. I think my prior is that as long as your model architecture is reasonable and is hardware efficient, and you have lots of compute, and you have lots of data, the model would just do well.”

“So we’ve seen that sparsity now is proven to be more useful as people think about hardware-friendly sparsity. I would say the high-level point is we show that there are ways to make sparsity hardware-friendly and there are ways to maintain quality while using sparsity.”

“So I think there’s gonna be a shift towards focusing a lot on inference. How can we make inference as efficient as possible from either model design or software framework or even hardware? We’ve seen some of the hardware designs are more catered to inference now—think, for example, Google TPU has a version for inference, and has a different version for training where they have different numbers of flops and memory bandwidth and so on.”

“So we want to understand, from an academic perspective, when or why do we need attention. Can we have other alternatives that scale better in terms of sequence length? Because the longer context length has been a big problem for attention for a long time. Yes, we worked on that. We spent tons of time on that. I looked around and maybe it’s a contrarian bet that I wanna work on something that maybe scaled better in terms of sequence length that, maybe in two to three years, would have a shot at not replacing transformer but augmenting transformer in some settings.”

Referenced in this podcast

Thanks to Tessa Hall for editing the podcast.