橘智橘智
FakeOrange
预计阅读时间:12分钟28秒

从原理出发理解 Flash Attention Part1

你需要掌握的所有基础知识!

0
0


前言


本文属于搬运内容,原作者:Sascha Kirch,原文链接。



正文


这是我新系列文章《从原理出发理解 FlashAttention》的第一部分。第二部分点击这里!


你可能已经听过很多次:“Attention 在序列长度 NNN 上的扩展性很差,计算复杂度是 O(N2)。”但到底是什么随着序列长度扩展性变差?为什么会这样?真的是序列长度本身造成的问题吗?在本系列中,我们将深入剖析 Tri Dao 等人在 2022 年发表的 FlashAttention 论文,探讨为什么 Attention 在现代 GPU 上运行缓慢且资源消耗巨大,以及 FlashAttention 是如何解决这一问题的。


这原本只是想写一篇简短的论文综述,但在写作和深入研究的过程中我意识到,目前很多资料都非常表面,往往默认你已经掌握了 Attention 机制、GPU 架构、CUDA 编程模型,以及大量的数学知识。因此我决定将文章拆分为两个部分。


在第一部分中,我们将介绍 Attention 机制、GPU 架构与 CUDA 编程模型,以及为什么 Attention 在现代 GPU 上如此缓慢且资源密集,并解释为什么我们无法直接使用 CUDA 编程中常见的优化方法。


在第二部分中,我们将深入介绍 FlashAttention,在第一部分的基础上,一步步详细剖析其前向传播与反向传播算法的细节。


🚀 那么,话不多说,让我们开始吧!



1. 为什么我们要关注这个问题?


在深入细节之前,让我们先思考一个问题:为什么这个话题值得我们关注?


Attention 机制是 Transformer 架构的核心,而 Transformer 又是如今几乎所有深度学习领域的基石。随着模型规模不断扩大,其训练和推理的时间与成本也随之飙升。因此,如果我们能让 Attention 更快、更高效,我们就能在更短的时间内、用更少的成本训练和部署更大的模型,甚至可能实现原本难以训练的模型。


我认为下图很好地说明了这个问题的重要性。它展示了 FlashAttention 相较于之前方法在性能上的提升,以及其在处理更长序列长度时的表现。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/89fbimage.png

图 1:FlashAttention 相比其他方法的提升情况。图像作者:Sascha Kirch


如图所示,FlashAttention 明显快于以往方法,而且在其他方法会因内存耗尽而失败的序列长度上,FlashAttention 仍能高效运行。


明白这一点后,让我们先简要回顾一下 Attention 机制。



2. Attention 简要回顾


Attention 机制是 Transformer 架构的核心部分,最早由 Vaswani 等人在论文《Attention is All You Need》(2017)中提出。从高层次来看,Attention 允许模型在处理某个特定 token 时,动态地聚焦输入序列中的不同部分,同时忽略其他部分。这里的 token 指的是输入序列中的单个元素,可以是一个单词、一块像素块,甚至是图中的一个节点。


对于输入序列中的每一个 token,Attention 机制都会计算它与所有其他 token 的关联得分(Attention Score)。


最初,Attention 是为了处理文本序列而提出的,但如今它已被广泛应用于各种任务和数据模态,如图像、视频和图结构数据,成为了一种通用的、用于捕捉序列元素间依赖关系的机制。


为了帮助理解,我们来看一个简单的示例:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/49edimage.png

图 2:文本、图像和图结构数据中 Attention 机制的直观示意图。图像作者:Sascha Kirch


非常简化地说,我喜欢把 Attention 想象成一种向数据提问的机制,而 Attention Score 则是对这些问题的回答。Attention Score 越高,表示对于当前 token 来说,该信息越重要。多头注意力机制(Multi-Head Attention, MHA)可以理解为:对于同一个 token,同时提出多个不同的问题,并将这些问题的答案组合在一起。后续我们会详细讲到。


