AI 算力资源越发紧张的当下,斯坦福新研究将 GPU 运行效率再提升一波 —— 内核只有 100 行代码,让 H100 比使用 FlashAttention-2,性能还要提升 30%

怎么做到的?

研究人员从“硬件实际需要什么?如何满足这些需求?”这两个问题出发,设计了 一个嵌入式 CUDA DSL 工具,名为 ThunderKittens(暂且译为雷猫)。

雷猫可简化 AI 内核的编写,同时充分利用底层硬件能力。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

具体来说,雷猫的主要抽象是寄存器和共享内存中的小型张量块(tile),和目前 GPU 中对小矩阵乘法的优化相匹配。

通过操作这些 tile,开发者可相对简单地编写代码,充分利用张量核心、异步数据传输和共享内存等硬件特性。

使用雷猫实现的注意力机制内核,代码量少且能实现很高的硬件利用率,性能超过直接使用底层库(如 Cutlass)。

详细讨论过程以及雷猫是怎么设计出的,研究人员以“GPUs Go Brrr”为题,发在了斯坦福 Hazy Research 的 Blog 网站上。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

网友们对此讨论也十分热烈。

有网友表示读这篇 Blog 时,让他想起了初次了解超标量 CPU 架构时的惊讶感受:

GPU 真的达到了新高度。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

还有网友表示:

这篇文章重新点燃了我在 CS 149 并行编程课中所感受到的快乐。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

H100 里有什么?

斯坦福研究人员以 H100 为例,探讨了优化 GPU 的方法。

首先,回顾一下 H100 的硬件细节,这对于接下来的讨论非常重要。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

一个 H100 SXM GPU 包含:

(1)80GB 的 HBM3 内存,带宽为 3TB / s(实际带宽略低)。

(2)50MB 的 L2 缓存,带宽为 12TB / s,在 GPU 上分为两个 25MB 的部分,通过交叉开关连接(这个交叉开关表现不佳)。

(3)132 个流式多处理器(SM),每个包含:

高达 227KB 的共享内存位于 256KB 的 L1 缓存中(这些加起来的带宽大约 33TB / s)。

一个张量内存加速器(TMA)—— 这是英伟达 Hopper 架构中的一种新硬件组件,可进行异步地址生成和内存获取,还能促进片上内存网络。

4 个子单元,每个含:一个 warp scheduler;512 个向量寄存器(每个包含 32 个 4 字节的词);一个用于执行矩阵乘法的张量核心;一组内置指令,如求和、乘法等,这些指令能够并行操作这些向量寄存器。

除了这些,一个 GPU 还包括内存控制器、指令缓存…… 但对于这项研究而言不重要。

重要的是,所有的计算都发生在流式多处理器中,大部分计算是在寄存器中

H100 GPU 拥有 989 TFLOPs 的半精度矩阵乘法计算能力,以及约 60 TFLOPs 的“其他”计算能力。因此,每个周期内张量核心被使用时,至少能达到 94% 的硬件利用率。而张量核心不被使用时,硬件的利用率不会超过 6%。

换句话说:

H100 的利用率 = 张量核心活跃周期的百分比 +/- 6%。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

所以要充分发挥 H100 的能力,关键是保持张量核心持续运算

榨干 H100,要注意什么?

然鹅,要保持张量核心持续运行并不容易。

研究人员发现 GPU 硬件具有一些特性,对于保持矩阵乘法的运行非常重要:

WGMMA 指令虽然是必要的,但使用起来颇为麻烦。

共享内存的速度并不如预期的快,使用时还需格外注意。

生成地址的成本较高。

保持高占用率对于提升性能是有益的,寄存器至关重要。

这些特性在非 H100 GPU 上也有所适用,在 H100 上更加典型,就拿 RTX 4090 来说,相比 H100 处理起来简单得多。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

所以接下来还是以 H100 为例,展开探讨这几点特性。

WGMMA 指令

H100 引入了一套新的指令集,名为“warp group matrix multiply accumulate”(在 PTX 中为 wgmma.mma_async,在 SASS 中为 HGMMA / IGMMA / QGMMA / BGMMA)。

要理解这些指令的特点,需回顾以往张量核心的使用方式。

早期 GPU 中的张量核心指令如 wmma.mma.syncmma.sync,要求 SM 一个子单元内的 32 个线程的一个 warp 同步传输数据块至张量核心并等待结果。

wgmma.mma_async 指令则不同。它允许 128 个连续线程跨 SM 所有子单元协作同步,并从共享内存及寄存器(可选)异步启动矩阵乘法。这使得这些 warp 在等待矩阵乘法结果时可以处理其他任务。

