Supported tiled input sharding for model parallelism.

PiperOrigin-RevId: 301882389
Change-Id: I7da33977da3f05881f7dc66742cd0fd9cb89d358
This commit is contained in:
A. Unique TensorFlower 2020-03-19 13:21:09 -07:00 committed by TensorFlower Gardener
parent 71313bcb18
commit 0c53d83001
4 changed files with 480 additions and 28 deletions

View File

@ -1395,3 +1395,304 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
return %1, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests inputs are correctly split and fed into TPU computation for
// tiled input sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_tiled_input
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func @parallel_execute_with_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[CONST_SPLIT_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]])
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
//
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1)
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: device = "TPU_REPLICATED_CORE_1"
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 4
// tile_assignment_devices: 0
// tile_assignment_devices: 1
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// -----
// Tests tile sharding of inputs with number of splits that does not evenly divide
// the input results in an error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}}
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// The following topology is used in subsequent test cases:
// Proto debug string:
// mesh_shape: 2
// mesh_shape: 1
// mesh_shape: 2
// num_tasks: 2
// num_tpu_devices_per_task: 2
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// tile_assignment_devices: 2
// tile_assignment_devices: 3
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// Tests inputs to TPUComputation that are tiled in multiple dimensions.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
// CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
// CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1)
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3)
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4)
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests inputs device assignment order is well preserved for tiled input sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 3
// tile_assignment_devices: 2
// tile_assignment_devices: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00"
//
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @tiled_input_sharding_with_device_assignment_order
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func @tiled_input_sharding_with_device_assignment_order(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
// CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
// CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"()
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#1)
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2)
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[COMPILE]]#3)
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4)
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}

View File

