tvm dynamic shape 学习

编译器
Published

November 15, 2023

探究tvm dynamic shape的实现.

tvm ir design

将relax ir的语法dump出来可以知道, 这里与relay那种数据流的ir不同, dataflow中的每个操作使用一个var binding来存储.

@R.function
def fn1(a: R.Tensor(("n", 10), 'float32'), b: R.Tensor((1,), 'float32')):
    with R.dataflow():
        n = T.int64()
        c: R.Tensor((n, 10)) = a + b
        R.output(c)
    return c
Function(
    params=[
        Var(
            name_hint="a",
            struct_info=TensorStructInfo(
                dtype=float32,
                shape=ShapeExpr(
                    values=[
                        PrimExpr(value=`n`),
                        PrimExpr(value=`T.int64(10)`)
                    ],
                    struct_info=ShapeStructInfo(
                        ndim=2,
                        values=[
                            PrimExpr(value=`n`),
                            PrimExpr(value=`T.int64(10)`)
                        ]
                    )
                )
            )
        ),
        Var(
            name_hint="b",
            struct_info=TensorStructInfo(
                dtype=float32,
                shape=ShapeExpr(
                    values=[PrimExpr(value=`T.int64(1)`)],
                    struct_info=ShapeStructInfo(
                        ndim=1,
                        values=[PrimExpr(value=`T.int64(1)`)]
                    )
                )
            )
        )
    ],
    body=SeqExpr(
        blocks=[
            BindingBlock(
                bindings=[
                    VarBinding(
                        var=Var(
                            name_hint="c",
                            struct_info=TensorStructInfo(
                                dtype=,
                                shape=ShapeExpr(
                                    values=[
                                        PrimExpr(value=`n`),
                                        PrimExpr(value=`T.int64(10)`)
                                    ],
                                    struct_info=ShapeStructInfo(
                                        ndim=2,
                                        values=[
                                            PrimExpr(value=`n`),
                                            PrimExpr(value=`T.int64(10)`)
                                        ]
                                    )
                                )
                            )
                        ),
                        value=Call(
                            op=Op(name="relax.add"),
                            args=[
                                Var(
                                    name_hint="a",
                                    struct_info=TensorStructInfo(
                                        dtype=float32,
                                        shape=ShapeExpr(
                                            values=[
                                                PrimExpr(value=`n`),
                                                PrimExpr(value=`T.int64(10)`)
                                            ],
                                            struct_info=ShapeStructInfo(
                                                ndim=2,
                                                values=[
                                                    PrimExpr(value=`n`),
                                                    PrimExpr(value=`T.int64(10)`)
                                                ]
                                            )
                                        )
                                    )
                                ),
                                Var(
                                    name_hint="b",
                                    struct_info=TensorStructInfo(
                                        dtype=float32,
                                        shape=ShapeExpr(
                                            values=[PrimExpr(value=`T.int64(1)`)],
                                            struct_info=ShapeStructInfo(
                                                ndim=1,
                                                values=[PrimExpr(value=`T.int64(1)`)]
                                            )
                                        )
                                    )
                                )
                            ],
                            struct_info=TensorStructInfo(
                                dtype=,
                                shape=ShapeExpr(
                                    values=[
                                        PrimExpr(value=`n`),
                                        PrimExpr(value=`T.int64(10)`)
                                    ],
                                    struct_info=ShapeStructInfo(
                                        ndim=2,
                                        values=[
                                            PrimExpr(value=`n`),
                                            PrimExpr(value=`T.int64(10)`)
                                        ]
                                    )
                                )
                            )
                        )
                    )
                ]
            )
        ],
        body=Var(
            name_hint="c",
            struct_info=TensorStructInfo(
                dtype=,
                shape=ShapeExpr(
                    values=[
                        PrimExpr(value=`n`),
                        PrimExpr(value=`T.int64(10)`)
                    ],
                    struct_info=ShapeStructInfo(
                        ndim=2,
                        values=[
                            PrimExpr(value=`n`),
                            PrimExpr(value=`T.int64(10)`)
                        ]
                    )
                )
            )
        ),
        struct_info=TensorStructInfo(
            dtype=,
            shape=ShapeExpr(
                values=[
                    PrimExpr(value=`n`),
                    PrimExpr(value=`T.int64(10)`)
                ],
                struct_info=ShapeStructInfo(
                    ndim=2,
                    values=[
                        PrimExpr(value=`n`),
                        PrimExpr(value=`T.int64(10)`)
                    ]
                )
            )
        )
    ),
    ret_struct_info=TensorStructInfo(
        dtype=,
        shape=ShapeExpr(
            values=[
                PrimExpr(value=`n`),
                PrimExpr(value=`T.int64(10)`)
            ],
            struct_info=ShapeStructInfo(
                ndim=2,
                values=[
                    PrimExpr(value=`n`),
                    PrimExpr(value=`T.int64(10)`)
                ]
            )
        )
    ),
    is_pure=1,
    attrs={"global_symbol": "fn1"},
    struct_info=FuncStructInfo(
        params=[
            TensorStructInfo(
                dtype=float32,
                shape=ShapeExpr(
                    values=[
                        PrimExpr(value=`n`),
                        PrimExpr(value=`T.int64(10)`)
                    ],
                    struct_info=ShapeStructInfo(
                        ndim=2,
                        values=[
                            PrimExpr(value=`n`),
                            PrimExpr(value=`T.int64(10)`)
                        ]
                    )
                )
            ),
            TensorStructInfo(
                dtype=float32,
                shape=ShapeExpr(
                    values=[PrimExpr(value=`T.int64(1)`)],
                    struct_info=ShapeStructInfo(
                        ndim=1,
                        values=[PrimExpr(value=`T.int64(1)`)]
                    )
                )
            )
        ],
        ret=TensorStructInfo(
            dtype=,
            shape=ShapeExpr(
                values=[
                    PrimExpr(value=`n`),
                    PrimExpr(value=`T.int64(10)`)
                ],
                struct_info=ShapeStructInfo(
                    ndim=2,
                    values=[
                        PrimExpr(value=`n`),
                        PrimExpr(value=`T.int64(10)`)
                    ]
                )
            )
        ),
        purity=True
    )
)

而根据relax shape设计文档下面这种情况应该是无法支持的:

@R.function
def fn2(a: R.Tensor(("n", 10), 'float32'), b: R.Tensor((1,), 'float32')):
    with R.dataflow():
        n = T.int64()
        c = a + b
        cshape: R.Shape() = R.shape_of(c)
        d = R.reshape(c, [1, cshape[0], cshape[1], 1])
        R.output(d)
    return d

我在思考是不是应该有一种直接基于数据流的方式来添加symbolic shape的信息?