mlc-llm 浅析

编译器
Published

November 1, 2023

学习tvm是如何解决LLM推理问题.

1. Model Arch Generator

LLM有一个特点就是其动态与自回归的特性, 传统CNN的模型的计算通路都保存在模型中, 对于DL Compiler来说只需要将固定shape下的模型进行编译优化即可, 而LLM的计算通路并没有体现在模型中, 万幸的是没有多少厂商会大改LLM的模型结构, 所以DL Compiler的前端去手动去处理也问题不大.

使用mlc.build对模型进行编译, 进入build_model_from_args函数:

def build_model_from_args(args: argparse.Namespace):
    # 各种配置处理

    # 选择模型进行解析
    model_generators = {
        "llama": llama,
        "mistral": llama,
        "stablelm_epoch": stablelm_3b,
        "gpt_neox": gpt_neox,
        "gpt_bigcode": gpt_bigcode,
        "minigpt": minigpt,
        "gptj": gptj,
        "rwkv": rwkv,
        "rwkv_world": rwkv,
        "chatglm": chatglm,
    }

    # 

目前tvm是基于relax的分支支持LLM的, 构建模型的过程主要就是使用relax的主要特性按原始模型结构重新构造了一遍tvm的ir module:

首先是构造BlockBuilder的scope, 然后在其中构造整个模型运行的每个阶段.

def get_model(args, hf_config):
    # 处理配置...
    param_manager = ParamManager()
    bb = relax.BlockBuilder()

    if sep_embed:
        create_embed_func(bb, param_manager, config, args.quantization)
    # 省略batching的构造...
    create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed)
    create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization)
    create_kv_cache_func(bb, config)
    create_softmax_func_for_single_seq(bb, config)

    create_metadata_func(
        bb,
        model_name=model_name,
        max_window_size=config.max_sequence_length,
        stop_tokens=[2],
        add_prefix_space=False,
    )
    # 设定动态dim的上下界
    mod = bb.get()
    for gv in mod.functions:
        func = mod[gv]
        if isinstance(func, relax.Function):
            mod[gv] = func.with_attr( "tir_var_upper_bound", { "n": config.max_sequence_length, "m": config.max_sequence_length, }, )

    if args.build_model_only:
        return mod, param_manager, None, config

    return setup_params(mod, param_manager, dtype, config, args)

在relax中支持同时包含构造relay的数据流以及tir, 所以下面会使用nn.emit以及nn.emit_te, 同时还可以使用一些手动优化的vm函数relax.extern("vm.builtin.paged_attention_kv_cache_append")以及直接编写的prim_func.

class Linear(nn.Module):
    # ...
    def forward(self, input: relax.Expr) -> relax.Var:
        return nn.emit(relax.op.linear(input, self.weight, self.bias))

def apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0):
    def f_rotary_embedding(tensor, offset):
        def rotary_compute(*idx):
            pos = (offset + idx[-3]).astype("float32")
            return rotary_modulate_by_freq(
                tensor,
                idx,
                pos,
                position_embedding_base,
            )

        return tvm.te.compute(tensor.shape, rotary_compute, name="rotary")

    q_embed = nn.emit_te(f_rotary_embedding, q, offset, primfunc_name_hint="rotary_embedding")
    k_embed = nn.emit_te(f_rotary_embedding, k, offset, primfunc_name_hint="rotary_embedding")
    return q_embed, k_embed

class LlamaPagedAttention(LlamaAttentionBase):
    # ...
    def attention_fwd(
        self,
        query_states: relax.Expr,
        key_states: relax.Expr,
        value_states: relax.Expr,
        past_key_values: relax.Expr,
        batch_size: tir.PrimExpr,
        q_len: tir.PrimExpr,
        **kwargs,
    ) -> Tuple[relax.Expr, relax.Expr]:
        assert "layer_id" in kwargs and isinstance(kwargs["layer_id"], int)
        layer_id = kwargs["layer_id"]

        f_kv_cache_append = relax.extern("vm.builtin.paged_attention_kv_cache_append")
        past_key_values = nn.emit(
            relax.call_pure_packed(
                f_kv_cache_append,
                past_key_values,
                self.kv_cache_transpose_append,
                key_states,
                value_states,
                relax.PrimValue(layer_id),
                sinfo_args=relax.ObjectStructInfo(),
            )
        )
        # ...
        return attn_output, past_key_values

