input : Operator chain Ops input : Permutation Perm = (lp1, lp2, ..., lpI ) input : Decomposition parameters S = (s1, s2, ..., sI) output : data movement volume DV output : memory usage MU
DV = 0, MU = 0 for op in Ops: total_DF = 0 for tensor T in op.allTensors(): DF = getFootprint(T, S) total_DF += DF if T in Ops.IOTensors(): DM = DF keep_reuse = true for loop lpi inreversed(Perm): if lpi in op.allLoops(): if lpi accesses tensor T: keep_reuse = false ifnot keep_reuse: DM *= ceil(Lpi / spi) DV += DM for loop lpi in Perm: if lpi is private to op: Perm.erase(lpi) MU = max(MU, total_DF) return DV, MU
defsolve_with_perms(all_perms: list[list[str]]): ap = AMPL() ap.eval('reset;')
ap.eval(""" var L2_m integer; var L2_k integer; var L2_n integer; var L1_m integer; var L1_k integer; var L1_n integer; var L0_m integer; var L0_k integer; var L0_n integer; var L2_A_DF = (L1_m * L0_m) * (L1_k * L0_k) * 4; var L2_B_DF = (L1_n * L0_n) * (L1_k * L0_k) * 4 * 4 * 4; var L2_C_DF = (L1_m * L0_m) * (L1_n * L0_n) * 4 * 4 * 4; var L1_A_DF = (L0_m) * (L0_k) * 4; var L1_B_DF = (L0_n) * (L0_k) * 4 * 4 * 4; var L1_C_DF = (L0_m) * (L0_n) * 4 * 4 * 4; """) access = { 'A': ['m', 'k'], 'B': ['n', 'k'], 'C': ['m', 'n'] } for level in [2, 1]: for bf in ['A', 'B', 'C']: level_perms = all_perms[len(all_perms) - level]
params = [f'L{level}_{bf}_DF'] for reuse_dim in level_perms[::-1]: if reuse_dim in access[bf]: break else: params.append(f'L{level}_{reuse_dim}') ap.eval(f'var L{level}_{bf}_DM = {" * ".join(params)};')
ap.eval(""" subject to L2_m_c: L2_m >= 1; subject to L2_k_c: L2_k >= 1; subject to L2_n_c: L2_n >= 1; subject to L1_m_c: L1_m >= 1; subject to L1_k_c: L1_k >= 1; subject to L1_n_c: L1_n >= 1; subject to L0_m_c: L0_m >= 1; subject to L0_k_c: L0_k >= 1; subject to L0_n_c: L0_n >= 1; subject to M_c: (L0_m*L1_m*L2_m) = 256; subject to K_c: (L0_k*L1_k*L2_k) = 2048; subject to N_c: (L0_n*L1_n*L2_n) = 64; subject to l1_capacity_c: (L1_A_DF + L1_B_DF + L1_C_DF) <= (1024 * 1024); subject to l2_capacity_c: (L2_A_DF + L2_B_DF + L2_C_DF) <= (512 * 1024); """)