研究人员通过微观基准测试,发现这些指令是充分发挥 H100 计算能力所必需的。没有这些指令,GPU 的峰值利用率大约只有 63%。

他们推测,这是由于张量核心需要从本地资源维持一个深度硬件 pipeline。

然而,这些指令的内存布局极其复杂。未重排的共享内存布局合并性差,需要额外的 L2 带宽。重排的内存布局记录不准确,研究人员花费了大量时间才弄明白。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

最终发现,这些布局只适用于特定矩阵形状,并与 wgmma.mma_async 指令的其他部分不兼容,例如硬件仅在未重排的布局下转置子矩阵。

此外,未重排的 wgmma 布局内存合并性差且有 bank conflicts。尽管 TMA 和 L2 缓存在如 flash attention 这类内核上能较好地掩盖这些问题,但要充分利用硬件,必须精心控制内存请求的合并和避免 bank conflicts。

尽管有这些问题,但这些指令对于充分利用 H100 是必不可少的。没有它们,GPU 的潜在性能就损失了 37%。

共享内存

共享内存的单次访问延迟约为 30 个周期(这也与研究人员观察的相符),这看似不多,但在这段时间内,SM 的张量核心几乎能完成两次完整的 32×32 方阵乘法。

以前的研究,如 Flash Attention,研究人员更多关注的是 HBM-SRAM 的瓶颈。但随着 HBM 速度的提升和张量核心的快速发展,即使是共享内存的相对较小延迟也变得尤为关键。

由于共享内存被分为 32 个独立的存储单元,处理不当可能会引发 bank conflicts,即同一个内存 bank 同时被多个请求访问,这种情况会导致请求被序列化。研究人员实验后认为,这会显著拖慢内核速度,且 wgmma 与 mma 指令需要的寄存器布局容易受到 bank conflicts 的影响。

解决方法是通过各种“重排”模式调整共享内存的配置,避免 bank conflicts,但细节要处理得当。

此外研究人员发现,尽可能避免在寄存器和共享内存之间的移动数据非常重要。可能的话,可使用内置硬件(如 wgmma 和 TMA 指令)进行异步数据传输。实在没法子了,再使用 warp 进行同步数据传输。

地址生成

H100 还有一个有趣的特性,其张量核心和内存都足够快,以至于仅生成用于获取数据的内存地址就占用了芯片的大量资源,特别是加入复杂的交错或重排模式时,这种情况更为明显。

研究人员表示,英伟达提供了张量内存加速器(TMA),似乎就是已经意识到了这个问题。

TMA 允许用户在全局和共享内存中指定多维张量布局,命令其异步提取张量的一部分,并在完成后触发一个屏障。这大大节省了地址生成的开销,并简化了 pipelines 的构建。

研究人员认为,TMA 对于充分发挥 H100 的潜力至关重要,可能比 wgmma.mma_async 更为关键。

它不仅节省了寄存器资源和指令派发,还提供了如异步在全局内存上执行归约等实用功能 —— 这在处理复杂的反向内核时尤其有用。

虽然 TMA 的重排模式解读有一定难度,需要进行一些逆向工程,但研究人员表示,相比之下,他们在这上面遇到的问题要少得多。

占用率

占用率指的是在 GPU 的相同执行硬件上同时调度的线程数。每个周期,SM 的某一子单元的 warp scheduler 会尝试向准备就绪的 warp 线程发出指令。

研究人员认为,英伟达采用这种模型可以更容易地保持硬件的满负荷运行。例如,当一个线程 warp 等待执行矩阵乘法时,另一个可以被指派执行使用快速指数运算的指令。

在某些方面,H100 对占用率的依赖程度低于前几代硬件。

它的异步特性使得即使单一指令流也能使多个硬件部分同时持续运行,包括读取内存、执行矩阵乘法、进行共享内存的归约,同时还能在寄存器上进行计算。

但高占用率容易隐藏缺陷或同步问题,一个设计良好的 pipeline 即使在占用率不高的情况下也能运行得相当快。

据研究人员观察,英伟达在设计 GPU 时确实考虑到了占用率。且由于存在足够多的同步操作和足够多的错误可能性,根据他们的经验,提高占用率通常能显著增加硬件的实际利用率。

此外,相比 H100,A100 和 RTX 4090 更依赖同步指令调度,占用率更重要。

用雷猫优化 GPU

鉴于以上情况,如何才能更轻松地编写所需的内核类型,同时充分发挥硬件的全部潜力?

