学习mliraffine fusion pass, 主要关注依赖分析部分.

1. 准备工作

首先我们的待测试的ir为:

module {
func.func @main(%arg0: memref<8x128x384xf32>, %arg1: memref<8x384x512xf32>, %arg2: memref<8x128x512xf32>, %arg3: memref<8x512x64xf32>, %arg4: memref<8x128x64xf32>) {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 512 {
affine.for %arg8 = 0 to 384 {
%0 = affine.load %arg0[%arg5, %arg6, %arg8] : memref<8x128x384xf32>
%1 = affine.load %arg1[%arg5, %arg8, %arg7] : memref<8x384x512xf32>
%2 = affine.load %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
}
}
}
}
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 64 {
affine.for %arg8 = 0 to 512 {
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
%1 = affine.load %arg3[%arg5, %arg8, %arg7] : memref<8x512x64xf32>
%2 = affine.load %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
}
}
}
}
return
}
}

2. performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount)

  1. 进入affine fusion pass之后, 通过dstIdMemRefDependenceGraph中找到produceraffine for节点作为src节点. 在我们的例子中, 显然是融合上下两个循环块.

  2. 通过gatherProducerConsumerMemrefs(srcId, dstId, mdg, producerConsumerMemrefs)收集src节点与dst节点中的存在生产消费链接的store/load.

  3. 通过dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps)获取dst节点中的内存操作的循环层级, 我们的例子中的循环深度为4.

  4. 遍历目标循环层级for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = affine::canFuseLoops(...) }, 分别在每一层循环测试是否可以进行fusion

3. affine::canFuseLoops(srcAffineForOp, dstAffineForOp, i, &depthSliceUnions[i - 1], strategy)

验证是否可以fusion是一个复杂的过程. 经过一些琐碎的边界条件处理后, 开始执行判断过程.

  1. numCommonLoops = affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);检查两个op外围是否存在共同的循环, 目前的例子中并没有.
  2. switch (fusionStrategy.getStrategy())根据不同的策略放入不同的关键路径op, 这里opsA表示先执行的, opsB表示后执行的.
  3. sliceComputationResult = affine::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, isSrcForOpBeforeDstForOp, srcSlice)

3.1 computeSliceUnion

计算opsAOpsB在指定循环层级位置计算得到的slice bounds是否满足他们之间的依赖. 因为我们需要测试前一个执行节点的内存对后面所有的执行的内存的依赖关系, 所以这里是一个card product.

这里我们可以从affine load/store中构造出access relation. 首先检查操作的是否同一个memref, 如果不是那必然没有依赖, 跳过.

for (auto *i : opsA) {
MemRefAccess srcAccess(i);
for (auto *j : opsB) {
MemRefAccess dstAccess(j);
if (srcAccess.memref != dstAccess.memref)
continue;
}
}

否则就需要进行依赖测试:

bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
isa<AffineReadOpInterface>(dstAccess.opInst);
FlatAffineValueConstraints dependenceConstraints;
// Check dependence between 'srcAccess' and 'dstAccess'.
DependenceResult result = checkMemrefAccessDependence( /* 如果操作的是同一个buffer, 那么需要检查依赖 */
srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
&dependenceConstraints, /*dependenceComponents=*/nullptr,
/*allowRAR=*/readReadAccesses)

3.2 checkMemrefAccessDependence

此时我们的src/dst分别为:

Checking for dependence at depth: 1 between:
mlir-asm-printer: Verifying operation: func.func
affine.store %4, %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
mlir-asm-printer: Verifying operation: func.func
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>

接下来从access中获得对应的access relation:

// Create access relation from each MemRefAccess.
FlatAffineRelation srcRel, dstRel;
if (failed(srcAccess.getAccessRelation(srcRel)))
return DependenceResult::Failure;
if (failed(dstAccess.getAccessRelation(dstRel)))
return DependenceResult::Failure;

FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();

mlir中对于srcdstaccess relation使用的是FlatAffineRelation表达, 代表的是domain -> range的映射, 存储的是constraints. 在我们的例子中得到如下, 类型分别是[domainVars, rangeVars, symbolVars, localVars, constant]:

Domain: 0, Range: 7, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value None None None const)
1 0 0 0 -1 0 0 0 = 0
0 1 0 0 0 -1 0 0 = 0
0 0 1 0 0 0 -1 0 = 0
1 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 511 >= 0
0 0 0 1 0 0 0 0 >= 0
0 0 0 -1 0 0 0 383 >= 0

