mlc-llm 浅析
学习tvm是如何解决LLM推理问题.
1. Model Arch Generator
LLM有一个特点就是其动态与自回归的特性, 传统CNN的模型的计算通路都保存在模型中, 对于DL Compiler来说只需要将固定shape下的模型进行编译优化即可, 而LLM的计算通路并没有体现在模型中, 万幸的是没有多少厂商会大改LLM的模型结构, 所以DL Compiler的前端去手动去处理也问题不大.
使用mlc.build
对模型进行编译,
进入build_model_from_args
函数:
def build_model_from_args(args: argparse.Namespace): |
目前tvm是基于relax的分支支持LLM的, 构建模型的过程主要就是使用relax的主要特性按原始模型结构重新构造了一遍tvm的ir module:
首先是构造BlockBuilder
的scope,
然后在其中构造整个模型运行的每个阶段.
def get_model(args, hf_config): |
在relax中支持同时包含构造relay的数据流以及tir,
所以下面会使用nn.emit
以及nn.emit_te
,
同时还可以使用一些手动优化的vm函数relax.extern("vm.builtin.paged_attention_kv_cache_append")
以及直接编写的prim_func
.
class Linear(nn.Module): |
在源代码中检索了一下, 发现是在vm中是直接实现了kv cache, 同时将kv cache的接口进行了封装, 让relax可以进行调用.
class AttentionKVCacheObj : public Object { |
其实tvm这种直接在module中构造操作的方式也是很方便的, 如果是传统的编译流程对于每个层还需要写pattern去切子图, 并且一些kv cache相关的优化可能还需要通过一些选项去在某些位置强行添加.
2. Module Transform
如果开启了量化还需要更新全部的参数,
然后对构造好的IR.Module
进行优化,
这里也是一些比较有针对性的优化pass: def mod_transform_before_build(
mod: tvm.IRModule,
param_manager: param_manager.ParamManager,
args: argparse.Namespace,
config: Dict,
) -> tvm.IRModule:
#
mod = param_manager.transform_dequantize()(mod)
mod = relax.transform.BundleModelParams()(mod)
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod)
if max_seq_len:
num_key_value_heads = config.get_num_key_value_heads()
mod = fuse_split_rotary_embedding(
config.num_attention_heads // args.num_shards,
num_key_value_heads // args.num_shards,
config.hidden_size // args.num_shards,
config.position_embedding_base,
)(mod)
if args.target_kind == "cuda":
# ...
mod = mlc_llm.transform.FuseTransposeMatmul()(mod)
mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter
mod = mlc_llm.transform.FuseDecodeMatmulEwise()(mod)
mod = mlc_llm.transform.FuseDecodeTake()(mod)
mod = relax.transform.DeadCodeElimination(model_names)(mod)
mod = mlc_llm.transform.CleanUpTIRAttrs()(mod)
mod_deploy = mod
return mod_deploy
修改后的Module如下, 相比原本的Module多了许多Fused的算子.
3. Module Build
build的过程就是调用原本tvm中的编译下降进行处理,
这里我的target为m1-metal
:
def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: |
relax中的原生支持动态shape,
所以在decode
过程中是通过dataflow
的形式来执行:
def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache:...):
cls = Module
with R.dataflow():
# ...
lv1897 = R.call_tir(cls.transpose5, (lv1894,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16"))
lv1898 = R.call_tir(cls.transpose5, (lv1895,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16"))
lv722 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv1896, lv1897, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
# ...
例如decode
中的transpose5
函数在before build
阶段,
tir中是以动态的方式进行构造的:
|
在after build
阶段,
经过编译下降之后的block中的iterVar
被映射到了thread
和block
两个层级.
我估计在tvm中对于动态申请的内存默认都是连续的,
所以这里match buffer
也没有特别的stride
.
|
编译后的模型如下:
4. Chat
chat 其实经过之前的编译过程后会非常的精简,
只需要获取对应编译后模型的packed func
然后反复调用即可.
class ChatModule { |