@ -488,11 +488,11 @@ Operation* BuildExecuteOp(
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
// represent execution of TPU program in multiple logical cores.
tf_device::ParallelExecuteOp BuildParallelExecuteOp(
LogicalResult BuildParallelExecuteOp(
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
Operation* compile_op, tf_device::LaunchFuncOp launch_func,
OpBuilder* builder) {
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
const int num_cores_per_replica = execution_devices.front().size();
// parallel_execute op returns concatenated list of return values of
// all its regions.
@ -510,20 +510,23 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp(
for (Type t : output_types) concatenated_output_types.emplace_back(t);
}
auto parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
*parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
launch_func.getLoc(), num_cores_per_replica, concatenated_output_types);
// Extract inputs for each region of the parallel_execute op. The i-th
// element in the list represents the input lists to TPU computation for
// i-th logical core.
auto input_list = tensorflow::ExtractInputsForLogicalDevices(
num_cores_per_replica, launch_func);
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
builder->setInsertionPoint(*parallel_execute_op);
auto result = tensorflow::ExtractInputsForLogicalDevices(
num_cores_per_replica, launch_func, builder, &input_list);
if (failed(result)) return failure();
const bool replicated = execution_devices.size() != 1;
// For each logical core, create a region with TPUExecute op.
assert(input_list.size() == num_cores_per_replica);
for (int core = 0; core < num_cores_per_replica; ++core) {
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core);
auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
builder->setInsertionPointToEnd(&region);
// Create Execute op.
@ -551,7 +554,7 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp(
region_launch_op.getResults());
}
return parallel_execute_op;
return success();
}
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
@ -703,9 +706,12 @@ LogicalResult Rewrite(
if (num_cores_per_replica > 1) {
// For model parallelism, tf_device.parallel_execute is used to express
// concurrent device execution across multiple logical devices.
tf_device::ParallelExecuteOp execute_op = BuildParallelExecuteOp(
tpu_device_assignment.execution_devices, output_shardings, compile_op,
launch_func, builder);
tf_device::ParallelExecuteOp execute_op;
result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
output_shardings, compile_op, launch_func,
builder, &execute_op);
if (failed(result)) return failure();
// As tf_device.parallel_execute wraps # logical cores number of TPUExecute
// ops, the number of return values of parallel_execute op exceeds that of

View File

@ -16,10 +16,19 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@ -36,14 +45,135 @@ llvm::Optional<mlir::StringRef> ParseShardingAttribute(
return sharding_attr.getValue();
}
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
ExtractInputsForLogicalDevices(int num_logical_cores,
mlir::tf_device::LaunchFuncOp launch_func) {
namespace {
constexpr char kNumSplitAttr[] = "num_split";
// Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
// in 'split_dimension' dimension and returns the split values.
mlir::LogicalResult CreateSplitOp(const int num_split,
const int split_dimension,
const mlir::Location& location,
mlir::Value src_input,
mlir::OpBuilder* builder,
mlir::TF::SplitOp* split_op) {
// Creates a const op to hold split dimension value.
auto split_dim_type =
mlir::RankedTensorType::get({}, builder->getIntegerType(32));
auto split_dimension_attr =
mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
location, split_dim_type, split_dimension_attr);
// Correctly set output shapes of split op output if input shape is statically
// known.
mlir::Type output_type;
auto input_type = src_input.getType().cast<mlir::TensorType>();
if (input_type.hasRank()) {
if (input_type.getShape()[split_dimension] ==
mlir::ShapedType::kDynamicSize) {
output_type = input_type;
} else {
auto shape = llvm::to_vector<4>(input_type.getShape());
if (shape[split_dimension] % num_split != 0) {
return mlir::emitError(
location,
llvm::formatv(
"incorrect input sharding configuration received. "
"{0}-th dimension of the input must be evenly divisible by {1}",
split_dimension, num_split));
}
shape[split_dimension] = shape[split_dimension] / num_split;
output_type =
mlir::RankedTensorType::get(shape, input_type.getElementType());
}
} else {
output_type = input_type;
}
// Creates a split op that splits |src_input| along |split_dimension|.
llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
*split_op = builder->create<mlir::TF::SplitOp>(
location, output_types, split_dimension_op.output(), src_input);
split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr(
builder->getIntegerType(32), num_split));
return mlir::success();
}
// For tile sharded inputs to TPU computation, inject split op between the
// input values and TPU computation so that tiled input values are passed in
// as inputs to TPU computations. If more than one dimension is sharded, then
// a tree of connected split ops are added before tf_device.parallel_execute op.
mlir::LogicalResult HandleTileShardedInputs(
const mlir::Location& location, const xla::OpSharding& input_sharding,
const mlir::Value& original_source, mlir::OpBuilder* builder,
llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
split_ops_for_tiled_input.reserve(
input_sharding.tile_assignment_devices_size());
// Creates a tree of split nodes for sharding tiled inputs. Splits nodes
// are created such that input data is sharded in row major order.
// Split nodes at ith depth from the original input node represent nodes
// that split the input data at i-th dimension.
const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
const int num_splits = num_splits_and_index.value();
const int dimension_index = num_splits_and_index.index();
if (num_splits == 1) continue;
// Creates root split op.
if (split_ops_for_tiled_input.empty()) {
mlir::TF::SplitOp root_split_op;
auto result = CreateSplitOp(num_splits, dimension_index, location,
original_source, builder, &root_split_op);
if (mlir::failed(result)) return mlir::failure();
split_ops_for_tiled_input.emplace_back(root_split_op);
continue;
}
llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
for (auto split_op : split_ops_for_tiled_input) {
for (auto parent_split_output_value : split_op.getResults()) {
mlir::TF::SplitOp child_split_op;
auto result =
CreateSplitOp(num_splits, dimension_index, location,
parent_split_output_value, builder, &child_split_op);
if (mlir::failed(result)) return mlir::failure();
new_split_ops.emplace_back(child_split_op);
}
}
std::swap(new_split_ops, split_ops_for_tiled_input);
}
// `split_ops_for_tiled_input` now includes final split nodes
// from which sharded data will be fed into TPUExcute ops -- sorted by
// row major order.
tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
for (auto split_op : split_ops_for_tiled_input)
tiled_inputs->append(split_op.getResults().begin(),
split_op.getResults().end());
return mlir::success();
}
} // namespace
mlir::LogicalResult ExtractInputsForLogicalDevices(
int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func,
mlir::OpBuilder* builder,
llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
// Initialize the input list for each logical devices.
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
input_list.reserve(num_logical_cores);
input_list->reserve(num_logical_cores);
for (int i = 0; i < num_logical_cores; ++i)
input_list.emplace_back(llvm::SmallVector<mlir::Value, 4>());
input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
llvm::SmallVector<mlir::Value, 4> launch_func_inputs(
launch_func.getOperands());
@ -53,8 +183,8 @@ ExtractInputsForLogicalDevices(int num_logical_cores,
// If sharding attribute does not exist, then all inputs are placed on 0th
// logical core by default.
if (!sharding_attrs) {
input_list[0] = launch_func_inputs;
return input_list;
(*input_list)[0] = launch_func_inputs;
return mlir::success();
}
// Enumerate sharding configuration for each inputs. If input has replicate
@ -71,19 +201,32 @@ ExtractInputsForLogicalDevices(int num_logical_cores,
sharding_attr.cast<mlir::StringAttr>().getValue().str());
const auto input_sharing_type = sharding.type();
if (input_sharing_type == xla::OpSharding::OTHER)
launch_func.emitError(
"tiled inputs are not yet supported for model parallelism");
if (input_sharing_type == xla::OpSharding::OTHER) {
llvm::SmallVector<mlir::Value, 4> tiled_inputs;
auto result = HandleTileShardedInputs(
launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
if (mlir::failed(result)) return mlir::failure();
if (input_sharing_type == xla::OpSharding::REPLICATED) {
for (auto inputs : input_list) inputs.emplace_back(input_value);
if (tiled_inputs.size() != num_logical_cores)
launch_func.emitError(llvm::formatv(
"incorrect {0}-th tiled input sharding received. "
"Product of tile sharding splits({1}) must be equal to "
"number of logical devices : {2}",
input_index, tiled_inputs.size(), num_logical_cores));
for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
const int assigned_logical_device = sharding.tile_assignment_devices(i);
(*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
}
} else if (input_sharing_type == xla::OpSharding::REPLICATED) {
for (auto inputs : *input_list) inputs.emplace_back(input_value);
} else {
assert(input_sharing_type == xla::OpSharding::MAXIMAL);
const int logical_device_id = sharding.tile_assignment_devices(0);
input_list[logical_device_id].emplace_back(input_value);
(*input_list)[logical_device_id].emplace_back(input_value);
}
}
return input_list;
return mlir::success();
}
mlir::LogicalResult ParseAndValidateOutputSharding(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
@ -38,9 +39,10 @@ llvm::Optional<mlir::StringRef> ParseShardingAttribute(
// i-th element is a list of mlir::Value's which represent inputs for the
// TPU computation correponding to i-th logical device. If the attribute
// does not exist, the all inputs are placed on logical core 0.
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
ExtractInputsForLogicalDevices(int num_logical_cores,
mlir::tf_device::LaunchFuncOp launch_func);
mlir::LogicalResult ExtractInputsForLogicalDevices(
int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func,
mlir::OpBuilder* builder,
llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list);
// Extracts a list of OpSharding that represent output sharding configuration
// of `tf_device.launch`.