def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None:
    from tvm.script import tir as T

    # fmt: off
    @T.prim_func
    def kv_cache_transpose_append(
        var_pages: T.handle,
        var_k_data: T.handle,
        var_v_data: T.handle,
        var_page_table_indptr: T.handle,
        var_page_table_values: T.handle,
        var_last_page_offset: T.handle,
        var_append_length_indptr: T.handle,
        var_pos2seqidx: T.handle,
        layer_id: T.int32,
    ):
        # 省略buffer构造...
        for global_pos, h, f in T.grid(ntoken, nhead, nfeat):
            with T.block("k_transpose_append"):
                vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                seq_idx = pos2seqidx[vgpos]
                seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]
                pages[
                    page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
                    layer_id,
                    0,
                    vh,
                    T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size),
                    vf,
                ] = k_data[vgpos, vh, vf]
            with T.block("v_transpose_append"):
                vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                seq_idx = pos2seqidx[vgpos]
                seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]
                pages[
                    page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
                    layer_id,
                    1,
                    vh,
                    T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size),
                    vf,
                ] = v_data[vgpos, vh, vf]
    # fmt: on

    bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append")
    # Todo: integrating attention TIR func/kernel.
    bb.add_func(relax.extern("attention_func"), "attention")

在源代码中检索了一下, 发现是在vm中是直接实现了kv cache, 同时将kv cache的接口进行了封装, 让relax可以进行调用.

class AttentionKVCacheObj : public Object {
 public:
  /*!
   * \brief Underlying support data.
   */
  NDArray data;

  /*!
   * \brief number of slots already filled.
   */
  int64_t fill_count{0};

  /*!
   * \brief View all current cached values as one array.
   * \param shape The cached values.
   */
  NDArray View(const ShapeTuple& shape) {
    // ..
  }

  /** Clear the cache */
  void Clear() { /* ... */ }

  /** pop n entries */
  void PopN(size_t n) {
    // ...
  }

  void Update(NDArray value) {
    // ...
  }

  /*!
   * \brief Append value to the cache.
   * \param value The value to be appended.
   */
  void Append(NDArray value) {
    // ...
  }

  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
  static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
  TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
};

// register
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create_multiple")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear")

其实tvm这种直接在module中构造操作的方式也是很方便的, 如果是传统的编译流程对于每个层还需要写pattern去切子图, 并且一些kv cache相关的优化可能还需要通过一些选项去在某些位置强行添加.

mod_after_get_model.py

2. Module Transform

如果开启了量化还需要更新全部的参数, 然后对构造好的IR.Module进行优化, 这里也是一些比较有针对性的优化pass:

def mod_transform_before_build(
    mod: tvm.IRModule,
    param_manager: param_manager.ParamManager,
    args: argparse.Namespace,
    config: Dict,
) -> tvm.IRModule:
  # 
  mod = param_manager.transform_dequantize()(mod)
  mod = relax.transform.BundleModelParams()(mod)
  use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
  mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod)

  if max_seq_len:
    num_key_value_heads = config.get_num_key_value_heads()
    mod = fuse_split_rotary_embedding(
            config.num_attention_heads // args.num_shards,
            num_key_value_heads // args.num_shards,
            config.hidden_size // args.num_shards,
            config.position_embedding_base,
        )(mod)
  if args.target_kind == "cuda":
    # ...
  mod = mlc_llm.transform.FuseTransposeMatmul()(mod)
  mod = relax.pipeline.get_pipeline()(mod)  # pylint: disable=no-value-for-parameter
  mod = mlc_llm.transform.FuseDecodeMatmulEwise()(mod)
  mod = mlc_llm.transform.FuseDecodeTake()(mod)
  mod = relax.transform.DeadCodeElimination(model_names)(mod)
  mod = mlc_llm.transform.CleanUpTIRAttrs()(mod)
  mod_deploy = mod
  return mod_deploy

修改后的Module如下, 相比原本的Module多了许多Fused的算子.

mod_depoly.py

