affine fusion pass浅析
学习mlir
中affine fusion pass
,
主要关注依赖分析部分.
1. 准备工作
首先我们的待测试的ir
为:
module { |
2.
performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount)
进入
affine fusion pass
之后, 通过dstId
在MemRefDependenceGraph
中找到producer
的affine for
节点作为src
节点. 在我们的例子中, 显然是融合上下两个循环块.通过
gatherProducerConsumerMemrefs(srcId, dstId, mdg, producerConsumerMemrefs)
收集src
节点与dst
节点中的存在生产消费链接的store/load
.通过
dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps)
获取dst
节点中的内存操作的循环层级, 我们的例子中的循环深度为4.遍历目标循环层级
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = affine::canFuseLoops(...) }
, 分别在每一层循环测试是否可以进行fusion
3.
affine::canFuseLoops(srcAffineForOp, dstAffineForOp, i, &depthSliceUnions[i - 1], strategy)
验证是否可以fusion
是一个复杂的过程.
经过一些琐碎的边界条件处理后, 开始执行判断过程.
numCommonLoops = affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
检查两个op
外围是否存在共同的循环, 目前的例子中并没有.switch (fusionStrategy.getStrategy())
根据不同的策略放入不同的关键路径op
, 这里opsA
表示先执行的,opsB
表示后执行的.sliceComputationResult = affine::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, isSrcForOpBeforeDstForOp, srcSlice)
3.1 computeSliceUnion
计算opsA
和OpsB
在指定循环层级位置计算得到的slice bounds
是否满足他们之间的依赖.
因为我们需要测试前一个执行节点的内存对后面所有的执行的内存的依赖关系,
所以这里是一个card product
.
这里我们可以从affine load/store
中构造出access relation
.
首先检查操作的是否同一个memref
, 如果不是那必然没有依赖,
跳过.
for (auto *i : opsA) { |
否则就需要进行依赖测试:
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) && |
3.2
checkMemrefAccessDependence
此时我们的src/dst
分别为: Checking for dependence at depth: 1 between:
mlir-asm-printer: Verifying operation: func.func
affine.store %4, %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
mlir-asm-printer: Verifying operation: func.func
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
接下来从access
中获得对应的access relation
:
// Create access relation from each MemRefAccess. |
在mlir
中对于src
和dst
的access relation
使用的是FlatAffineRelation
表达,
代表的是domain -> range
的映射,
存储的是constraints
. 在我们的例子中得到如下,
类型分别是[domainVars, rangeVars, symbolVars, localVars, constant]
:
Domain: 0, Range: 7, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value None None None const)
1 0 0 0 -1 0 0 0 = 0
0 1 0 0 0 -1 0 0 = 0
0 0 1 0 0 0 -1 0 = 0
1 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 511 >= 0
0 0 0 1 0 0 0 0 >= 0
0 0 0 -1 0 0 0 383 >= 0
Domain: 0, Range: 7, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value None None None const)
1 0 0 0 -1 0 0 0 = 0
0 1 0 0 0 -1 0 0 = 0
0 0 0 1 0 0 -1 0 = 0
1 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 63 >= 0
0 0 0 1 0 0 0 0 >= 0
0 0 0 -1 0 0 0 511 >= 0
转换为常用的map
表示如下: { [i,j,k,l] -> [i,j,k] : 0<= i < 8 and 0<= j < 128 and 0<= k < 512 and 0<= l < 384 }
{ [i,j,k,l] -> [i,j,l] : 0<= i < 8 and 0<= j < 128 and 0<= k < 64 and 0<= l < 512 }
然后组合两个relation
: dstRel.inverse();
dstRel.compose(srcRel); // src.domain -> dst.range
compose
后此时dstRel
为: Domain: 0, Range: 8, Symbols: 0, Locals: 0
19 constraints
(Value Value Value Value Value Value Value Value const)
-1 0 0 0 1 0 0 0 0 = 0
0 -1 0 0 0 1 0 0 0 = 0
0 0 -1 0 0 0 0 1 0 = 0
1 0 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 0 511 >= 0
0 0 0 1 0 0 0 0 0 >= 0
0 0 0 -1 0 0 0 0 383 >= 0
0 0 0 0 1 0 0 0 0 >= 0
0 0 0 0 -1 0 0 0 7 >= 0
0 0 0 0 0 1 0 0 0 >= 0
0 0 0 0 0 -1 0 0 127 >= 0
0 0 0 0 0 0 1 0 0 >= 0
0 0 0 0 0 0 -1 0 63 >= 0
0 0 0 0 0 0 0 1 0 >= 0
0 0 0 0 0 0 0 -1 511 >= 0
这里我不理解的是为什么只有range var
了,
理论上这里的map
形式应该是: { [i, j, k, l] -> [i' = i, j' = j, k', l' = k] : 0 <= i <= 7 and 0 <= j <= 127 and 0 <= k <= 511 and 0 <= l <= 383 and 0 <= k' <= 63 }
得到新的dstRel
后, 添加顺序约束,
这里内部是检查他们在共享循环中的顺序, 目前没有共享循环, 所以也不做什么.
// Add 'src' happens before 'dst' ordering constraints.
addOrderingConstraints(srcDomain, dstDomain, loopDepth, &dstRel);
最终就是检查约束: dstRel.isEmpty()
,
这里isEmpty
检查的是否存在整数解.
内部实现我看他是使用gaussianEliminateVars
/fourierMotzkinEliminate
以及GCDTest
来保证存在整数解.
TODO 在多面体编译中, 依赖约束的存在解意味着什么?
3.3
getComputationSliceState
如果我们确定存在依赖,且依赖约束存在解, 就开始计算所计算出的点集.
getComputationSliceState( |
首先这里的sourceOp
是在前面执行的op
,
depSinkOp
是在后面执行的. 依赖约束同上一小节.
然后删除不需要的维度: // Project out dimensions other than those up to 'loopDepth'.
unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
unsigned num =
isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
dependenceConstraints->projectOut(pos, num);
因为之前构造dependenceConstraints
的时候,
经过了compose
,
所以他现在的变量按src [i,j,k,l], dest [i,j,k,l]
的顺序排列,
同时目前要插入的循环深度loopDepth=1
,
所以就是将dest [j,k,l]
消除, 得到如下: Domain: 0, Range: 5, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value Value const)
-1 0 0 0 1 0 = 0
1 0 0 0 0 0 >= 0
-1 0 0 0 0 7 >= 0
0 1 0 0 0 0 >= 0
0 -1 0 0 0 127 >= 0
0 0 1 0 0 0 >= 0
0 0 -1 0 0 511 >= 0
0 0 0 1 0 0 >= 0
0 0 0 -1 0 383 >= 0
0 0 0 0 1 0 >= 0
0 0 0 0 -1 7 >= 0
获得src
节点中所有的循环变量i,j,k,l
:
// Add slice loop IV values to 'sliceState'.
unsigned offset = isBackwardSlice ? 0 : loopDepth;
unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
&sliceState->ivs);