Domain: 0, Range: 7, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value None None None const)
1 0 0 0 -1 0 0 0 = 0
0 1 0 0 0 -1 0 0 = 0
0 0 0 1 0 0 -1 0 = 0
1 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 63 >= 0
0 0 0 1 0 0 0 0 >= 0
0 0 0 -1 0 0 0 511 >= 0

转换为常用的map表示如下:

{ [i,j,k,l] -> [i,j,k] : 0<= i < 8 and 0<= j < 128 and 0<= k < 512 and 0<= l < 384 }
{ [i,j,k,l] -> [i,j,l] : 0<= i < 8 and 0<= j < 128 and 0<= k < 64 and 0<= l < 512 }

然后组合两个relation:

dstRel.inverse();
dstRel.compose(srcRel); // src.domain -> dst.range

compose后此时dstRel为:

Domain: 0, Range: 8, Symbols: 0, Locals: 0
19 constraints
(Value Value Value Value Value Value Value Value const)
-1 0 0 0 1 0 0 0 0 = 0
0 -1 0 0 0 1 0 0 0 = 0
0 0 -1 0 0 0 0 1 0 = 0
1 0 0 0 0 0 0 0 0 >= 0
-1 0 0 0 0 0 0 0 7 >= 0
0 1 0 0 0 0 0 0 0 >= 0
0 -1 0 0 0 0 0 0 127 >= 0
0 0 1 0 0 0 0 0 0 >= 0
0 0 -1 0 0 0 0 0 511 >= 0
0 0 0 1 0 0 0 0 0 >= 0
0 0 0 -1 0 0 0 0 383 >= 0
0 0 0 0 1 0 0 0 0 >= 0
0 0 0 0 -1 0 0 0 7 >= 0
0 0 0 0 0 1 0 0 0 >= 0
0 0 0 0 0 -1 0 0 127 >= 0
0 0 0 0 0 0 1 0 0 >= 0
0 0 0 0 0 0 -1 0 63 >= 0
0 0 0 0 0 0 0 1 0 >= 0
0 0 0 0 0 0 0 -1 511 >= 0

这里我不理解的是为什么只有range var了, 理论上这里的map形式应该是:

{ [i, j, k, l] -> [i' = i, j' = j, k', l' = k] : 0 <= i <= 7 and 0 <= j <= 127 and 0 <= k <= 511 and 0 <= l <= 383 and 0 <= k' <= 63 }

得到新的dstRel后, 添加顺序约束, 这里内部是检查他们在共享循环中的顺序, 目前没有共享循环, 所以也不做什么.

// Add 'src' happens before 'dst' ordering constraints.
addOrderingConstraints(srcDomain, dstDomain, loopDepth, &dstRel);

最终就是检查约束: dstRel.isEmpty(), 这里isEmpty检查的是否存在整数解. 内部实现我看他是使用gaussianEliminateVars/fourierMotzkinEliminate以及GCDTest来保证存在整数解.

TODO 在多面体编译中, 依赖约束的存在解意味着什么?

3.3 getComputationSliceState

如果我们确定存在依赖,且依赖约束存在解, 就开始计算所计算出的点集.

getComputationSliceState(
Operation *depSourceOp, Operation *depSinkOp,
FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
bool isBackwardSlice, ComputationSliceState *sliceState)

首先这里的sourceOp是在前面执行的op, depSinkOp是在后面执行的. 依赖约束同上一小节. 然后删除不需要的维度:

// Project out dimensions other than those up to 'loopDepth'.
unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
unsigned num =
isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
dependenceConstraints->projectOut(pos, num);

因为之前构造dependenceConstraints的时候, 经过了compose, 所以他现在的变量按src [i,j,k,l], dest [i,j,k,l]的顺序排列, 同时目前要插入的循环深度loopDepth=1, 所以就是将dest [j,k,l]消除, 得到如下:

Domain: 0, Range: 5, Symbols: 0, Locals: 0
11 constraints
(Value Value Value Value Value const)
-1 0 0 0 1 0 = 0
1 0 0 0 0 0 >= 0
-1 0 0 0 0 7 >= 0
0 1 0 0 0 0 >= 0
0 -1 0 0 0 127 >= 0
0 0 1 0 0 0 >= 0
0 0 -1 0 0 511 >= 0
0 0 0 1 0 0 >= 0
0 0 0 -1 0 383 >= 0
0 0 0 0 1 0 >= 0
0 0 0 0 -1 7 >= 0

获得src节点中所有的循环变量i,j,k,l:

  // Add slice loop IV values to 'sliceState'.
unsigned offset = isBackwardSlice ? 0 : loopDepth;
unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
&sliceState->ivs);