TVM TensorIR
关于TVM的Tensor level IR.
1. ffi navigator的bug修复
我这里是python3.9, 不知道为什么tvm的ffi
navigator插件有一个类型问题启动不了.
所以需要修改/Users/lisa/mambaforge/lib/python3.9/site-packages/ffi_navigator/dialect/tvm.py line 97
为如下:
if path.startswith("" if self._pypath_api_internal is None else self._pypath_api_internal):
2. tvm.script.tir 与 tvm.tir
tvm.tir
是内在实现.
tvm.script.tir
主要是封装了一层用户友好的python类型接口(不存在实现).可以查看这篇文章.
tvm.script
实际上就是tensor ir
的语法表现形式,我们通过写tvm.script
语法,然后构建出IRModule
.
避免了直接从ir构造的别扭,因为如果是relay这种,不需要考虑太多的条件以及循环等,如果是底层ir,用函数的方式写这些就非常蛋疼了.
比如从tir直接构造ir是这样的: ib = tvm.tir.ir_builder.create()
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
with ib.if_scope(True):
ib.emit(tir.Evaluate(tir.ret(a)))
ib.emit(tir.Evaluate(tir.ret(b)))
stmt = ib.get()
func = tir.PrimFunc([a, b], stmt)
func = build_tir_func(func)
out = func(1.0, 2.0)script.tir
就方便多了:
def add(a: T.handle, b: T.handle):
for i in T.parallel(0, 2):
for j in T.serial(0, 1):
for z in T.vectorized(3, 4):
T.evaluate(0)
3. tvm.script -> tir的流程
首先我们使用tvm.script.tir
写一个计算函数,然后被转换为python
的ast
,由于不同
python
版本之间的 ast
不同,所以
tvm
单独开发了一个和 python
版本无关的
ast parser
叫 synr
.
在parser
的使用利用tvm
的lower transformer
把ast
进行细化.
要注意,用户层面导入tvm.script.tir as T
实际上都只有类型而已,
他对于这些类型的实际定义并没有导入进来,而是在tvm.script.parser
中使用.
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j in T.grid(128, 128):
with T.block("init"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.float32(0)
for k in range(128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]tvm.tir.function.PrimFunc
就如下:
PrimFunc([a, b, c]) {
block root() {
reads([])
writes([])
for (i, 0, 128) {
for (j, 0, 128) {
block init(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128))) {
bind(vi, i)
bind(vj, j)
reads([])
writes([C[vi, vj]])
C[vi, vj] = 0f
}
for (k, 0, 128) {
block update(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128))) {
bind(vi, i)
bind(vj, j)
bind(vk, k)
reads([C[vi, vj], A[vi, vk], B[vj, vk]])
writes([C[vi, vj]])
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}
}
}
}
}
}
4 ir builder流程
ir builder
提供了另一种构建tir
的方法,典型用法如下:
ib = tvm.tir.ir_builder.create()
n = te.size_var("n")
A = ib.pointer("float32", name="A")
tmod = tvm.tir.truncmod
with ib.for_range(0, n, name="i") as i:
with ib.if_scope(tmod(i, 2) == 0):
A[i] = A[i] + 1
with ib.else_scope():
A[0] = A[i] + 2
body = ib.get()ib.xx
构造的ir
对象都会通过ib.emit
的方式添加到ir builer
内部,然后对于一些存在scope
的比如for if
等等,
是构造了一个with scope
对象,然后在退出这个scope
的时候把中间的所有emit
生成的对象作为body
构造成一个for/if
的ir
.
5. tvm.te 与 tvm.tir
te
里面的实际上是老的写法,他里面又写了一套tensor/data producer
等等的ir
,
te
的ir
定义实际上是以operation
为核心的,然后类似于tensorflow
的placeholder
的方式进行构建的,实际上在转换到IRModule
的时候,还是会把这些东西转化为tir.Buffer
.所以目前可以不看那块的内容.
6. 一些tir的作用
6.1 block reads && writes
block
是tvm
调度的基本单元,他的调度器通常是获得一个block
,然后对这个块进行融合/分割/并行等等操作,同时还可以分析多个块
在parser
的block
的流程,他的func.body
是只会有一个赋值的操作C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
(忽略了前面的iter var
定义,应该是这些定义到时候都会被固化到代码中,所以也不会出现在计算流程中的原因),然后在func.exit_scope
时,他会进入tvm
的callback
函数中
python/tvm/script/tir/scope_handler.py line 255
,构造出带有bind
以及reads/writes
的tir
.
(实际上底层还分有BlockRealize
和Block
两部分)
with T.block("update"): |
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) |
得到的结果,实际上是把remap
的定义融合到了block
这个ir
中.
for (k, 0, 128) {
block update(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128))) {
bind(vi, i)
bind(vj, j)
bind(vk, k)
reads([C[vi, vj], A[vi, vk], B[vj, vk]])
writes([C[vi, vj]])
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}
}
6.2 block iter_var
iter_var
我个人把他看作一个symbol var
,他的好处就是我们可以任意绑定一个时机的value
,等到schedule
做完后再消除他得到真正的索引操作.
这里要说明一下iter_var
对于一个Buffer
的索引操作将会得到是BufferLoad
的ir
,他的表现形式就是多维索引B[vi,vj]
.
在后续这个BufferLoad
会被lower
到Load
,表现形式就是B.Handle[i * w + j]
.
即我们取symbol var
绑定的value
并计算出对于一个指针真正的索引.
🌰 原始TIR
: for (i: int32, 0, 128) {
for (j: int32, 0, 128) {
block([128, 128], "B") as [vi, vj] {
bind(vi, i)
bind(vj, j)
tir.reads([A[vi, vj]])
tir.writes([B[vi, vj]])
B[vi, vj] = (A[vi, vj]*2f32)
}
}split
之后,
可以发现我们只需要修改iter var
的绑定即可实现split
,
不然得递归把所有的i
改成((i_0*64) + i_1)
,写transform
就巨麻烦了.
for (i_0: int32, 0, 2) {
for (i_1: int32, 0, 64) {
for (j: int32, 0, 128) {
block([128, 128], "B") as [vi, vj] {
bind(vi, ((i_0*64) + i_1))
bind(vj, j)
tir.reads([A[vi, vj]])
tir.writes([B[vi, vj]])
B[vi, vj] = (A[vi, vj]*2f32)
}
}
}
6.3 BufferLoad lower
- 利用
ConvertBlocksToOpaque
的transform
把iter_var.var
都替换成对应的value
, 这里我其实没明白,为什么不把itervar
也设计成expr
, 理论上应该没啥问题吧.
7. 代码生成
7.1 ssa赋值
我自己写了一下c代码生成才发现不能无脑对综合了stmt以及expr的ir进行ssa赋值.怪不得tvm的c代码生成默认不开ssa赋值.
🌰 把下面的代码转换为c代码 void RefFunc(int[] A, int n)
{
for (i in (0, n))
{
A[i] = A[i] + 1;
for (j in (0, 10))
{
A[i] = A[i] + j;
}
}
}load A[i]
都变成了_1
这个tmep var
了.
然后第二次load
的时候就会出现没有更新值的问题.
void func_0(int32_t* A, int32_t n) {
for (int32_t i = 0; i < n; i++) {
int32_t _3 = (i * 1);
int32_t _2 = (0 + _3);
int32_t _1 = A[_2];
int32_t _0 = (_1 + 1);
A[_2] = _0;
for (int32_t j = 0; j < 10; j++) {
int32_t _4 = (_1 + j); // 这里就会出现load没有更新值的问题
A[_2] = _4;
}
}
}
所以我目前也是按照tvm的方法,把这些计算流程都转换成线性的计算.
这样就保证所有的表达式都会被emit
,不过也带来了一个计算冗余的问题,这个后续我们可以继续优化.
void func_0(int32_t* A, int32_t n) {
for (int32_t i = 0; i < n; i++) {
A[(0 + (i * 1))] = (A[(0 + (i * 1))] + 1);
for (int32_t j = 0; j < 10; j++) {
A[(0 + (i * 1))] = (A[(0 + (i * 1))] + j);
}
}
}
从relay到tir
默认tvm是在codegen中执行这个过程, 不过没法直接dump出对应的tir来看, 不过我们可以通过自定义pass的方法插入print节点.
from tvm import relay |
如何更加优雅的写tiling?
如果在TVM中:
如果是手写tiling的话,最麻烦的一点就是每次都需要手动算tile大小,然后开辟出n个for循环进行写操作.
def simple_split(a: T.handle) -> None:
A = T.match_buffer(a, [16])
for i in T.serial(0, 16):
with T.block("block"):
vi = T.axis.remap("S", [i])
A[vi] = i + 100
def test_simple_split():
sch = tir.Schedule(simple_split)
b = sch.get_block("block")
lps = sch.get_loops(b)
sch.split(lps[0], [7,10])
print(sch.mod.script())
# from tvm.script import tir as T
class Module:
def main(a: T.handle) -> None:
A = T.match_buffer(a, [16], dtype="float32")
# body
# with T.block("root")
for i_0, i_1 in T.grid(7, 10):
with T.block("block"):
vi = T.axis.spatial(16, i_0 * 10 + i_1)
T.where(i_0 * 10 + i_1 < 16)
T.reads([])
T.writes([A[vi]])
A[vi] = i_0 * 10 + i_1 + 100
不过tvm的tir中是简化了for循环,也就是无法自定义stride,因为他面向的对象都是cpu/gpu这些的设备.
但是如果对于一些大颗粒算子的dsa来说,最好还是带有stride的for循环比较合理,否则对于一段程序我们需要这样写:
def simple_split(a: T.handle) -> None:
A = T.match_buffer(a, [16])
chunk_n = 3
chunk_c = 5
for n in T.serial(0, compute_segment(16, chunk_n)):
for c in T.serial(0, compute_segment(32, chunk_c)):
with T.block("block"):
vi, vj = T.axis.remap("SS", [n,c])
A[vi * chunk_n + vj * chunk_c] = 100
如果每次都自己控制chunk,那么如果有6d的tensor,也就是6层循环, 那么变量绝对多到难以控制的程度.
如果可以这样写肯定就舒服多了,
然后关键是就是chunk固定但是length还得每次求, 不过应该是合理一些了:
def simple_split(a: T.handle) -> None:
A = T.match_buffer(a, [16])
chunk_n = 3
chunk_c = 5
for n in T.serial(0, 16, chunk_n):
for c in T.serial(0, 32, chunk_c):
with T.block("block"):
vi, vj = T.axis.remap("SS", [n,c])
with T.let(length_n, min(chunk_n, 16 - vi)):
with T.let(length_c, min(chunk_c, 32 - vj)):
A[vi + vj] = 100
但是还是有一点非常麻烦,那就是求tir中定义一个变量就需要声明他的作用域,那么对于真的多层的循环复杂逻辑肯定还是很麻烦的.
如果在CSharp中:
我的想法是在csharp中基于Linq实现两套写法, 那些shape之类的可能还是没法用expr进行lazy的运算,因为一旦那样就很难用linq语法, 写起来就复杂很多.
1. 适配老架构的segment的方式
之前因为是cpp的语法,所以要实现一套基于Enumerable的dsl还是比较麻烦,所以for循环之类的刻板代码比较多, 目前我也先支持这种写法. 通过linq拆分出segment之后构造segment 4d然后进行计算. csharp的linq可以再嵌套linq所以不用担心复杂的逻辑无法处理, 最后返回出expr即可.
T.PrimFunc("TileLoadStore").Body( |
2. 输入glb_tensor,可以通过索引的方式进行tiling, 而后构造指令.
这个glb_tensor应该是一个可以多层级的数据结构,比如当前的sub_tensor可以求关于上一层tensor的地址偏移,然后也可以求关于父节点的内存偏移. 然后基于之前segment的逻辑,就可以把写出一个优雅的tensor处理逻辑.
from in_seg in compute_segment(N,chunk_n) |