diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 7ee20d23df3..f6eb08bb58c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1395,3 +1395,304 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc return %1, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests inputs are correctly split and fed into TPU computation for +// tiled input sharding. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CONST_SPLIT_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]]) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: device = "TPU_REPLICATED_CORE_1" + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 4 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" +// +// ----- + + +// Tests tile sharding of inputs with number of splits that does not evenly divide +// the input results in an error. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + +// The following topology is used in subsequent test cases: +// Proto debug string: +// mesh_shape: 2 +// mesh_shape: 1 +// mesh_shape: 2 +// num_tasks: 2 +// num_tpu_devices_per_task: 2 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// Tests inputs to TPUComputation that are tiled in multiple dimensions. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0) + // CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + + +// ----- + +// Tests inputs device assignment order is well preserved for tiled input sharding. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_shape { +// element_type: F32 +// dimensions: 2 +// dimensions: 2 +// layout { +// minor_to_major: 1 +// minor_to_major: 0 +// format: DENSE +// } +// is_dynamic_dimension: false +// is_dynamic_dimension: false +// } +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 3 +// tile_assignment_devices: 2 +// tile_assignment_devices: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00" +// +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" +// +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @tiled_input_sharding_with_device_assignment_order + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func @tiled_input_sharding_with_device_assignment_order(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0) + // CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"() + // CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 50b6555076d..e20e78a243c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -488,11 +488,11 @@ Operation* BuildExecuteOp( // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. -tf_device::ParallelExecuteOp BuildParallelExecuteOp( +LogicalResult BuildParallelExecuteOp( llvm::ArrayRef> execution_devices, llvm::ArrayRef output_sharding_config, Operation* compile_op, tf_device::LaunchFuncOp launch_func, - OpBuilder* builder) { + OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) { const int num_cores_per_replica = execution_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. @@ -510,20 +510,23 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp( for (Type t : output_types) concatenated_output_types.emplace_back(t); } - auto parallel_execute_op = builder->create( + *parallel_execute_op = builder->create( launch_func.getLoc(), num_cores_per_replica, concatenated_output_types); // Extract inputs for each region of the parallel_execute op. The i-th // element in the list represents the input lists to TPU computation for // i-th logical core. - auto input_list = tensorflow::ExtractInputsForLogicalDevices( - num_cores_per_replica, launch_func); + llvm::SmallVector, 4> input_list; + builder->setInsertionPoint(*parallel_execute_op); + auto result = tensorflow::ExtractInputsForLogicalDevices( + num_cores_per_replica, launch_func, builder, &input_list); + if (failed(result)) return failure(); const bool replicated = execution_devices.size() != 1; // For each logical core, create a region with TPUExecute op. assert(input_list.size() == num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { - auto& region = parallel_execute_op.GetRegionBlockWithIndex(core); + auto& region = parallel_execute_op->GetRegionBlockWithIndex(core); builder->setInsertionPointToEnd(®ion); // Create Execute op. @@ -551,7 +554,7 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp( region_launch_op.getResults()); } - return parallel_execute_op; + return success(); } tf_device::LaunchOp AssignDevicesToReplicatedExecute( @@ -703,9 +706,12 @@ LogicalResult Rewrite( if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. - tf_device::ParallelExecuteOp execute_op = BuildParallelExecuteOp( - tpu_device_assignment.execution_devices, output_shardings, compile_op, - launch_func, builder); + + tf_device::ParallelExecuteOp execute_op; + result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices, + output_shardings, compile_op, launch_func, + builder, &execute_op); + if (failed(result)) return failure(); // As tf_device.parallel_execute wraps # logical cores number of TPUExecute // ops, the number of return values of parallel_execute op exceeds that of diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index bbe91054b3b..bcf6e1b3496 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -16,10 +16,19 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -36,14 +45,135 @@ llvm::Optional ParseShardingAttribute( return sharding_attr.getValue(); } -llvm::SmallVector, 4> -ExtractInputsForLogicalDevices(int num_logical_cores, - mlir::tf_device::LaunchFuncOp launch_func) { +namespace { + +constexpr char kNumSplitAttr[] = "num_split"; + +// Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways +// in 'split_dimension' dimension and returns the split values. +mlir::LogicalResult CreateSplitOp(const int num_split, + const int split_dimension, + const mlir::Location& location, + mlir::Value src_input, + mlir::OpBuilder* builder, + mlir::TF::SplitOp* split_op) { + // Creates a const op to hold split dimension value. + auto split_dim_type = + mlir::RankedTensorType::get({}, builder->getIntegerType(32)); + auto split_dimension_attr = + mlir::DenseElementsAttr::get(split_dim_type, split_dimension); + auto split_dimension_op = builder->create( + location, split_dim_type, split_dimension_attr); + + // Correctly set output shapes of split op output if input shape is statically + // known. + mlir::Type output_type; + auto input_type = src_input.getType().cast(); + + if (input_type.hasRank()) { + if (input_type.getShape()[split_dimension] == + mlir::ShapedType::kDynamicSize) { + output_type = input_type; + } else { + auto shape = llvm::to_vector<4>(input_type.getShape()); + if (shape[split_dimension] % num_split != 0) { + return mlir::emitError( + location, + llvm::formatv( + "incorrect input sharding configuration received. " + "{0}-th dimension of the input must be evenly divisible by {1}", + split_dimension, num_split)); + } + + shape[split_dimension] = shape[split_dimension] / num_split; + output_type = + mlir::RankedTensorType::get(shape, input_type.getElementType()); + } + } else { + output_type = input_type; + } + + // Creates a split op that splits |src_input| along |split_dimension|. + llvm::SmallVector output_types(num_split, output_type); + *split_op = builder->create( + location, output_types, split_dimension_op.output(), src_input); + split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr( + builder->getIntegerType(32), num_split)); + return mlir::success(); +} + +// For tile sharded inputs to TPU computation, inject split op between the +// input values and TPU computation so that tiled input values are passed in +// as inputs to TPU computations. If more than one dimension is sharded, then +// a tree of connected split ops are added before tf_device.parallel_execute op. +mlir::LogicalResult HandleTileShardedInputs( + const mlir::Location& location, const xla::OpSharding& input_sharding, + const mlir::Value& original_source, mlir::OpBuilder* builder, + llvm::SmallVectorImpl* tiled_inputs) { + llvm::SmallVector split_ops_for_tiled_input; + split_ops_for_tiled_input.reserve( + input_sharding.tile_assignment_devices_size()); + + // Creates a tree of split nodes for sharding tiled inputs. Splits nodes + // are created such that input data is sharded in row major order. + // Split nodes at ith depth from the original input node represent nodes + // that split the input data at i-th dimension. + const auto& dimension_splits = input_sharding.tile_assignment_dimensions(); + for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) { + const int num_splits = num_splits_and_index.value(); + const int dimension_index = num_splits_and_index.index(); + if (num_splits == 1) continue; + + // Creates root split op. + if (split_ops_for_tiled_input.empty()) { + mlir::TF::SplitOp root_split_op; + auto result = CreateSplitOp(num_splits, dimension_index, location, + original_source, builder, &root_split_op); + if (mlir::failed(result)) return mlir::failure(); + + split_ops_for_tiled_input.emplace_back(root_split_op); + continue; + } + + llvm::SmallVector new_split_ops; + new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits); + + for (auto split_op : split_ops_for_tiled_input) { + for (auto parent_split_output_value : split_op.getResults()) { + mlir::TF::SplitOp child_split_op; + auto result = + CreateSplitOp(num_splits, dimension_index, location, + parent_split_output_value, builder, &child_split_op); + if (mlir::failed(result)) return mlir::failure(); + + new_split_ops.emplace_back(child_split_op); + } + } + + std::swap(new_split_ops, split_ops_for_tiled_input); + } + + // `split_ops_for_tiled_input` now includes final split nodes + // from which sharded data will be fed into TPUExcute ops -- sorted by + // row major order. + tiled_inputs->reserve(input_sharding.tile_assignment_devices_size()); + for (auto split_op : split_ops_for_tiled_input) + tiled_inputs->append(split_op.getResults().begin(), + split_op.getResults().end()); + + return mlir::success(); +} + +} // namespace + +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + mlir::OpBuilder* builder, + llvm::SmallVectorImpl>* input_list) { // Initialize the input list for each logical devices. - llvm::SmallVector, 4> input_list; - input_list.reserve(num_logical_cores); + input_list->reserve(num_logical_cores); for (int i = 0; i < num_logical_cores; ++i) - input_list.emplace_back(llvm::SmallVector()); + input_list->emplace_back(llvm::SmallVector()); llvm::SmallVector launch_func_inputs( launch_func.getOperands()); @@ -53,8 +183,8 @@ ExtractInputsForLogicalDevices(int num_logical_cores, // If sharding attribute does not exist, then all inputs are placed on 0th // logical core by default. if (!sharding_attrs) { - input_list[0] = launch_func_inputs; - return input_list; + (*input_list)[0] = launch_func_inputs; + return mlir::success(); } // Enumerate sharding configuration for each inputs. If input has replicate @@ -71,19 +201,32 @@ ExtractInputsForLogicalDevices(int num_logical_cores, sharding_attr.cast().getValue().str()); const auto input_sharing_type = sharding.type(); - if (input_sharing_type == xla::OpSharding::OTHER) - launch_func.emitError( - "tiled inputs are not yet supported for model parallelism"); + if (input_sharing_type == xla::OpSharding::OTHER) { + llvm::SmallVector tiled_inputs; + auto result = HandleTileShardedInputs( + launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); + if (mlir::failed(result)) return mlir::failure(); - if (input_sharing_type == xla::OpSharding::REPLICATED) { - for (auto inputs : input_list) inputs.emplace_back(input_value); + if (tiled_inputs.size() != num_logical_cores) + launch_func.emitError(llvm::formatv( + "incorrect {0}-th tiled input sharding received. " + "Product of tile sharding splits({1}) must be equal to " + "number of logical devices : {2}", + input_index, tiled_inputs.size(), num_logical_cores)); + + for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { + const int assigned_logical_device = sharding.tile_assignment_devices(i); + (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]); + } + } else if (input_sharing_type == xla::OpSharding::REPLICATED) { + for (auto inputs : *input_list) inputs.emplace_back(input_value); } else { assert(input_sharing_type == xla::OpSharding::MAXIMAL); const int logical_device_id = sharding.tile_assignment_devices(0); - input_list[logical_device_id].emplace_back(input_value); + (*input_list)[logical_device_id].emplace_back(input_value); } } - return input_list; + return mlir::success(); } mlir::LogicalResult ParseAndValidateOutputSharding( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 4f548ca95aa..f7a9dbf2c81 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project @@ -38,9 +39,10 @@ llvm::Optional ParseShardingAttribute( // i-th element is a list of mlir::Value's which represent inputs for the // TPU computation correponding to i-th logical device. If the attribute // does not exist, the all inputs are placed on logical core 0. -llvm::SmallVector, 4> -ExtractInputsForLogicalDevices(int num_logical_cores, - mlir::tf_device::LaunchFuncOp launch_func); +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func, + mlir::OpBuilder* builder, + llvm::SmallVectorImpl>* input_list); // Extracts a list of OpSharding that represent output sharding configuration // of `tf_device.launch`.