带着这样的直觉,我们继续深入了解 Transformer 编码器(Encoder)结构及 Attention 机制的细节。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/5858image.png

图 3:将 Attention 机制放置在 Transformer 编码器的整体结构中。图像作者:Sascha Kirch


Transformer 的编码器由若干个相同结构的堆叠模块(Stage)组成。每一个模块包括两个子部分:一个多头自注意力(MHA)层和一个多层感知机(MLP)层。

MHA 子模块实现了多个并行的注意力头(Attention Heads),它们会被拼接(Concatenation)并通过线性变换(Linear Projection)恢复到期望的隐藏维度。


进一步深入到单个 Attention 头内部,我们可以看到 Attention 机制的基本架构,它可以通过下面的公式描述:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/6312image.png

图 4:Attention 机制的数学表达式。图像作者:Sascha Kirch


注意:Dropout 和 Masking 是可选项,不过因为 FlashAttention 中提到了它们,所以这里我们也一并提及。


由于本文并不是专门讲 Attention 机制本身,因此这里不会展开详细讨论。如果你希望进一步深入了解 Attention,建议阅读 Vaswani 等人的原论文(2017)Jay Alammar 的著名博客文章。


那么,我们真正关心的问题是:


为什么 Attention 机制在现代 GPU 上会如此缓慢且资源消耗巨大?


为了解答这个问题,让我们先来看看 Attention 机制中涉及到的矩阵的形状(Shape)。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/3343image.png

图 5:标注 Attention 机制中各矩阵的形状。图像作者:Sascha Kirch


这里最关键的地方在于:Attention 机制在中间计算中会生成大小为 N×N的矩阵,而最终我们关心的输出只是大小为 N×d 的矩阵,其中 N 是序列长度,d 是隐藏层的维度。这就是为什么 Transformer 通常被认为在序列长度 NNN 上具有 O(N2) 的计算和内存复杂度。


虽然图中以 N=4096 举例(那时的常规规模),但论文中提到,FlashAttention 将其扩展到了 N=65,536,这会导致中间结果高达 42.9 亿 个数值!


需要特别注意的是,图 5 展示的只是单个 Attention 头在单层 Transformer 中的情况。而在实际应用中,一个 Transformer 通常有多个头(Multi-Head)和多层堆叠(Multi-Layer),这会使计算复杂度进一步指数级增长。


为了进一步理解为什么 Attention 机制在现代 GPU 上运行缓慢且资源密集,

我们接下来需要简要了解一下 GPU 硬件架构CUDA 编程基础


别担心,这部分我们会保持简明易懂;我保证。😉



3. 现代 GPU 与 CUDA 设备架构


3.1 GPU 架构


我们先从比较 CPUGPU 的架构开始。


CPU 由少量核心组成,主要优化于顺序处理(sequential processing);而 GPU 拥有大规模并行(massively parallel)的架构,由数千个更小、更高效的核心组成,专门用于同时处理多个任务

此外,CPU 和 GPU 都拥有一套分层内存体系(memory hierarchy),不同层级的内存在容量和访问速度上存在显著差异。这一点在后面讲到 FlashAttention 如何实现极大加速时会非常关键。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/03b7image.png

图 6:CPU 与 GPU 架构对比。图像作者:Sascha Kirch


这种通用 GPU 架构(General Purpose GPU,GPGPU)配合 CUDA 编程语言,最初是在 2006 年由 Nvidia 的 Tesla 架构引入的,其中的许多核心概念至今仍然是现代 GPU 设计的基础。

我强烈推荐阅读 Nvidia 关于 Tesla 架构的官方白皮书,以获取更深入了解。


GPU 的核心构件是 流式多处理器(Streaming Multiprocessor,简称 SM),

每个 SM 包括:

  • 多个 CUDA 核心 和 Tensor 核心
  • 共享内存(Shared Memory)
  • 寄存器(Registers)

每个 SM 能够并行执行多个线程块(blocks)。

一个线程块内的线程可以通过共享内存互相通信,而不同线程块之间的线程不能直接通信

