Rewrite TpuExtractOutsideCompilation pass without assuming clustering.

This pass is now handled by first:
1. Decomposing control flow into the host and device portions first.
2. Moving all outside compilation ops to run in a single host parallel_execute region.

This simplifies the flow greatly and leads to simpler code path.  It will also allow the use of tf_device::LaunchOp to represent outside compilation regions.

PiperOrigin-RevId: 351399003
Change-Id: I558ef175a993b2e9861c14fc3e34dc46bb8f41e6
This commit is contained in:
Ken Franko 2021-01-12 10:30:09 -08:00 committed by TensorFlower Gardener
parent c62635d663
commit c213508ff9
2 changed files with 375 additions and 747 deletions

View File

@ -68,8 +68,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-LABEL: func @nodep_multiple_outside_compilation
func @nodep_multiple_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-COUNT-2: "tf_device.launch"
// CHECK: "tf_device.parallel_execute"
// CHECK: "tf_device.launch"
// CHECK: "tf.B"
// CHECK-NEXT: "tf.D"
// CHECK-NOT "tf_device.launch"
// CHECK: "tf_device.cluster"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
@ -146,12 +149,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:[a-z_0-9]+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -199,18 +202,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-NOT: "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> (tensor<?xi32>)
@ -230,16 +230,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -267,11 +265,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: tf_device.return %[[HOST_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -299,11 +297,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -367,12 +365,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -398,26 +396,21 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT2]], %[[DEVICE_ORDINAL2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL2]])
// CHECK-SAME: key = "host_compute_channel_cluster2_0_retvals"
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT1]], %[[DEVICE_ORDINAL1]])
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL1]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster2_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_1_retvals"
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -445,12 +438,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -472,24 +465,21 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-SAME: key = "host_compute_channel_cluster2_0_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster2_0_args"
// CHECK-SAME: send_key = "host_compute_channel_1_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -514,15 +504,19 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]], %[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_1_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -571,10 +565,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.Yield"() : () -> ()
@ -582,11 +576,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK: "tf._XlaHostComputeMlir"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -624,7 +619,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
@ -634,8 +629,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: tpu_core = 0
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.F"() : () -> tensor<?xi32>
@ -676,11 +671,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]])
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1)
// CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]])
@ -691,15 +686,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: (tensor<i1>) -> ()
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[D_OUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK: "tf.Yield"() : () -> ()
// CHECK: "tf.Yield"() : () -> ()
@ -751,23 +746,24 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -806,7 +802,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK: "tf.D"
// CHECK-NEXT: "tf.Yield"() : () -> ()
@ -814,7 +810,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -851,14 +848,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: key = "if_predicate_channel_2"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_1"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]])
// CHECK-NOT: "tf._XlaSendFromHostV2"
// CHECK-NEXT: "tf.Yield"() : () -> ()
@ -867,10 +864,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME key = "if_predicate_channel_2"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]])
// CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_1"}
// CHECK: "tf._XlaHostComputeMlir"(%[[H_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]])
@ -921,7 +920,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
@ -980,7 +979,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
@ -1030,26 +1029,19 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-SAME: key = "while_condition_channel_cluster2_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
@ -1057,7 +1049,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: %[[HOST_COMPUTE_OUTPUT:.+]] = "tf._XlaHostComputeMlir"
@ -1103,11 +1094,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK-SAME: key = "while_condition_channel_1"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
@ -1120,7 +1110,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]])
// CHECK-NEXT: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]])
@ -1168,7 +1158,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
@ -1182,8 +1172,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: tpu_core = 0
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.F"() : () -> tensor<?xi32>
@ -1289,18 +1279,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[B_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.WhileRegion"()
// CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]])
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[C_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]])
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
@ -1310,8 +1297,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[H_OUT:.*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUT]])
// CHECK: %[[C_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"()
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUT]])
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]])
// CHECK-NEXT: "tf._XlaHostComputeMlir"()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -1349,10 +1335,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// violated. tf.C op in this case.
// CHECK-LABEL: func @device_op_dominance
func @device_op_dominance() -> () {
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.B
// CHECK: tf.D
// CHECK: tf._XlaSendFromHost
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.D
// CHECK: tf.A
// CHECK: tf._XlaHostComputeMlir
@ -1375,16 +1361,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-LABEL: func @device_op_dominance_with_indirect_dependency
func @device_op_dominance_with_indirect_dependency() -> () {
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.B
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.F
// CHECK: tf._XlaSendFromHost
// CHECK: tf.A
// CHECK: tf.D
// CHECK: tf._XlaHostComputeMlir
// CHECK: tf.C
// CHECK: tf.D
// CHECK: tf.E
// CHECK: tf._XlaHostComputeMlir
// CHECK: tf.G
"tf_device.cluster"() ( {