HloModule jit_reshard_3, is_scheduled=true, entry_computation_layout={(f32[2048,128]{1,0})->f32[8,2048]{1,0}}, num_partitions=256
%fused_computation (param_0.2: f32[1,16,8,128]) -> f32[8,2048] { %param_0.2 = f32[1,16,8,128]{3,2,1,0} parameter(0) %transpose.1 = f32[1,8,16,128]{3,1,2,0} transpose(%param_0.2), dimensions={0,2,1,3}, metadata={op_name="jit(reshard_3)/reshard" } %copy.1 = f32[1,8,16,128]{3,2,1,0} copy(%transpose.1), metadata={op_name="jit(reshard_3)/reshard" } ROOT %bitcast.3 = f32[8,2048]{1,0} bitcast(%copy.1), metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.1 (param_0.5: f32[2048,128], param_1.2: s32[256], param_2.1: u32[]) -> f32[1,1,8,128] { %param_0.5 = f32[2048,128]{1,0} parameter(0) %param_1.2 = s32[256]{0} parameter(1) %param_2.1 = u32[] parameter(2) %dynamic-slice.10 = s32[1]{0} dynamic-slice(%param_1.2, %param_2.1), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.5 = s32[] bitcast(%dynamic-slice.10), metadata={op_name="jit(reshard_3)/reshard" } %constant.22 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.9 = f32[128,128]{1,0} dynamic-slice(%param_0.5, %bitcast.5, %constant.22), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.4 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.9), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.16 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.4), slice={[0:1], [15:16], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.2 (param_0.8: f32[2048,128], param_1.5: s32[256], param_2.3: u32[]) -> f32[1,1,8,128] { %param_0.8 = f32[2048,128]{1,0} parameter(0) %param_1.5 = s32[256]{0} parameter(1) %param_2.3 = u32[] parameter(2) %dynamic-slice.12 = s32[1]{0} dynamic-slice(%param_1.5, %param_2.3), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.7 = s32[] bitcast(%dynamic-slice.12), metadata={op_name="jit(reshard_3)/reshard" } %constant.23 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.11 = f32[128,128]{1,0} dynamic-slice(%param_0.8, %bitcast.7, %constant.23), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.6 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.11), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.17 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.6), slice={[0:1], [14:15], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.3 (param_0.11: f32[2048,128], param_1.8: s32[256], param_2.5: u32[]) -> f32[1,1,8,128] { %param_0.11 = f32[2048,128]{1,0} parameter(0) %param_1.8 = s32[256]{0} parameter(1) %param_2.5 = u32[] parameter(2) %dynamic-slice.14 = s32[1]{0} dynamic-slice(%param_1.8, %param_2.5), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.9 = s32[] bitcast(%dynamic-slice.14), metadata={op_name="jit(reshard_3)/reshard" } %constant.24 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.13 = f32[128,128]{1,0} dynamic-slice(%param_0.11, %bitcast.9, %constant.24), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.8 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.13), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.18 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.8), slice={[0:1], [13:14], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.4 (param_0.14: f32[2048,128], param_1.11: s32[256], param_2.7: u32[]) -> f32[1,1,8,128] { %param_0.14 = f32[2048,128]{1,0} parameter(0) %param_1.11 = s32[256]{0} parameter(1) %param_2.7 = u32[] parameter(2) %dynamic-slice.16 = s32[1]{0} dynamic-slice(%param_1.11, %param_2.7), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.11 = s32[] bitcast(%dynamic-slice.16), metadata={op_name="jit(reshard_3)/reshard" } %constant.25 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.15 = f32[128,128]{1,0} dynamic-slice(%param_0.14, %bitcast.11, %constant.25), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.10 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.15), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.19 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.10), slice={[0:1], [12:13], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.5 (param_0.17: f32[2048,128], param_1.14: s32[256], param_2.9: u32[]) -> f32[1,1,8,128] { %param_0.17 = f32[2048,128]{1,0} parameter(0) %param_1.14 = s32[256]{0} parameter(1) %param_2.9 = u32[] parameter(2) %dynamic-slice.18 = s32[1]{0} dynamic-slice(%param_1.14, %param_2.9), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.13 = s32[] bitcast(%dynamic-slice.18), metadata={op_name="jit(reshard_3)/reshard" } %constant.26 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.17 = f32[128,128]{1,0} dynamic-slice(%param_0.17, %bitcast.13, %constant.26), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.12 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.17), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.20 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.12), slice={[0:1], [11:12], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.6 (param_0.20: f32[2048,128], param_1.17: s32[256], param_2.11: u32[]) -> f32[1,1,8,128] { %param_0.20 = f32[2048,128]{1,0} parameter(0) %param_1.17 = s32[256]{0} parameter(1) %param_2.11 = u32[] parameter(2) %dynamic-slice.20 = s32[1]{0} dynamic-slice(%param_1.17, %param_2.11), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.15 = s32[] bitcast(%dynamic-slice.20), metadata={op_name="jit(reshard_3)/reshard" } %constant.27 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.19 = f32[128,128]{1,0} dynamic-slice(%param_0.20, %bitcast.15, %constant.27), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.14 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.19), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.21 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.14), slice={[0:1], [10:11], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.7 (param_0.23: f32[2048,128], param_1.20: s32[256], param_2.13: u32[]) -> f32[1,1,8,128] { %param_0.23 = f32[2048,128]{1,0} parameter(0) %param_1.20 = s32[256]{0} parameter(1) %param_2.13 = u32[] parameter(2) %dynamic-slice.22 = s32[1]{0} dynamic-slice(%param_1.20, %param_2.13), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.17 = s32[] bitcast(%dynamic-slice.22), metadata={op_name="jit(reshard_3)/reshard" } %constant.28 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.21 = f32[128,128]{1,0} dynamic-slice(%param_0.23, %bitcast.17, %constant.28), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.16 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.21), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.22 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.16), slice={[0:1], [9:10], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.8 (param_0.26: f32[2048,128], param_1.23: s32[256], param_2.15: u32[]) -> f32[1,1,8,128] { %param_0.26 = f32[2048,128]{1,0} parameter(0) %param_1.23 = s32[256]{0} parameter(1) %param_2.15 = u32[] parameter(2) %dynamic-slice.24 = s32[1]{0} dynamic-slice(%param_1.23, %param_2.15), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.19 = s32[] bitcast(%dynamic-slice.24), metadata={op_name="jit(reshard_3)/reshard" } %constant.29 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.23 = f32[128,128]{1,0} dynamic-slice(%param_0.26, %bitcast.19, %constant.29), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.18 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.23), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.23 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.18), slice={[0:1], [8:9], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.9 (param_0.29: f32[2048,128], param_1.26: s32[256], param_2.17: u32[]) -> f32[1,1,8,128] { %param_0.29 = f32[2048,128]{1,0} parameter(0) %param_1.26 = s32[256]{0} parameter(1) %param_2.17 = u32[] parameter(2) %dynamic-slice.26 = s32[1]{0} dynamic-slice(%param_1.26, %param_2.17), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.21 = s32[] bitcast(%dynamic-slice.26), metadata={op_name="jit(reshard_3)/reshard" } %constant.30 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.25 = f32[128,128]{1,0} dynamic-slice(%param_0.29, %bitcast.21, %constant.30), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.20 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.25), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.24 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.20), slice={[0:1], [7:8], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.10 (param_0.32: f32[2048,128], param_1.29: s32[256], param_2.19: u32[]) -> f32[1,1,8,128] { %param_0.32 = f32[2048,128]{1,0} parameter(0) %param_1.29 = s32[256]{0} parameter(1) %param_2.19 = u32[] parameter(2) %dynamic-slice.28 = s32[1]{0} dynamic-slice(%param_1.29, %param_2.19), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.23 = s32[] bitcast(%dynamic-slice.28), metadata={op_name="jit(reshard_3)/reshard" } %constant.31 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.27 = f32[128,128]{1,0} dynamic-slice(%param_0.32, %bitcast.23, %constant.31), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.22 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.27), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.25 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.22), slice={[0:1], [6:7], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.11 (param_0.35: f32[2048,128], param_1.32: s32[256], param_2.21: u32[]) -> f32[1,1,8,128] { %param_0.35 = f32[2048,128]{1,0} parameter(0) %param_1.32 = s32[256]{0} parameter(1) %param_2.21 = u32[] parameter(2) %dynamic-slice.30 = s32[1]{0} dynamic-slice(%param_1.32, %param_2.21), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.25 = s32[] bitcast(%dynamic-slice.30), metadata={op_name="jit(reshard_3)/reshard" } %constant.32 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.29 = f32[128,128]{1,0} dynamic-slice(%param_0.35, %bitcast.25, %constant.32), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.24 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.29), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.26 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.24), slice={[0:1], [5:6], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.12 (param_0.38: f32[2048,128], param_1.35: s32[256], param_2.23: u32[]) -> f32[1,1,8,128] { %param_0.38 = f32[2048,128]{1,0} parameter(0) %param_1.35 = s32[256]{0} parameter(1) %param_2.23 = u32[] parameter(2) %dynamic-slice.32 = s32[1]{0} dynamic-slice(%param_1.35, %param_2.23), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.27 = s32[] bitcast(%dynamic-slice.32), metadata={op_name="jit(reshard_3)/reshard" } %constant.33 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.31 = f32[128,128]{1,0} dynamic-slice(%param_0.38, %bitcast.27, %constant.33), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.26 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.31), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.27 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.26), slice={[0:1], [4:5], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.13 (param_0.41: f32[2048,128], param_1.38: s32[256], param_2.25: u32[]) -> f32[1,1,8,128] { %param_0.41 = f32[2048,128]{1,0} parameter(0) %param_1.38 = s32[256]{0} parameter(1) %param_2.25 = u32[] parameter(2) %dynamic-slice.34 = s32[1]{0} dynamic-slice(%param_1.38, %param_2.25), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.29 = s32[] bitcast(%dynamic-slice.34), metadata={op_name="jit(reshard_3)/reshard" } %constant.34 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.33 = f32[128,128]{1,0} dynamic-slice(%param_0.41, %bitcast.29, %constant.34), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.28 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.33), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.28 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.28), slice={[0:1], [3:4], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.14 (param_0.44: f32[2048,128], param_1.41: s32[256], param_2.27: u32[]) -> f32[1,1,8,128] { %param_0.44 = f32[2048,128]{1,0} parameter(0) %param_1.41 = s32[256]{0} parameter(1) %param_2.27 = u32[] parameter(2) %dynamic-slice.36 = s32[1]{0} dynamic-slice(%param_1.41, %param_2.27), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.31 = s32[] bitcast(%dynamic-slice.36), metadata={op_name="jit(reshard_3)/reshard" } %constant.35 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.35 = f32[128,128]{1,0} dynamic-slice(%param_0.44, %bitcast.31, %constant.35), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.30 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.35), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.29 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.30), slice={[0:1], [2:3], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.15 (param_0.47: f32[2048,128], param_1.44: s32[256], param_2.29: u32[]) -> f32[1,1,8,128] { %param_0.47 = f32[2048,128]{1,0} parameter(0) %param_1.44 = s32[256]{0} parameter(1) %param_2.29 = u32[] parameter(2) %dynamic-slice.38 = s32[1]{0} dynamic-slice(%param_1.44, %param_2.29), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.33 = s32[] bitcast(%dynamic-slice.38), metadata={op_name="jit(reshard_3)/reshard" } %constant.36 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.37 = f32[128,128]{1,0} dynamic-slice(%param_0.47, %bitcast.33, %constant.36), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.32 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.37), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.30 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.32), slice={[0:1], [1:2], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
%fused_computation.16 (param_0.50: f32[2048,128], param_1.47: s32[256], param_2.31: u32[]) -> f32[1,1,8,128] { %param_0.50 = f32[2048,128]{1,0} parameter(0) %param_1.47 = s32[256]{0} parameter(1) %param_2.31 = u32[] parameter(2) %dynamic-slice.40 = s32[1]{0} dynamic-slice(%param_1.47, %param_2.31), dynamic_slice_sizes={1}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.35 = s32[] bitcast(%dynamic-slice.40), metadata={op_name="jit(reshard_3)/reshard" } %constant.37 = s32[] constant(0), metadata={op_name="jit(reshard_3)/reshard" } %dynamic-slice.39 = f32[128,128]{1,0} dynamic-slice(%param_0.50, %bitcast.35, %constant.37), dynamic_slice_sizes={128,128}, metadata={op_name="jit(reshard_3)/reshard" } %bitcast.34 = f32[1,16,8,128]{3,2,1,0} bitcast(%dynamic-slice.39), metadata={op_name="jit(reshard_3)/reshard" } ROOT %slice.31 = f32[1,1,8,128]{3,2,1,0} slice(%bitcast.34), slice={[0:1], [0:1], [0:8], [0:128]}, metadata={op_name="jit(reshard_3)/reshard" } }
ENTRY %main.0_spmd (param: f32[2048,128]) -> f32[8,2048] { %partition-id = u32[] partition-id() %param = f32[2048,128]{1,0} parameter(0), sharding={devices=[1,16,16]<=[16,16]T(1,0) last_tile_dim_replicate}, metadata={op_name="x"} %constant.3 = s32[256]{0} constant({...}), metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.15 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.16, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.1 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.2 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.3 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.4, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.4 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.5 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.6, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.6 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.7, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.7 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.8, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.8 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.9, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.9 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.10, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.10 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.11, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.11 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.12, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.12 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.13, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.13 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.14, metadata={op_name="jit(reshard_3)/reshard" } %bitcast_slice_fusion.14 = f32[1,1,8,128]{3,2,1,0} fusion(%param, %constant.3, %partition-id), kind=kLoop, calls=%fused_computation.15, metadata={op_name="jit(reshard_3)/reshard" } %all-to-all.1 = (f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}, f32[1,1,8,128]{3,2,1,0}) all-to-all(%bitcast_slice_fusion.15, %bitcast_slice_fusion.14, %bitcast_slice_fusion.13, %bitcast_slice_fusion.12, %bitcast_slice_fusion.11, %bitcast_slice_fusion.10, %bitcast_slice_fusion.9, %bitcast_slice_fusion.8, %bitcast_slice_fusion.7, %bitcast_slice_fusion.6, %bitcast_slice_fusion.5, %bitcast_slice_fusion.4, %bitcast_slice_fusion.3, %bitcast_slice_fusion.2, %bitcast_slice_fusion.1, %bitcast_slice_fusion), channel_id=1, replica_groups=[16,16]<=[256], metadata={op_name="jit(reshard_3)/reshard" } %get-tuple-element.2 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=0 %get-tuple-element.3 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=1 %get-tuple-element.4 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=2 %get-tuple-element.5 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=3 %get-tuple-element.6 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=4 %get-tuple-element.7 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=5 %get-tuple-element.8 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=6 %get-tuple-element.9 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=7 %get-tuple-element.10 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=8 %get-tuple-element.11 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=9 %get-tuple-element.12 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=10 %get-tuple-element.13 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=11 %get-tuple-element.14 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=12 %get-tuple-element.15 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=13 %get-tuple-element.16 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=14 %get-tuple-element.17 = f32[1,1,8,128]{3,2,1,0} get-tuple-element(%all-to-all.1), index=15 %concatenate = f32[1,16,8,128]{3,2,1,0} concatenate(%get-tuple-element.2, %get-tuple-element.3, %get-tuple-element.4, %get-tuple-element.5, %get-tuple-element.6, %get-tuple-element.7, %get-tuple-element.8, %get-tuple-element.9, %get-tuple-element.10, %get-tuple-element.11, %get-tuple-element.12, %get-tuple-element.13, %get-tuple-element.14, %get-tuple-element.15, %get-tuple-element.16, %get-tuple-element.17), dimensions={1}, metadata={op_name="jit(reshard_3)/reshard" } ROOT %copy_bitcast_fusion = f32[8,2048]{1,0} fusion(%concatenate), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(reshard_3)/reshard" } }
|