这里提前剧透一下:
因为有了成千上万个线程,GPU 允许我们将像矩阵乘法(matrix multiplication)和卷积(convolution)这样的大规模重计算任务进行极致并行化。

具体做法是:将每个线程分配到处理一小块数据的一小部分计算上,从而加速整个大计算任务。



3.2 CUDA 编程模型


Nvidia 的 GPU 使用 CUDA 编程模型进行编程,它是 C++ 编程语言的扩展。

通过 CUDA,开发者可以编写 CUDA 内核函数(kernel),这些内核函数由 GPU 上的每个线程独立执行。

这些线程被组织成线程块(blocks),而线程块又组成了网格(grid)


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/8a72image.png

图 7:CUDA 线程层次结构。图像作者:Sascha Kirch


这种线程层次结构非常巧妙,因为它模拟了我们要处理的数据结构的形状,比如向量(vector)、矩阵(matrix)和张量(tensor)。

此外,每个线程都可以识别自己在网格和线程块中的位置,这使得线程能够准确地计算出自己要处理的数据位置。


我们还可以将 CUDA 编程模型映射到 GPU 硬件架构上,来看数据在不同层级中是如何存储和访问的。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/5273image.png

图 8:CUDA 编程模型与 GPU 架构的对应关系。图像作者:Sascha Kirch


为了引入 FlashAttention 中的内核融合(Kernel Fusion)概念,我们需要理解 CPU 如何与 GPU 交互,以及数据是如何在它们之间流动的。 首先来看一下所谓的异构编程模型(Heterogeneous Programming Model)

也就是在不同设备上(如 CPU 和 GPU)执行程序。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/37bcimage.png

图 9:使用 CPU 和 GPU 的异构编程模型。图像作者:Sascha Kirch


在这种模型下:

  • CPU 执行串行(顺序)代码,负责数据在主机(Host)和设备(Device)之间的传输,并且负责启动(launch)GPU 内核。
  • GPU 负责并行执行内核,通过在设备上启动网格中的多个线程块来完成。

执行完毕后,GPU 会将结果返回给 CPU,CPU 再继续自己的串行处理,并可能启动下一个内核。


这里的关键是:
在朴素(naive)实现中,每一个内核(kernel)都需要在主机和设备之间进行数据传输。


结合 Attention 机制来看,这些顺序执行的内核可能分别对应:

  • 矩阵乘法(matrix multiplication)
  • 缩放(scaling)
  • 掩码(masking)
  • SoftMax
  • Dropout


进一步放大来看,单个内核执行时,数据在 CPU 和 GPU 之间是如何流动的?

以及 GPU 上的核心(cores,也就是线程)是如何访问存储在 GPU 全局内存(global memory)中的数据?data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/00cfimage.png

图 10:单个 CUDA 内核执行时的数据流动示意图。图像作者:Sascha Kirch


当然,上图是极度简化版,但它传达了一个重要的信息:

数据在 CPU 和 GPU 之间、以及 GPU 内部,需要频繁传递,而且核心经常在等待数据到达。



3.3 CUDA 中的矩阵乘法(Matrix Multiplication)


呼~前面信息量有点大,现在我们来看一个例子来加深理解。
矩阵乘法 是深度学习中最重要的基本操作之一,也是 CUDA 加速的典型场景。


如你所知,矩阵乘法(以 C=A×B 为例)意味着:

  • 取矩阵 A 的每一行
  • 与矩阵 B 的每一列做点积(dot product)

在最朴素(且极其低效)的 CUDA 实现中,我们可以简单地:

  • 给每个线程分配一个任务,计算结果矩阵 C 中的一个元素。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/aaf2image.png

图 11:矩阵乘法示意图。图像作者:Sascha Kirch,灵感来源于 Simon Boehm


注意:
计算结果矩阵 C 的一整行,需要使用矩阵 A 的一行和矩阵 B 的所有列。


为了更深入理解内核执行,我们需要稍微绕远一点,了解**不同内存带宽(memory bandwidth)**的概念。