3. Module Build

build的过程就是调用原本tvm中的编译下降进行处理, 这里我的target为m1-metal:

def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
    # dump ...
    if target_kind != "cpu":
        dispatch_target = (
            args.target
            if args.target_kind != "webgpu"
            else tvm.target.Target("apple/m1-gpu-restricted")
        )
        with dispatch_target:
            mod_deploy = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            )(mod_deploy)
            mod_deploy = (
                mlc_llm.transform.LiftTIRGlobalBufferAlloc()(  # pylint: disable=not-callable
                    mod_deploy
                )
            )
            mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)

    # 省略使用cuda...
    args.lib_path = os.path.join(args.artifact_path, output_filename)
    ex.export_library(args.lib_path, **args.export_kwargs)
    print(f"Finish exporting to {args.lib_path}")

relax中的原生支持动态shape, 所以在decode过程中是通过dataflow的形式来执行:

@R.function
    def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache:...):
        cls = Module
        with R.dataflow():
          # ...
          lv1897 = R.call_tir(cls.transpose5, (lv1894,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16"))
          lv1898 = R.call_tir(cls.transpose5, (lv1895,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16"))
          lv722 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv1896, lv1897, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
          # ...

例如decode中的transpose5函数在before build阶段, tir中是以动态的方式进行构造的:

    @T.prim_func(private=True)
    def transpose5(var_A: T.handle, var_T_transpose: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16")
        T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(80)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(80)):
            with T.block("T_transpose"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
                T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]

after build阶段, 经过编译下降之后的block中的iterVar被映射到了threadblock两个层级. 我估计在tvm中对于动态申请的内存默认都是连续的, 所以这里match buffer也没有特别的stride.

    @T.prim_func(private=True)
    def transpose5(var_A: T.handle, var_T_transpose: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        A = T.match_buffer(var_A, (1, n, 32, 80), "float16")
        T_transpose = T.match_buffer(var_T_transpose, (1, 32, n, 80), "float16")
        # with T.block("root"):
        for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"):
            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // (80 * n))
                    v1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (80 * n) // 80)
                    v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80)
                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560)
                    T.reads(A[0, v1, v0, v2])
                    T.writes(T_transpose[0, v0, v1, v2])
                    T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2]

编译后的模型如下:

mod_build_stage.py

4. Chat

chat 其实经过之前的编译过程后会非常的精简, 只需要获取对应编译后模型的packed func然后反复调用即可.

class ChatModule {
 public:
  /*!
   * \brief Constructor
   * \param device the device to run the chat on.
   */
  explicit ChatModule(const DLDevice& device) {
    this->chat_mod_ = mlc::llm::CreateChatModule(device);
    this->prefill_ = this->chat_mod_->GetFunction("prefill");
    this->decode_ = this->chat_mod_->GetFunction("decode");
    this->stopped_ = this->chat_mod_->GetFunction("stopped");
    this->get_message_ = this->chat_mod_->GetFunction("get_message");
    this->reload_ = this->chat_mod_->GetFunction("reload");
    this->get_role0_ = this->chat_mod_->GetFunction("get_role0");
    this->get_role1_ = this->chat_mod_->GetFunction("get_role1");
    this->runtime_stats_text_ = this->chat_mod_->GetFunction("runtime_stats_text");
    this->verbose_runtime_stats_text_ = this->chat_mod_->GetFunction("verbose_runtime_stats_text");
    this->reset_chat_ = this->chat_mod_->GetFunction("reset_chat");
    this->process_system_prompts_ = this->chat_mod_->GetFunction("process_system_prompts");
    this->lib_path_ = "";
    this->executable_ = tvm::runtime::Module(nullptr);
    ICHECK(prefill_ != nullptr);
    ICHECK(decode_ != nullptr);
    ICHECK(stopped_ != nullptr);
    ICHECK(get_message_ != nullptr);
    ICHECK(reload_ != nullptr);
    ICHECK(get_role0_ != nullptr);
    ICHECK(get_role1_ != nullptr);
    ICHECK(runtime_stats_text_ != nullptr);
    ICHECK(verbose_runtime_stats_text_ != nullptr);
    ICHECK(reset_chat_ != nullptr);
  }