mlc-llm 浅析
学习tvm是如何解决LLM推理问题.
1. Model Arch Generator
LLM有一个特点就是其动态与自回归的特性, 传统CNN的模型的计算通路都保存在模型中, 对于DL Compiler来说只需要将固定shape下的模型进行编译优化即可, 而LLM的计算通路并没有体现在模型中, 万幸的是没有多少厂商会大改LLM的模型结构, 所以DL Compiler的前端去手动去处理也问题不大.
使用mlc.build
对模型进行编译,
进入build_model_from_args
函数:
def build_model_from_args(args: argparse.Namespace): |
目前tvm是基于relax的分支支持LLM的, 构建模型的过程主要就是使用relax的主要特性按原始模型结构重新构造了一遍tvm的ir module:
首先是构造BlockBuilder
的scope,
然后在其中构造整个模型运行的每个阶段.
def get_model(args, hf_config): |
在relax中支持同时包含构造relay的数据流以及tir,
所以下面会使用nn.emit
以及nn.emit_te
,
同时还可以使用一些手动优化的vm函数relax.extern("vm.builtin.paged_attention_kv_cache_append")
以及直接编写的prim_func
.
class Linear(nn.Module): |
在源代码中检索了一下, 发现是在vm中是直接实现了kv cache, 同时将kv cache的接口进行了封装, 让relax可以进行调用.
class AttentionKVCacheObj : public Object { |
其实tvm这种直接在module中构造操作的方式也是很方便的, 如果是传统的编译流程对于每个层还需要写pattern去切子图, 并且一些kv cache相关的优化可能还需要通过一些选项去在某些位置强行添加.
2. Module Transform
如果开启了量化还需要更新全部的参数,
然后对构造好的IR.Module
进行优化,
这里也是一些比较有针对性的优化pass:
def mod_transform_before_build( |
修改后的Module如下, 相比原本的Module多了许多Fused的算子.
3. Module Build
build的过程就是调用原本tvm中的编译下降进行处理,
这里我的target为m1-metal
:
def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: |
relax中的原生支持动态shape,
所以在decode
过程中是通过dataflow
的形式来执行:
|
例如decode
中的transpose5
函数在before build
阶段,
tir中是以动态的方式进行构造的:
|
在after build
阶段,
经过编译下降之后的block中的iterVar
被映射到了thread
和block
两个层级.
我估计在tvm中对于动态申请的内存默认都是连续的,
所以这里match buffer
也没有特别的stride
.
|
编译后的模型如下:
4. Chat
chat 其实经过之前的编译过程后会非常的精简,
只需要获取对应编译后模型的packed func
然后反复调用即可.
class ChatModule { |