执行一个内核时,一般流程是:

  1. 创建线程块网格(thread block grid)
  2. 将内核加载到处理器核心中
  3. 从全局内存(global memory)读取数据到处理器的寄存器(registers)
  4. 执行计算
  5. 将结果写回全局内存


一般来说,内存容量越大,离处理器越远,速度越慢

因此:

  • 全局内存(global memory):最大但最慢
  • 寄存器(registers):最小但最快


让我们用一个例子来体会一下。


以 Nvidia A100 GPU 为例,它的内存带宽达到 1.5 TB/s(听起来已经很快了)。

但即使如此,实际数据量也是巨大的。


假设我们要计算两个大小为 N×N 的矩阵相乘,设 N=4096,且每个元素是 float32(4字节):

  • 读取矩阵 A:4096×4096×4 字节
  • 读取矩阵 B:4096×4096×4 字节
  • 写入矩阵 C:4096×4096×4 字节

总共:

  • 读取数据:2 × 4096 × 4096 × 4 bytes = 134MB
  • 写入数据:4096 × 4096 × 4 bytes = 64MB


但是,如果按照朴素实现方法——
每处理一行 A,就重新读取整个矩阵 B —— 那么:

  • 134MB × 4096 = 549GB 的数据流量! 😅

而且,这还是针对小窗口 N=4096N=4096N=4096 的情况!



❓那么,知道了这些之后,问题来了:


我们要怎样做,才能让 GPU 上的计算更快、更高效呢?



3.4 内核执行优化(Optimizations for Kernel Execution)


你可能已经猜到了,我们的主要关注点是:

减少对全局内存(global memory)的读写次数, 因为全局内存是整个内存层级中最慢的,且核心(cores)经常在等待数据加载。


为了达到这个目标,有几种常见的技术,包括:

  • 共享内存(Shared Memory) —— 把常用数据保存在核心附近;
  • 块切分(Block Tiling) —— 重复利用已经加载到内存中的数据,计算多个输出;
  • 内核融合(Kernel Fusion) —— 避免把中间结果写入后又读取;


需要注意的是,为了保持简明,这里我们会适当地简化一些细节,

但足够为理解 FlashAttention 的贡献打下坚实基础。

如果你想了解更深入的 CUDA 矩阵乘法优化细节,我推荐阅读 Simon Boehm 的精彩博客,非常直观且深入。


接下来,我们分别介绍上述优化技术,并结合矩阵乘法(Matrix Multiplication)来说明它们的应用:


共享内存(Shared Memory)


把常用数据保存在核心附近

首先,我们要充分利用 GPU 内存层级中的不同存储层。
共享内存(Shared Memory) 是一种小容量、极快速度的内存,它在同一个线程块(block)内的所有线程之间共享(位于同一个流式多处理器 SM 中)。


根据 FlashAttention 论文中关于 A100 GPU 的数据:

  • 全局内存(40GB)带宽:1.5 TB/s
  • 共享内存(20MB)带宽:19 TB/s


显然,共享内存的带宽远远大于全局内存


我们知道,同一个块(block)内的线程可以通过共享内存互相通信,

而不同块(blocks)可以并行运行在不同的流式多处理器(SM)上。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/fb0fimage.png

图 12:矩阵乘法中共享内存的使用示意图。图像作者:Sascha Kirch,灵感来源于 Simon Boehm


因此,对于每一个处理块,我们可以:

  • 先从全局内存把数据加载到共享内存中;
  • 然后,线程块内的线程就可以快速地访问共享内存中的数据;
  • 处理完这个块后,再继续加载下一个块的数据。


块切分(Block Tiling)


重复利用已经加载的数据,计算多个输出

接下来,我们可以让每个线程计算多个输出元素,而不仅仅是一个。

这就是所谓的**切分(Tiling)**技术,在 CUDA 优化中非常常见。

