Rewrite TpuExtractOutsideCompilation pass without assuming clustering.

This pass is now handled by first:
1. Decomposing control flow into the host and device portions first.
2. Moving all outside compilation ops to run in a single host parallel_execute region.

This simplifies the flow greatly and leads to simpler code path.  It will also allow the use of tf_device::LaunchOp to represent outside compilation regions.

PiperOrigin-RevId: 351418970
Change-Id: Ie28bb760a9f2a46f5f9d4bb29360feed8a4dde38
This commit is contained in:
Ken Franko 2021-01-12 12:00:59 -08:00 committed by TensorFlower Gardener
parent c377472dc4
commit dd5bbd57dd
2 changed files with 748 additions and 376 deletions

View File

@ -68,11 +68,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-LABEL: func @nodep_multiple_outside_compilation
func @nodep_multiple_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK: "tf_device.launch"
// CHECK: "tf.B"
// CHECK-NEXT: "tf.D"
// CHECK-NOT "tf_device.launch"
// CHECK: "tf_device.parallel_execute"
// CHECK-COUNT-2: "tf_device.launch"
// CHECK: "tf_device.cluster"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
@ -149,12 +146,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:[a-z_0-9]+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -202,15 +199,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-NOT: "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> (tensor<?xi32>)
@ -230,14 +230,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -265,11 +267,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK: tf_device.return %[[HOST_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -297,11 +299,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -365,12 +367,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -396,21 +398,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-DAG: %[[PROGRAM_OUTPUT2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT2]], %[[DEVICE_ORDINAL2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_retvals"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL2]])
// CHECK-SAME: key = "host_compute_channel_cluster2_0_retvals"
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT1]], %[[DEVICE_ORDINAL1]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL1]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_1_retvals"
// CHECK-SAME: recv_key = "host_compute_channel_cluster2_0_retvals"
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
@ -438,12 +445,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -465,21 +472,24 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-SAME: key = "host_compute_channel_cluster2_0_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_1_args"
// CHECK-SAME: send_key = "host_compute_channel_cluster2_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -504,19 +514,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]], %[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_1_args"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
@ -565,10 +571,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.Yield"() : () -> ()
@ -576,12 +582,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -619,7 +624,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
@ -629,8 +634,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: tpu_core = 0
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.F"() : () -> tensor<?xi32>
@ -671,11 +676,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]])
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1)
// CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]])
@ -686,15 +691,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-SAME: (tensor<i1>) -> ()
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[D_OUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK: "tf.Yield"() : () -> ()
// CHECK: "tf.Yield"() : () -> ()
@ -746,24 +751,23 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: tpu_core = 0
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -802,7 +806,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK: "tf.D"
// CHECK-NEXT: "tf.Yield"() : () -> ()
@ -810,8 +814,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -848,14 +851,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_2"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_1"
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]])
// CHECK-NOT: "tf._XlaSendFromHostV2"
// CHECK-NEXT: "tf.Yield"() : () -> ()
@ -864,12 +867,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME key = "if_predicate_channel_2"
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_0"}
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[H_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_1"}
// CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]])
@ -920,7 +921,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
@ -979,7 +980,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
@ -1029,19 +1030,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-SAME: key = "while_condition_channel_cluster2_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
@ -1049,6 +1057,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: %[[HOST_COMPUTE_OUTPUT:.+]] = "tf._XlaHostComputeMlir"
@ -1094,10 +1103,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_1"
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
@ -1110,7 +1120,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]])
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]])
@ -1158,7 +1168,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
@ -1172,8 +1182,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
// CHECK-SAME: tpu_core = 0
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.F"() : () -> tensor<?xi32>
@ -1279,15 +1289,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[B_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.WhileRegion"()
// CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]])
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[C_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]])
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
@ -1297,7 +1310,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[H_OUT:.*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUT]])
// CHECK: %[[C_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"()
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUT]])
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]])
// CHECK-NEXT: "tf._XlaHostComputeMlir"()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@ -1335,10 +1349,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// violated. tf.C op in this case.
// CHECK-LABEL: func @device_op_dominance
func @device_op_dominance() -> () {
// CHECK: tf.B
// CHECK: tf._XlaSendFromHost
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.B
// CHECK: tf.D
// CHECK: tf._XlaSendFromHost
// CHECK: tf.A
// CHECK: tf._XlaHostComputeMlir
@ -1361,15 +1375,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-LABEL: func @device_op_dominance_with_indirect_dependency
func @device_op_dominance_with_indirect_dependency() -> () {
// CHECK: tf.B
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.B
// CHECK: tf.F
// CHECK: tf._XlaSendFromHost
// CHECK: tf.A
// CHECK: tf.C
// CHECK: tf.D
// CHECK: tf.E
// CHECK: tf._XlaHostComputeMlir
// CHECK: tf.C
// CHECK: tf.E
// CHECK: tf.G
"tf_device.cluster"() ( {