这篇论文是陈天奇团队的成果,提出一个统一的硬件感知抽象(Axe
Layout),将逻辑张量坐标映射到多维物理空间,并设计基于此的多粒度、分布式感知的编译器DSL。今天就来解析一下Axe
Layout的设计思路和实现细节。
本文源码位于handson-polyhedral/17_axe.ipynb ;
核心公式
Axe Layout将逻辑张量索引映射到物理坐标集合:
\[L(x) = \{ D(x) + r + O \mid r \in R
\}\]
其中: ### D (Shard) - 分片映射 -
是一个有序的iter列表 - 每个iter = (extent, stride,
@axis ) -
extent : 硬件维度的逻辑大小 - stride :
相邻逻辑元素在硬件维度上的距离 - @axis : 物理轴的标签(device, warp,
lane等) -
作用:将逻辑索引 转换为物理坐标
R (Replica) - 复制维度
是一个集合 (无序),独立于逻辑索引
格式:{axis_name: replica_count, ...}
作用:为并行执行 添加额外维度
例如:4个线程独立执行相同的计算
O (Offset) - 偏移
固定的基地址 或资源保留
格式:{axis_name: offset_value, ...}
例如:数据在内存中的起始位置、执行器件偏移
import numpy as npimport itertoolsimport pycute as cutefrom typing import List , Tuple , Dict from dataclasses import dataclass@dataclass class Iter : extent: int stride: int axis: str def __repr__ (self ): return f"({self.extent} , {self.stride} @{self.axis} )" @dataclass class AxeLayout : D: List [Iter] R: List [Iter] O: Dict [str , int ] def __repr__ (self ): d_str = " × " .join(map (str , self.D)) r_str = " × " .join(map (str , self.R)) if self.R else "∅" o_str = ", " .join([f"{v} @{k} " for k, v in self.O.items()]) if self.O else "∅" return f"D: {d_str} | R: {r_str} | O: {o_str} " def print (self, name="" ): if name: print (f"\n{name} :" ) if not self.D: d_extent_line = "( )" d_stride_line = "( )" else : col_widths = [] for iter_obj in self.D: extent_str = str (iter_obj.extent) stride_str = f"{iter_obj.stride} @{iter_obj.axis} " col_width = max (len (extent_str), len (stride_str)) col_widths.append(col_width) extent_parts = [] for i, iter_obj in enumerate (self.D): extent_str = str (iter_obj.extent) padded = extent_str.center(col_widths[i]) extent_parts.append(padded) d_extent_line = "( " + " " .join(extent_parts) + " )" stride_parts = [] for i, iter_obj in enumerate (self.D): stride_str = f"{iter_obj.stride} @{iter_obj.axis} " padded = stride_str.rjust(col_widths[i]) stride_parts.append(padded) d_stride_line = "( " + ", " .join(stride_parts) + " )" if self.R: r_col_widths = [] for iter_obj in self.R: extent_str = str (iter_obj.extent) stride_str = f"{iter_obj.stride} @{iter_obj.axis} " col_width = max (len (extent_str), len (stride_str)) r_col_widths.append(col_width) r_extent_parts = [] for i, iter_obj in enumerate (self.R): extent_str = str (iter_obj.extent) padded = extent_str.center(r_col_widths[i]) r_extent_parts.append(padded) r_extent_line = "( " + " " .join(r_extent_parts) + " )" r_stride_parts = [] for i, iter_obj in enumerate (self.R): stride_str = f"{iter_obj.stride} @{iter_obj.axis} " padded = stride_str.rjust(r_col_widths[i]) r_stride_parts.append(padded) r_stride_line = "( " + ", " .join(r_stride_parts) + " )" print (d_extent_line + " " + r_extent_line) print (d_stride_line + " + " + r_stride_line, end="" ) else : print (d_extent_line) print (d_stride_line, end="" ) if self.O: o_items = [f"{offset} @{axis} " for axis, offset in self.O.items()] o_str = " + " .join(o_items) print (" + " + o_str, end="" ) print ()
Examples
NVIDIA Tensor Core tile
有了基础的定义后, 我们可以尝试实现论文中的第一个例子, 映射逻辑
\(8×16\) tile到GPU的2个warp(各32
lane)+ 2个寄存器:
\[\begin{pmatrix}8 & 2 & 4 & 2
\\ 4@\texttt{lane} & 1@\texttt{warp} & 1@\texttt{lane} &
1@\texttt{reg}\end{pmatrix} + \begin{bmatrix}2 \\
4@\texttt{warp}\end{bmatrix} + 5@\texttt{warp}\]
example 1
layout_a = AxeLayout(D=[ Iter(8 , 4 , "lane" ), Iter(2 , 1 , "warp" ), Iter(4 , 1 , "lane" ), Iter(2 , 1 , "reg" ), ], R=[Iter(2 , 4 , "warp" )], O={"warp" : 5 }) layout_a.print ("layout_a" )
layout_a:
( 8 2 4 2 ) ( 2 )
( 4@lane, 1@warp, 1@lane, 1@reg ) + ( 4@warp ) + 5@warp
Distributed sharding on a
2×2 GPU mesh
假设一个\(64×128\) 张量在4个GPU上的分片+复制混合:
example 2
完全切分 (行按GPU行,列按GPU列分): \[\begin{pmatrix}2 & 32 & 2 & 64 \\
1@\texttt{gpuid} & 128@\texttt{m} & 2@\texttt{gpuid} &
1@\texttt{m}\end{pmatrix}\]
这里的gpuid表示设备维度,这里的m表示内存维度,如果把内存维度的Iter单独抽取出来,就可以计算在每个设备中Local
Tensor的Layout。把gpuid的Iter出来,可以用stride来隐式体现gpu
mesh的分布方式。
行切分+列复制 (行切分,每行shard复制到行内两GPU):
\[\begin{pmatrix}2 & 32 & 128 \\
1@\texttt{gpuid} & 128@\texttt{m} & 1@\texttt{m}\end{pmatrix} +
\begin{bmatrix}2 \\ 2@\texttt{gpuid}\end{bmatrix}\]
实际上这两个就是经典的分布式张量切分方式,对应到SBP里面分别为:
\[
\begin{aligned}
(split(0), split(1)) \\
(split(0), broadcast)
\end{aligned}
\]
layout_b = AxeLayout(D=[ Iter(2 , 1 , "gpuid" ), Iter(32 , 128 , "m" ), Iter(2 , 2 , "gpuid" ), Iter(64 , 1 , "m" )], R=[], O={}) layout_b.print ("layout_b" )
layout_b:
( 2 32 2 64 )
( 1@gpuid, 128@m, 2@gpuid, 1@m )
Native
multidimensional memory in Accelerators
example 3
这里是对于物理内存硬件的映射,认为P是memory bank partitions,
F为free dimensions。主要体现的是Axe
Layout是同时支持并行硬件维度和存储硬件的表达,所以当不考虑并行维度时,可以完成传统layout的功能。
layout_c = AxeLayout(D=[ Iter(2 , 512 , "F" ), Iter(128 , 1 , "P" ), Iter(512 , 1 , "F" )], R=[], O={}) layout_d = AxeLayout(D=[ Iter(2 , 112 , "Col" ), Iter(128 , 1 , "Lane" ), Iter(112 , 1 , "Col" )], R=[], O={}) layout_c.print ("layout_c" ) layout_d.print ("layout_d" )
layout_c:
( 2 128 512 )
( 512@F, 1@P, 1@F )
layout_d:
( 2 128 112 )
( 112@Col, 1@Lane, 1@Col )
Forward Mapping
AxeLayout实际上是定义了logical index与hardware
index的relation,基于此所实现的 logical index -> parallel
index的映射称为forward mapping。
我这里拿第一个例子,具体计算前向过程。首先shard部分负责分解原始的logical
index,而他的stride是应用于hardware
axes的。而后replica则会扩展映射集合,和offset部分一起将偏移作用到hardware
axes上。
logical coord: (i=2, j=9)
▼
shape: (8, 16)
▼
linear index: 2 * 16 + 9 = 41
▼
decomposed index as:
> 41 ÷ 16 = 2
> (41 % 16) ÷ 8 = 1
> ((41 % 16) % 8) // 2 = 0
> ((41 % 16) % 8) % 2 = 1
▼
shard coord:
( 2 , 1 , 0 , 1 )
▼
( 8 2 4 2 ) ( 2 )
D = (4@lane, 1@warp, 1@lane, 1@reg) + R = (4@warp) + O = (5@warp)
▼
> lane = 2*4 + 0*1 = 8 > warp0 = 0*4 = 0 > warp = 5
> warp = 1*1 = 1 > warp1 = 1*4 = 4
> reg = 1*1 = 1
▼
{lane: 8, warp: 1, reg: 1} + [{warp: 0}, {warp: 4}] + {warp: 5}
▼
offset = [
{warp: 1+0+5 = 6 , lane: 8, reg: 1},
{warp: 1+4+5 = 10, lane: 8, reg: 1}
]
上述流程也许有些抽象,我们可以通过代码来更直观地理解Axe
Layout的前向计算过程。
这里我复用了cute中的一些函数来辅助计算映射过程:
def forward (self: AxeLayout, coord: Tuple [int , ...], shape: Tuple [int , ...] ): shard_extents = tuple (i.extent for i in self.D) org_idx = cute.crd2idx(coord, shape, cute.suffix_product(shape)) shard_crd = cute.idx2crd(org_idx, shard_extents[::-1 ])[::-1 ] shards = [(idx * it.stride, it.axis) for (idx, it) in zip (shard_crd, self.D)] rep_set = [] replica = self.R if len (self.R) else [Iter(1 , 0 , None )] for reps in itertools.product([(i * it.stride, it.axis) for it in replica for i in range (it.extent)]): d = dict () for val, axis in shards: d[axis] = d.get(axis, 0 ) + val for val, axis in reps: if axis is not None : d[axis] = d.get(axis, 0 ) + val for axis, val in self.O.items(): d[axis] = d.get(axis, 0 ) + val rep_set.append(d) return rep_set
接下来我们给定一个具体的logical coordinate (2, 9),并通过Axe
Layout的forward函数来计算映射到的并行轴坐标集合:
def pretty_print (crd: List [Dict [str , int ]], **kwargs ): key_order = dict () for k,v in kwargs.items(): key_order[k] = v print ([{k: d[k] for k in sorted (d, key=lambda k: key_order.get(k, 99 ))} for d in crd]) AxeLayout.forward = forward hardware_coords = layout_a.forward((2 , 9 ), (8 , 16 )) pretty_print(hardware_coords, warp=0 , lane=1 , reg=2 )
[{'warp': 6, 'lane': 8, 'reg': 1}, {'warp': 10, 'lane': 8, 'reg': 1}]
Backward Mapping
Axe Layout同样支持从 hardware index -> logical
index的映射过程,称为backward
mapping。这个过程就会比之前稍微简单一些,因为不需要考虑replica的扩展,只需要将offset和shard的过程反过来即可。具体流程我就不赘述,直接给出代码:
def backward (self: AxeLayout, indices: List [Dict [str , int ]], shape: Tuple [int , ...] ): index = indices[0 ].copy() for axis, offset in self.O.items(): index[axis] = index.get(axis, 0 ) - offset shard_domains: Dict [str , List [Tuple [int , int , int ]]] = {} for i, it in enumerate (self.D): sdomain = shard_domains.get(it.axis, ([])) sdomain.append((i, it.extent, it.stride)) shard_domains[it.axis] = sdomain scoords = [-1 ] * len (self.D) for axis, tps in shard_domains.items(): tps.sort(key=lambda it: it[0 ]) slayout = cute.Layout(tuple (tp[1 ] for tp in tps), tuple (tp[2 ] for tp in tps)) sindex = index[axis] scrd = cute.idx2crd(sindex, slayout.shape, slayout.stride) for i, (idx, _, _) in enumerate (tps): scoords[idx] = scrd[i] shard_extents = tuple (it.extent for it in self.D) linear_idx = cute.crd2idx(tuple (scoords), shard_extents, cute.suffix_product(shard_extents)) logical_coord = cute.idx2crd(linear_idx, shape, cute.suffix_product(shape)) return logical_coord
我们逆推之前layout
a的backward结果,可以检查backward的结果和forward给的输入是一致的:
AxeLayout.backward = backward logical_coord = layout_a.backward(hardware_coords, (8 , 16 )) assert logical_coord == (2 , 9 )print (logical_coord)
(2, 9)
总结
这篇论文里面还提到支持Tile,Slice等操作,实际上也都是可以复用cute
Layout Algorithm来实现的,这里就不进一步展开了。
同时还有一些编译器语法的前端支持,比如如何定义Tensor,Execution scopes,
这些内容不是我关注的重点,也就不展开了。 总的来说Axe
Layout还是一个比较简洁且强大的张量布局抽象,他察觉到了在sharding的时候,实际上需要绑定tensor
logical shape和hardware
shape,因此可以把两部分放到一个Shard结构中进行统一处理,这样把parallel
resource和memory resource都可以纳入同一个layout
framework下进行处理,能在兼容原始layout的同时方便的扩展分布式能力。