Supported tiled input sharding for model parallelism.
PiperOrigin-RevId: 301882389 Change-Id: I7da33977da3f05881f7dc66742cd0fd9cb89d358
This commit is contained in:
parent
71313bcb18
commit
0c53d83001
@ -1395,3 +1395,304 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
return %1, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests inputs are correctly split and fed into TPU computation for
|
||||
// tiled input sharding.
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
// input 0
|
||||
// type: OTHER
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_devices: 0
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\03\1A\02\01\02\22\02\00\01"
|
||||
//
|
||||
// input 1
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\01"
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_tiled_input
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
|
||||
func @parallel_execute_with_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
|
||||
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
|
||||
// CHECK-SAME: devices =
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
|
||||
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: %[[CONST_SPLIT_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]])
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
//
|
||||
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1)
|
||||
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
|
||||
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
|
||||
// CHECK: device = "TPU_REPLICATED_CORE_1"
|
||||
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
// input 0
|
||||
// type: OTHER
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_dimensions: 4
|
||||
// tile_assignment_devices: 0
|
||||
// tile_assignment_devices: 1
|
||||
//
|
||||
// input 1
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\01"
|
||||
//
|
||||
// -----
|
||||
|
||||
|
||||
// Tests tile sharding of inputs with number of splits that does not evenly divide
|
||||
// the input results in an error.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}}
|
||||
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The following topology is used in subsequent test cases:
|
||||
// Proto debug string:
|
||||
// mesh_shape: 2
|
||||
// mesh_shape: 1
|
||||
// mesh_shape: 2
|
||||
// num_tasks: 2
|
||||
// num_tpu_devices_per_task: 2
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 1
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 1
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 1
|
||||
// device_coordinates: 0
|
||||
// device_coordinates: 1
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
// input 0
|
||||
// type: OTHER
|
||||
// tile_shape {
|
||||
// element_type: F32
|
||||
// dimensions: 2
|
||||
// dimensions: 2
|
||||
// layout {
|
||||
// minor_to_major: 1
|
||||
// minor_to_major: 0
|
||||
// format: DENSE
|
||||
// }
|
||||
// is_dynamic_dimension: false
|
||||
// is_dynamic_dimension: false
|
||||
// }
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_devices: 0
|
||||
// tile_assignment_devices: 1
|
||||
// tile_assignment_devices: 2
|
||||
// tile_assignment_devices: 3
|
||||
// Serialized string:
|
||||
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03"
|
||||
//
|
||||
// input 1
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\01"
|
||||
|
||||
// Tests inputs to TPUComputation that are tiled in multiple dimensions.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
|
||||
func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
|
||||
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
|
||||
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
|
||||
// CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
|
||||
// CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1)
|
||||
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
|
||||
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3)
|
||||
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4)
|
||||
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
|
||||
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests inputs device assignment order is well preserved for tiled input sharding.
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
// input 0
|
||||
// type: OTHER
|
||||
// tile_shape {
|
||||
// element_type: F32
|
||||
// dimensions: 2
|
||||
// dimensions: 2
|
||||
// layout {
|
||||
// minor_to_major: 1
|
||||
// minor_to_major: 0
|
||||
// format: DENSE
|
||||
// }
|
||||
// is_dynamic_dimension: false
|
||||
// is_dynamic_dimension: false
|
||||
// }
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_devices: 3
|
||||
// tile_assignment_devices: 2
|
||||
// tile_assignment_devices: 1
|
||||
// tile_assignment_devices: 0
|
||||
// Serialized string:
|
||||
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00"
|
||||
//
|
||||
//
|
||||
// input 1
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\01"
|
||||
//
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @tiled_input_sharding_with_device_assignment_order
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
|
||||
func @tiled_input_sharding_with_device_assignment_order(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
|
||||
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
|
||||
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
|
||||
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: %[[CONST_SPLIT_0_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
|
||||
// CHECK: %[[CONST_SPLIT_1_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
|
||||
// CHECK: %[[CONST_SPLIT_2_DIM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#1)
|
||||
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2)
|
||||
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[COMPILE]]#3)
|
||||
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
|
||||
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4)
|
||||
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
|
||||
%1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
@ -488,11 +488,11 @@ Operation* BuildExecuteOp(
|
||||
|
||||
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
|
||||
// represent execution of TPU program in multiple logical cores.
|
||||
tf_device::ParallelExecuteOp BuildParallelExecuteOp(
|
||||
LogicalResult BuildParallelExecuteOp(
|
||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
||||
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
|
||||
Operation* compile_op, tf_device::LaunchFuncOp launch_func,
|
||||
OpBuilder* builder) {
|
||||
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
|
||||
const int num_cores_per_replica = execution_devices.front().size();
|
||||
// parallel_execute op returns concatenated list of return values of
|
||||
// all its regions.
|
||||
@ -510,20 +510,23 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp(
|
||||
for (Type t : output_types) concatenated_output_types.emplace_back(t);
|
||||
}
|
||||
|
||||
auto parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
|
||||
*parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
|
||||
launch_func.getLoc(), num_cores_per_replica, concatenated_output_types);
|
||||
|
||||
// Extract inputs for each region of the parallel_execute op. The i-th
|
||||
// element in the list represents the input lists to TPU computation for
|
||||
// i-th logical core.
|
||||
auto input_list = tensorflow::ExtractInputsForLogicalDevices(
|
||||
num_cores_per_replica, launch_func);
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
|
||||
builder->setInsertionPoint(*parallel_execute_op);
|
||||
auto result = tensorflow::ExtractInputsForLogicalDevices(
|
||||
num_cores_per_replica, launch_func, builder, &input_list);
|
||||
if (failed(result)) return failure();
|
||||
|
||||
const bool replicated = execution_devices.size() != 1;
|
||||
// For each logical core, create a region with TPUExecute op.
|
||||
assert(input_list.size() == num_cores_per_replica);
|
||||
for (int core = 0; core < num_cores_per_replica; ++core) {
|
||||
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core);
|
||||
auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
|
||||
builder->setInsertionPointToEnd(®ion);
|
||||
|
||||
// Create Execute op.
|
||||
@ -551,7 +554,7 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp(
|
||||
region_launch_op.getResults());
|
||||
}
|
||||
|
||||
return parallel_execute_op;
|
||||
return success();
|
||||
}
|
||||
|
||||
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||
@ -703,9 +706,12 @@ LogicalResult Rewrite(
|
||||
if (num_cores_per_replica > 1) {
|
||||
// For model parallelism, tf_device.parallel_execute is used to express
|
||||
// concurrent device execution across multiple logical devices.
|
||||
tf_device::ParallelExecuteOp execute_op = BuildParallelExecuteOp(
|
||||
tpu_device_assignment.execution_devices, output_shardings, compile_op,
|
||||
launch_func, builder);
|
||||
|
||||
tf_device::ParallelExecuteOp execute_op;
|
||||
result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
|
||||
output_shardings, compile_op, launch_func,
|
||||
builder, &execute_op);
|
||||
if (failed(result)) return failure();
|
||||
|
||||
// As tf_device.parallel_execute wraps # logical cores number of TPUExecute
|
||||
// ops, the number of return values of parallel_execute op exceeds that of
|
||||
|
@ -16,10 +16,19 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -36,14 +45,135 @@ llvm::Optional<mlir::StringRef> ParseShardingAttribute(
|
||||
return sharding_attr.getValue();
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
|
||||
ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
mlir::tf_device::LaunchFuncOp launch_func) {
|
||||
namespace {
|
||||
|
||||
constexpr char kNumSplitAttr[] = "num_split";
|
||||
|
||||
// Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
|
||||
// in 'split_dimension' dimension and returns the split values.
|
||||
mlir::LogicalResult CreateSplitOp(const int num_split,
|
||||
const int split_dimension,
|
||||
const mlir::Location& location,
|
||||
mlir::Value src_input,
|
||||
mlir::OpBuilder* builder,
|
||||
mlir::TF::SplitOp* split_op) {
|
||||
// Creates a const op to hold split dimension value.
|
||||
auto split_dim_type =
|
||||
mlir::RankedTensorType::get({}, builder->getIntegerType(32));
|
||||
auto split_dimension_attr =
|
||||
mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
|
||||
auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
|
||||
location, split_dim_type, split_dimension_attr);
|
||||
|
||||
// Correctly set output shapes of split op output if input shape is statically
|
||||
// known.
|
||||
mlir::Type output_type;
|
||||
auto input_type = src_input.getType().cast<mlir::TensorType>();
|
||||
|
||||
if (input_type.hasRank()) {
|
||||
if (input_type.getShape()[split_dimension] ==
|
||||
mlir::ShapedType::kDynamicSize) {
|
||||
output_type = input_type;
|
||||
} else {
|
||||
auto shape = llvm::to_vector<4>(input_type.getShape());
|
||||
if (shape[split_dimension] % num_split != 0) {
|
||||
return mlir::emitError(
|
||||
location,
|
||||
llvm::formatv(
|
||||
"incorrect input sharding configuration received. "
|
||||
"{0}-th dimension of the input must be evenly divisible by {1}",
|
||||
split_dimension, num_split));
|
||||
}
|
||||
|
||||
shape[split_dimension] = shape[split_dimension] / num_split;
|
||||
output_type =
|
||||
mlir::RankedTensorType::get(shape, input_type.getElementType());
|
||||
}
|
||||
} else {
|
||||
output_type = input_type;
|
||||
}
|
||||
|
||||
// Creates a split op that splits |src_input| along |split_dimension|.
|
||||
llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
|
||||
*split_op = builder->create<mlir::TF::SplitOp>(
|
||||
location, output_types, split_dimension_op.output(), src_input);
|
||||
split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr(
|
||||
builder->getIntegerType(32), num_split));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// For tile sharded inputs to TPU computation, inject split op between the
|
||||
// input values and TPU computation so that tiled input values are passed in
|
||||
// as inputs to TPU computations. If more than one dimension is sharded, then
|
||||
// a tree of connected split ops are added before tf_device.parallel_execute op.
|
||||
mlir::LogicalResult HandleTileShardedInputs(
|
||||
const mlir::Location& location, const xla::OpSharding& input_sharding,
|
||||
const mlir::Value& original_source, mlir::OpBuilder* builder,
|
||||
llvm::SmallVectorImpl<mlir::Value>* tiled_inputs) {
|
||||
llvm::SmallVector<mlir::TF::SplitOp, 4> split_ops_for_tiled_input;
|
||||
split_ops_for_tiled_input.reserve(
|
||||
input_sharding.tile_assignment_devices_size());
|
||||
|
||||
// Creates a tree of split nodes for sharding tiled inputs. Splits nodes
|
||||
// are created such that input data is sharded in row major order.
|
||||
// Split nodes at ith depth from the original input node represent nodes
|
||||
// that split the input data at i-th dimension.
|
||||
const auto& dimension_splits = input_sharding.tile_assignment_dimensions();
|
||||
for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) {
|
||||
const int num_splits = num_splits_and_index.value();
|
||||
const int dimension_index = num_splits_and_index.index();
|
||||
if (num_splits == 1) continue;
|
||||
|
||||
// Creates root split op.
|
||||
if (split_ops_for_tiled_input.empty()) {
|
||||
mlir::TF::SplitOp root_split_op;
|
||||
auto result = CreateSplitOp(num_splits, dimension_index, location,
|
||||
original_source, builder, &root_split_op);
|
||||
if (mlir::failed(result)) return mlir::failure();
|
||||
|
||||
split_ops_for_tiled_input.emplace_back(root_split_op);
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::TF::SplitOp, 4> new_split_ops;
|
||||
new_split_ops.reserve(split_ops_for_tiled_input.size() * num_splits);
|
||||
|
||||
for (auto split_op : split_ops_for_tiled_input) {
|
||||
for (auto parent_split_output_value : split_op.getResults()) {
|
||||
mlir::TF::SplitOp child_split_op;
|
||||
auto result =
|
||||
CreateSplitOp(num_splits, dimension_index, location,
|
||||
parent_split_output_value, builder, &child_split_op);
|
||||
if (mlir::failed(result)) return mlir::failure();
|
||||
|
||||
new_split_ops.emplace_back(child_split_op);
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(new_split_ops, split_ops_for_tiled_input);
|
||||
}
|
||||
|
||||
// `split_ops_for_tiled_input` now includes final split nodes
|
||||
// from which sharded data will be fed into TPUExcute ops -- sorted by
|
||||
// row major order.
|
||||
tiled_inputs->reserve(input_sharding.tile_assignment_devices_size());
|
||||
for (auto split_op : split_ops_for_tiled_input)
|
||||
tiled_inputs->append(split_op.getResults().begin(),
|
||||
split_op.getResults().end());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mlir::LogicalResult ExtractInputsForLogicalDevices(
|
||||
int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func,
|
||||
mlir::OpBuilder* builder,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list) {
|
||||
// Initialize the input list for each logical devices.
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
|
||||
input_list.reserve(num_logical_cores);
|
||||
input_list->reserve(num_logical_cores);
|
||||
for (int i = 0; i < num_logical_cores; ++i)
|
||||
input_list.emplace_back(llvm::SmallVector<mlir::Value, 4>());
|
||||
input_list->emplace_back(llvm::SmallVector<mlir::Value, 4>());
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> launch_func_inputs(
|
||||
launch_func.getOperands());
|
||||
@ -53,8 +183,8 @@ ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
// If sharding attribute does not exist, then all inputs are placed on 0th
|
||||
// logical core by default.
|
||||
if (!sharding_attrs) {
|
||||
input_list[0] = launch_func_inputs;
|
||||
return input_list;
|
||||
(*input_list)[0] = launch_func_inputs;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Enumerate sharding configuration for each inputs. If input has replicate
|
||||
@ -71,19 +201,32 @@ ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
sharding_attr.cast<mlir::StringAttr>().getValue().str());
|
||||
|
||||
const auto input_sharing_type = sharding.type();
|
||||
if (input_sharing_type == xla::OpSharding::OTHER)
|
||||
launch_func.emitError(
|
||||
"tiled inputs are not yet supported for model parallelism");
|
||||
if (input_sharing_type == xla::OpSharding::OTHER) {
|
||||
llvm::SmallVector<mlir::Value, 4> tiled_inputs;
|
||||
auto result = HandleTileShardedInputs(
|
||||
launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
|
||||
if (mlir::failed(result)) return mlir::failure();
|
||||
|
||||
if (input_sharing_type == xla::OpSharding::REPLICATED) {
|
||||
for (auto inputs : input_list) inputs.emplace_back(input_value);
|
||||
if (tiled_inputs.size() != num_logical_cores)
|
||||
launch_func.emitError(llvm::formatv(
|
||||
"incorrect {0}-th tiled input sharding received. "
|
||||
"Product of tile sharding splits({1}) must be equal to "
|
||||
"number of logical devices : {2}",
|
||||
input_index, tiled_inputs.size(), num_logical_cores));
|
||||
|
||||
for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) {
|
||||
const int assigned_logical_device = sharding.tile_assignment_devices(i);
|
||||
(*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]);
|
||||
}
|
||||
} else if (input_sharing_type == xla::OpSharding::REPLICATED) {
|
||||
for (auto inputs : *input_list) inputs.emplace_back(input_value);
|
||||
} else {
|
||||
assert(input_sharing_type == xla::OpSharding::MAXIMAL);
|
||||
const int logical_device_id = sharding.tile_assignment_devices(0);
|
||||
input_list[logical_device_id].emplace_back(input_value);
|
||||
(*input_list)[logical_device_id].emplace_back(input_value);
|
||||
}
|
||||
}
|
||||
return input_list;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult ParseAndValidateOutputSharding(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
@ -38,9 +39,10 @@ llvm::Optional<mlir::StringRef> ParseShardingAttribute(
|
||||
// i-th element is a list of mlir::Value's which represent inputs for the
|
||||
// TPU computation correponding to i-th logical device. If the attribute
|
||||
// does not exist, the all inputs are placed on logical core 0.
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
|
||||
ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
mlir::tf_device::LaunchFuncOp launch_func);
|
||||
mlir::LogicalResult ExtractInputsForLogicalDevices(
|
||||
int num_logical_cores, mlir::tf_device::LaunchFuncOp launch_func,
|
||||
mlir::OpBuilder* builder,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<mlir::Value, 4>>* input_list);
|
||||
|
||||
// Extracts a list of OpSharding that represent output sharding configuration
|
||||
// of `tf_device.launch`.
|
||||
|
Loading…
Reference in New Issue
Block a user