其基本思想是:

  • 将从共享内存加载到寄存器的数据重复利用,
  • 在一次加载之后进行多次计算,从而减少读写全局内存的次数。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/6c3fimage.png

图 13:矩阵乘法中的块切分示意图。图像作者:Sascha Kirch,灵感来源于 Simon Boehm


需要特别强调的是:

  • 计算总量没有减少,
  • 只是减少了内存访问次数,
  • 花更多时间在计算上,减少了等待数据的时间。

这提高了所谓的算术强度(Arithmetic Intensity)

即:运算次数/内存访问次数的比值


内核融合(Kernel Fusion)


避免写入和读取中间结果


最后一种优化策略是:

  • 将多个内核融合成一个内核,
  • 这样可以避免把中间结果写到全局内存后,又被下一步重新读回的过程。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/97d0image.png

图 14:内核融合对总运行时间和内存消耗的影响示意图。图像作者:Sascha Kirch


需要特别注意的是:

只有在多个内核操作的是同一数据块(tile)时,内核才能被融合。 这在后面讲解 Attention 机制时会非常关键。


到这里为止,关于 Attention 机制、CUDA 编程模型和 GPU 架构的基础知识已经打牢了!


我们终于准备好可以理解:

  • 为什么传统 Attention 机制在现代 GPU 上既慢又耗资源,
  • 以及
  • FlashAttention 是如何通过改良来解决这些问题的了!



4. 现代 GPU 上的 Attention


4.1 为什么 Attention 慢且资源消耗巨大?


到现在为止,你可能已经猜到了:

Attention 的问题在于,我们对全局内存的读写太频繁了。

让我们再次审视一下 Attention 机制中涉及的矩阵形状,

不过这次我们从 GPU 架构CUDA 编程模型 的视角来看。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/08edimage.png

图 15:Attention 层中每个内核的读写示意。图像作者:Sascha Kirch


从中可以观察到几点关键现象:

  • 输入有三组矩阵(Q、K 和 V),每个尺寸为 N×dN \times dN×d,输出矩阵 O 的尺寸也是 N×dN \times dN×d。
  • Attention 机制在中间过程生成了巨大的 N×NN \times NN×N 矩阵,所有这些中间结果都需要存储,然后在后续内核执行时再次读取。
  • 像矩阵乘法(MatMul)和 SoftMax 这样的操作,在计算每个输出元素时会重复使用相同的输入数据。
  • 而像 Dropout 和 Masking 这样的操作则是逐元素应用的(element-wise operations)。


可以看到:
存在大量昂贵的全局内存读写,而这些中间结果实际上并不是最终我们关心的,只是过程中的临时产物。


那么,我们能否应用前面提到的优化方法来改善 Attention 的执行呢?

更准确地说,我们为什么不能直接应用?

(之后我们会讲 FlashAttention 是如何突破这些限制的。)


小知识拓展:
你知道 FlashAttention 的第一作者 Tri Dao 也是 Mamba 状态空间模型(Mamba State Space Models) 的作者之一吗?

Mamba 是一种替代 Transformer 进行序列建模的方法,能够降低 Transformer 的 O(N2)O(N^2)O(N2) 计算复杂度。

如果感兴趣,可以看看我关于 Mamba 的另一系列内容。


4.2 我们能如何改进 Attention 机制?


Attention 机制的大问题在于:

  • 必须在中间过程物化(materialize)巨大的 N×N 矩阵,
  • 并且需要将它们写入和读取全局内存。


那么,为什么不直接将这些涉及的内核(kernels)融合(fusion)在一起,
避免产生中间结果的写入和读取呢?


这确实是个很棒的想法!

但是,存在两个主要问题:

  1. 并不是所有内核都在处理同一块(block-tile)数据, ——而我们之前提到过,要实现内核融合,必须操作相同数据块。
  2. 反向传播(backward pass)阶段需要用到前向传播(forward pass)中的中间结果来计算梯度。


慢下来,我们先看第一个问题:
SoftMax 是怎么在 Attention 中实现的?