雷猫(ThunderKittens)登场了。

这是一个嵌入在 CUDA 中的 DSL,本是斯坦福研究人员设计出来给自己内部使用的,后来发现还真挺好使。

Ps:起这么个名,一是他们觉得小猫很可爱,二来他们觉得大伙儿在代码中输入 kittens:: 会很有趣。

具体来说,雷猫包含四种模板类型:

  • 寄存器 tiles:在寄存器文件上表示二维张量。

  • 寄存器向量:在寄存器文件上表示一维张量。

  • 共享 tiles:在共享内存中表示二维张量。

  • 共享向量:在共享内存中表示一维张量。

tiles 通过高度、宽度和布局进行参数化;寄存器向量通过长度和布局进行参数化;而共享向量仅通过长度进行参数化,通常不会遇到 bank conflicts 问题。

此外,研究人员提供了一系列操作来处理这些张量,既可在 warp 级别使用,也可用于多个 warp 协作,包含初始化器,如将共享向量清零;一元操作,如 exp;二元操作,如 mul;行 / 列操作,例如行求和。

雷猫作为一个嵌入到 CUDA 中的库,其提供的抽象层在遇到不支持的功能时能够很好地处理。如果雷猫缺少某些功能,可以直接扩展它来实现你想要的效果。

以 Tri 的 flash attention 算法为例,在实际应用中,即使是使用英伟达的 Cutlass 库,实现起来也是相当复杂。

以下是一个在 RTX 4090 上使用雷猫编写的简单 flash attention 内核的示例。

总共约 60 行 CUDA 代码,硬件利用率达到了 75%。代码复杂性主要在于算法本身,而非交织模式或寄存器布局。

#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.

using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {

    auto warpid        = kittens::warpid();
    auto block_start   = blockIdx.x*(n*64);
    const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
          bf16 *_o = __o__ + block_start;

    extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory
    shared_allocator al((int*)&__shm[0]);

    // K and V live in shared memory -- this is about all that will fit.
    st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
    st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();

    // Initialize all of the register tiles.
    rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swed into col_l
    rt_fl_1x1<> att_block;
    rt_bf_1x1<> att_block_mma;
    rt_fl_1x4<> o_reg;
    rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block
    rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block

    int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);

    for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {

        // each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)
        load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
        mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment

        // zero flash attention L, M, and O registers.
        neg_infty(max_vec); // zero registers for the Q chunk
        zero(norm_vec);
        zero(o_reg);

        // iterate over k, v for these q's that have been loaded
        for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {

            // each warp loads its own chunk of k, v into shared memory
            load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
            load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
            __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase

            // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.
            for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {

                load(k_reg, k_smem[subtile]); // load k from shared into registers

                zero(att_block); // zero 16x16 attention tile
                mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T

                copy(norm_vec_last, norm_vec);
                copy(max_vec_last,  max_vec);

                row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
                sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0
                exp(att_block, att_block); // exponentiate the block in-place.

                sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.
                exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.
                mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.

                row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec
                div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized

                mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max
                div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm

                copy(att_block_mma, att_block); // convert to bf16 for mma_AB

                load(v_reg, v_smem[subtile]); // load v from shared into registers.
                rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg

                mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it
                mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.
            }
            __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk
        }

        store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
    }
}

关于 TMA、WGMMA、交织模式和描述符的复杂性,这里展示了一个使用雷猫编写的,针对 H100 的 FlashAttention-2 算法的前向传递示例。

