现在针对 LLM的部署推理框架都会花很多精力优化IO,公认这种大的decoder-only的模型推理瓶颈在memory侧。对此,我们简单分析一下原因。

首先LLM进行text-generation的流程如下:

  1. 加载模型架构
  2. 填充weight
  3. 输入tokenization化 (tokenizer encode prompt)
  4. 计算+后处理 (auto-regression, top_p_top_k, beam_search等)
  5. 输出字符(tokenizer decode tokens)

  • ⭐架构层面,基本都是transformer decoder搭建的,也就是embedding → decoder_layer (attention_block, norm, residual add, activation op, softmax, …) → lm_head. 其中除了attention_block和lm_head中的GEMM外,其余可以认为都是IO密集型算子 (norm, rope, add, softmax等)。不过现在的推理框架基本都是支持算子 融合(op fusion),比如 residual add可以作为gemm的post op,reduction维度算完即可相加然后存储,softmax通过flash attention在reduction维度滑动更新, gelu, relu, silu 这种FFN里面的激活函数也可以算出一个block就直接计算等等。总得来说,fusion op来减少反复从memory中反复读写的操作,让核心(core)一次算多点。因此IO OPs方面相对来说比较好解决(profiling看哪个op好事多就想办法fuse起来,压力来到kernel这边)。
  • ⭐⭐填充weight,也就是 参数量层面。另外模型的层数多,隐藏维度大,导致整体的weight比较大,load起来也比较费时。所以这个是memory bound的第一个原因。但是,现在低精度,混合精度推理相对成熟( k_quant , weight_only 4 bit 等),甚至更激进的2 bit也有,因此只要保证精度损失在可接受范围内 ( smooth quant , AWQ , GPTQ 等,压力来到PTQ量化算法这边),weight就可以压缩好几倍,大幅度减少memory load的时间。
  • 输入tokenization化,这个是标准操作,每个模型有自己独特的tokenizer进行prompt编码,这个词表查询一般没什么优化技巧,也只在first_token推理中进行
  • ⭐⭐⭐计算+后处理。这个应该算是最大的IO瓶颈了,大头就是 kv cache ,我们知道LLM都是自回归式(前向注意力)的来计算下一个token,除了prompt的first-token步骤,后面的都必须要串行计算,也就是说下一个token依赖上一个token,以此类推。其中 kv cache 是为了不重复计算之前出现过的token的 K,V activation值进行的优化,也就是**用空间换时间,**防止计算量逐iteration增长。以llama-7b 为例,layer=32, n_hidden_size=4096, context_window=2048 , 如果存储fp32 dtype,需要的内存大小为 32x4096x2048x2x4 /1024 /1024=2048M=2GB ,这个还是单batch的,如果想要实现多batch,比如continuous batching,还得乘上max_bs数。第二个是LLM的next_token计算截断,输入的是尺寸是 bsx1 ,如果是单batch下,尺寸就非常小,对GEMM来说,计算量小了很多,bottleneck自然就慢慢往IO bound那边倾斜,然后计算阶段还要额外load kv cache 然后进行memory_cpy 最新计算出的 K, V activation值,然后再统一送给下面的 MHA或者flash attention截断。当然,我们可以增大输入的尺寸,比如经典的ORCA 论文的continuous batching通过把多个request拼成一个sentence来增加计算量,充分利用 core 核心,但是相应的 kv cache 也会成倍增大。因此这一块是最大的瓶颈,没法很好地解决 (tradeoff),当然 kv cache 也是可以使用低精度来存储,然后计算时候再根据需求反量化回去,比如常见的F16, INT8等,减少IO存储,另一方面比如vllm的 paged attention 来分散其存储位置,减少空间不够的内存搬运等问题。后处理上如果是top_p_top_k这种增加多样性生成的采样方法,应该还好,只是算一次top_k的时间,如果是像 beam_search这种关注句子质量的采样方法,由于要保持多个beam,等于是变向增加了batch_size,因此瓶颈也是在 kv_cache 这里,虽然可以通过common_perfix这种规避掉一些内存,但是可能还要涉及到beam之间的kv_cache 的内存拷贝等,因此memory_bound也很明显。
  • 输出字符方面和输入类似,只能在tokenizer方面做文章,因为暂时没什么优化,大家都差不多。

综上,LLM推理的memory_bound最根本的原因还是 kv cache ,而kv cache 是由LLM架构中的前向注意力自回归导致的。如果未来有新的架构可以前向一次prompt吐出多个带前后顺序的tokens,那么kv cache 的问题就能得到很好的缓解了。