Set host device for outside compilation LaunchOp.
PiperOrigin-RevId: 318092701 Change-Id: Ib8f9de05110030a54be12c36ba004939c4c9d832
This commit is contained in:
parent
fcad3004bc
commit
2434d24013
@ -2,460 +2,455 @@
|
||||
|
||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||
|
||||
func @missing_outside_compilation_attribute() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
|
||||
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
|
||||
// CHECK-NOT: "tf_device.parallel_execute"
|
||||
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
||||
|
||||
// CHECK-NOT: "tf_device.parallel_execute"
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||
func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
|
||||
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.E"
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
|
||||
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||
func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
|
||||
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.cluster"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
|
||||
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.E"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
|
||||
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single host->device output.
|
||||
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
|
||||
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
|
||||
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
|
||||
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
|
||||
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.cluster"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @return_host_output_outside_compilation
|
||||
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: tf_device.return %[[HOST_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single device->host input.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single input/output.
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
|
||||
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
|
||||
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single host->device output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple input/output.
|
||||
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
|
||||
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
|
||||
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %8 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with input/output.
|
||||
// CHECK-LABEL: func @return_host_output_outside_compilation
|
||||
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: tf_device.return %[[HOST_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
|
||||
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
|
||||
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %7 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single input/output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
|
||||
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_input_single_outside_compilation
|
||||
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple input/output.
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with single device->host input.
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
|
||||
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %8 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
|
||||
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a multiple outside compiled clusters with input/output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
|
||||
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
|
||||
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
|
||||
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %7 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
|
||||
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
|
||||
|
||||
// Tests only directly used results of tpu cluster are remapped with
|
||||
// parallel_execute.
|
||||
// CHECK-LABEL: func @mixed_input_single_outside_compilation
|
||||
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @remapped_results
|
||||
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2:2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %2#1 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
|
||||
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
|
||||
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests only directly used results of tpu cluster are remapped with
|
||||
// parallel_execute.
|
||||
|
||||
// CHECK-LABEL: func @remapped_results
|
||||
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2:2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %2#1 : tensor<?xi32>
|
||||
}
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
@ -91,13 +93,14 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
|
||||
|
||||
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
|
||||
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
|
||||
OpBuilder* builder, Operation* last_cluster_op) {
|
||||
// TODO(b/154363171): Set the CPU device.
|
||||
OpBuilder* builder, Operation* last_cluster_op,
|
||||
llvm::StringRef host_device) {
|
||||
// An empty string placeholder is used for the device as that will be later
|
||||
// populated with the device of the associated TPUReplicateMetadata op.
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
auto launch_op = builder->create<tf_device::LaunchOp>(
|
||||
last_cluster_op->getLoc(), builder->getStringAttr(""), result_types);
|
||||
last_cluster_op->getLoc(), builder->getStringAttr(host_device),
|
||||
result_types);
|
||||
|
||||
launch_op.body().push_back(new Block);
|
||||
|
||||
@ -253,8 +256,9 @@ void MoveOutsideCompiledOps(
|
||||
|
||||
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
||||
// 'launch` as regions.
|
||||
void CreateParallelExecuteFromOutsideClusters(
|
||||
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) {
|
||||
void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster,
|
||||
const OutsideClusterMap& clusters,
|
||||
llvm::StringRef host_device) {
|
||||
OpBuilder builder(tpu_cluster);
|
||||
// Create parallel_execute regions. The original TPU cluster computation
|
||||
// is the extra region.
|
||||
@ -269,8 +273,8 @@ void CreateParallelExecuteFromOutsideClusters(
|
||||
Block& outside_block =
|
||||
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
|
||||
builder.setInsertionPointToEnd(&outside_block);
|
||||
tf_device::LaunchOp host_launch_op =
|
||||
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back());
|
||||
tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(
|
||||
&builder, cluster_ops.back(), host_device);
|
||||
|
||||
// Determine if there are any inputs that are provided out of cluster.
|
||||
auto external_inputs = GetExternalOperands(cluster_ops);
|
||||
@ -307,8 +311,14 @@ void CreateParallelExecuteFromOutsideClusters(
|
||||
}
|
||||
|
||||
void TPUExtractOutsideCompilation::runOnOperation() {
|
||||
// Get runtime devices information from the closest parent module.
|
||||
auto module = getOperation();
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
auto extract_result =
|
||||
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
module.walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
OutsideClusterMap clusters;
|
||||
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||
&clusters)))
|
||||
@ -316,7 +326,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
|
||||
|
||||
if (clusters.empty()) return WalkResult::advance();
|
||||
|
||||
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters);
|
||||
std::string host_device;
|
||||
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
|
||||
&host_device);
|
||||
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters,
|
||||
host_device);
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user