来看 N×N矩阵的处理方式。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/e383image.png

图 16:Attention 机制中 SoftMax 的计算方式示意图。图像作者:Sascha Kirch


注意两点:

  • SoftMax 是**按行(row-wise)**应用的,而不是对整个矩阵整体应用。
  • 更重要的是,为了计算 SoftMax 矩阵 PPP 中的一个元素,需要读取对应输入矩阵 S 的整行元素。


而之前提到的矩阵乘法内核(MatMul kernel):

  • 是用 块切分(block tiling) 和 滑动窗口(sliding window) 的方式,
  • 只加载局部数据到共享内存,
  • 这种方式无法兼容 SoftMax,
  • 因为 SoftMax 需要访问整行数据来计算每一个输出元素。


为什么 SoftMax 是按行应用的呢?

  • 因为 Attention 中的每一行对应输入序列中的一个 token。
  • 通过对每一行应用 SoftMax,生成的是每个 token 对其他 token 的注意力分数,
  • 也就是:当前 token 关注其他 token 的程度。


那如果我们假设一下,

如果 SoftMax 也能在同一块数据上完成呢?

  • 那么矩阵乘法和 SoftMax 内核就可以融合,
  • 避免写入中间结果到全局内存,
  • 从而进一步加速。


但不幸的是,即使解决了这个问题,还有第二个问题

  • 我们仍然需要中间结果来进行反向传播,计算梯度!


在训练神经网络时,一般流程是:

  • 输入一批数据(batch),
  • 计算损失(loss),
  • 应用**反向传播(Backpropagation)**来更新模型参数,
  • 以便模型下一次表现更好。


应用链式法则(chain rule),

我们可以逐层计算每个操作的梯度。


来看 Attention 机制中一小块示例:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/3f57f03e-ca23-4c8f-8495-4bacb7a63570/827eimage.png

图 17:前向传播 vs 反向传播示意图。图像作者:Sascha Kirch


可以看到:

  • 反向传播需要使用前向传播时产生的中间结果。

所以:

  • 我们不得不将中间结果写入到全局内存,
  • 然后在反向传播时再次读取它们。

这正好就是我们一开始想要避免的事情……
结果又回到了起点。


幸运的是,FlashAttention 成功地解决了这些问题!

但在深入 FlashAttention 的细节之前,
让我们先总结一下第一部分的内容,消化一下今天的信息量。


这只是深入理解 FlashAttention 工程原理的开始!

我很期待听到其他在 Attention 机制或 GPU 优化领域工作的同学们的经验分享:


  • 你遇到过哪些挑战?
  • 有探索过什么替代方法吗?


欢迎在评论区留言讨论!

如果你觉得这份总结对你有帮助,点几个赞(Clap)可以帮助更多人发现它。
敬请期待第二部分!🚀



5. 第一部分总结(Wrapping Up Part 1)


5.1 总结(Summary)


我们想要实现的目标:

让 Attention 机制更快、更高效。

那么,
❓ 为什么 Attention 如此缓慢且资源消耗巨大?

  • 因为 Attention 需要在中间过程中物化(materialize)巨大的 N×NN \times NN×N 矩阵,
  • 这些中间结果又必须写入并从全局内存读取,极大增加了内存访问开销。


❓ 为什么我们不能简单地把所有内核(kernels)融合在一起,从而避免中间结果的读写?

  • 因为并不是所有内核都在处理同一数据块(block-tile), ——而内核融合(kernel fusion)要求操作同一块数据。
  • 而且我们需要在反向传播(backward pass)阶段使用中间结果来计算梯度, ——无法彻底避免中间结果的存储。


5.2 继续阅读 Part 2


在第二部分(Part 2)中,我们将正式介绍 FlashAttention,在本部分打下的基础之上,


一步一步深入剖析:


  • FlashAttention 的前向传播(forward pass)和
  • 反向传播(backward pass) 的详细算法机制。

🚀 每一步都循序渐进,助你彻底理解 FlashAttention!

评论