template<int D>
__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)
void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {
    extern __shared__ int __shm[]; // this is the CUDA shared memory
    tma_swizzle_allocator al((int*)&__shm[0]);

    constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // constants
    constexpr int qo_height  = fwd_attend_ker_tile_dims<D>::qo_height;
    constexpr int kv_height  = fwd_attend_ker_tile_dims<D>::kv_height;

    st_bf<qo_height, tile_width, layout_q>          (&q_smem)   [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>,          NUM_WARPGROUPS>();
    st_bf<kv_height, tile_width, layout_k>          (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2,       NUM_WORKERS_KV>();
    st_bf<kv_height, tile_width, layout_v>          (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2,       NUM_WORKERS_KV>();

    int tic = 0, toc = 1;

    rt_fl<1, kv_height> att_block;
    rt_bf<1, kv_height> att_block_mma;
    rt_fl<1, qo_height> o_prev;
    col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec;
    col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec;

    int warpid      = kittens::warpid();
    int warpgroupid = warpid/kittens::WARPGROUP_WARPS;

    int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);

    __shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;

    int q_phasebit = 0;
    int kv_phasebit = 0;

    if (threadIdx.x == 0) {
        tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1);
        tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1); 
    }

    if (warpid == 0) {
        for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q
            int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;
            tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx); 
        }
        for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v      
            int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w; 
            tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx); 
            tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx); 
        }
    }

    neg_infty(max_vec); // zero registers for the Q chunk
    zero(norm_vec);
    zero(o_prev);
    __syncthreads();

    tma::arrive_and_wait(qsmem_barrier, q_phasebit);
    q_phasebit ^= 1;

    if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); } 
    else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }

    for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {
        tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);
        kv_phasebit ^= 1;

        __syncthreads();
        if (warpid == 0) {
            tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));

            if (kv_idx + 1 < kv_blocks) {
                for (int w = 0; w < NUM_WORKERS_KV; w++) {        
                    int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w; 
                    tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx); 
                    tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);
                }
            }
        }

        warpgroup::mma_fence(att_block);
        warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);
        warpgroup::mma_commit_group();

        copy(norm_vec_last, norm_vec);
        copy(max_vec_last,  max_vec);

        warpgroup::mma_async_wait();

        row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
        sub_row(att_block, att_block, max_vec);
        exp(att_block, att_block);

        sub(max_vec_last, max_vec_last, max_vec);
        exp(max_vec_last, max_vec_last);
        mul(norm_vec, norm_vec, max_vec_last);

        row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec
        div_row(att_block, att_block, norm_vec);

        mul(norm_vec_last, norm_vec_last, max_vec_last);
        div(norm_vec_last, norm_vec_last, norm_vec);

        copy(att_block_mma, att_block); // convert to bf16 for mma
        mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it

        warpgroup::mma_fence(o_prev);
        warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]);
        warpgroup::mma_commit_group();
    }

    auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory
    warpgroup::store(o_smem[warpgroupid], o_prev); 
    __syncthreads();

    if (warpid % 4 == 0) { // store o
        int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid;
        tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx); 
        tma::store_commit_group(); 
    }

    tma::store_async_wait();
}

那么,它的表现如何?

这个内核只有 100 行代码,实际上它在 H100 上的性能比 FlashAttention-2 高出约 30%。雷猫负责包装布局和指令,提供了一个可以在 GPU 上使用的迷你 pytorch 环境。

FA2(通过 Pytorch 实现)与 TK 在 H100 SXM 上的多种配置比较

此外,研究人员还发布了基于线性注意力和其他新架构的内核。其中基于线性注意力的内核的运行速度可达 215 TFLOPs,如果考虑到算法中固有的重计算,速度可超过 300 TFLOPs。

尽管线性注意力在理论上效率更高,但此前在实际硬件上表现并不佳。因此,研究人员认为这可能促进一系列高吞吐量应用的发展。

斯坦福让“GPU 高速运转”的新工具火了,比 FlashAttention2 更快-风君雪科技博客

small tile 符合 AI 和硬件发展趋势

最后,雷猫研究团队总结了开发雷猫的一些思考。在他们看来,雷猫之所以有效,是因为它的目标并不是试图做所有事:

CUDA 的确比雷猫表达能力更广,雷猫小而简单,功能有限。但雷猫的 small tiles 抽象设计符合 AI 和硬件的发展趋势。

虽然雷猫不支持小于 16 的维度,但研究人员认为这并不重要,因为硬件也不倾向于支持过小的维度。

如果你的矩阵乘法小于 16×16,你确定你正在做的是 AI 吗?

从理论出发,研究人员认为需要进行一种框架转变。

“寄存器当然不应该像旧 CPU 那样 32 位字。CUDA 使用的 1024 位宽向量寄存器确实是朝着正确方向迈出的一步。但对我们来说,寄存器是 16×16 的数据 tile。我们认为 AI 需要这样的设计,毕竟,它仍然只是矩阵乘法、规约和重塑。我们认为硬件也需要这样的设计,小型矩阵乘法迫切需要超出系统级 MMA 的硬件支持。”

研究人员认为,应该根据硬件特性来重新定义 AI 的设计理念。例如,循环状态应该有多大?应该足够大以适应一个 SM。计算的密度应该有多高?不应低于硬件的需求。

我们未来工作的一个重要方向是利用我们对硬件的了解来帮助我们设计与之匹配的 AI。

参考链接:

  • [1]https://hazyresearch.stanford.edu/blog/2024-05-12-tk

  • [2]https://github.com/HazyResearch/ThunderKittens

  • [3]https://news.ycombinator.com/item?id=40337936