Rewrite TpuExtractOutsideCompilation pass without assuming clustering.
This pass is now handled by first: 1. Decomposing control flow into the host and device portions first. 2. Moving all outside compilation ops to run in a single host parallel_execute region. This simplifies the flow greatly and leads to simpler code path. It will also allow the use of tf_device::LaunchOp to represent outside compilation regions. PiperOrigin-RevId: 351418970 Change-Id: Ie28bb760a9f2a46f5f9d4bb29360feed8a4dde38
This commit is contained in:
parent
c377472dc4
commit
dd5bbd57dd
@ -68,11 +68,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: "tf.B"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT "tf_device.launch"
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
@ -149,12 +146,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:[a-z_0-9]+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
@ -202,15 +199,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-NOT: "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
@ -230,14 +230,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
@ -265,11 +267,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: tf_device.return %[[HOST_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
@ -297,11 +299,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
@ -365,12 +367,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
@ -396,21 +398,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT2]], %[[DEVICE_ORDINAL2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_1_retvals"
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2_0_retvals"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT1]], %[[DEVICE_ORDINAL1]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
|
||||
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_1_retvals"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster2_0_retvals"
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
@ -438,12 +445,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
@ -465,21 +472,24 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_1_args"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2_0_args"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_1_args"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster2_0_args"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
@ -504,19 +514,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_1_args"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]], %[[RECV_OUTPUT_1]])
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_1_args"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
@ -565,10 +571,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
|
||||
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: "tf.Yield"() : () -> ()
|
||||
@ -576,12 +582,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: tpu_core = 0
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
@ -619,7 +624,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
|
||||
@ -629,8 +634,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: tpu_core = 0
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%7 = "tf.F"() : () -> tensor<?xi32>
|
||||
@ -671,11 +676,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]])
|
||||
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<i1>)
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1)
|
||||
// CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]])
|
||||
@ -686,15 +691,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-SAME: (tensor<i1>) -> ()
|
||||
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK: %[[D_OUT:[0-9]*]] = "tf.D"
|
||||
// CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: tpu_core = 0
|
||||
// CHECK: "tf.Yield"() : () -> ()
|
||||
// CHECK: "tf.Yield"() : () -> ()
|
||||
@ -746,24 +751,23 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%6)
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: tpu_core = 0
|
||||
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
@ -802,7 +806,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_0"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK: "tf.D"
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
@ -810,8 +814,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%6)
|
||||
// CHECK-SAME: key = "if_predicate_channel_0"
|
||||
// CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"}
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
@ -848,14 +851,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_2"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_0"
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK-SAME: key = "if_predicate_channel_cluster1_0_1"
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]])
|
||||
// CHECK-NOT: "tf._XlaSendFromHostV2"
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
@ -864,12 +867,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
|
||||
// CHECK-SAME key = "if_predicate_channel_2"
|
||||
// CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_0"}
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]])
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[H_OUTPUT]])
|
||||
// CHECK-SAME: key = "if_predicate_channel_1"
|
||||
// CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_1"}
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"() : () -> ()
|
||||
// CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]])
|
||||
@ -920,7 +921,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "while_condition_channel_0"
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
@ -979,7 +980,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
|
||||
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "while_condition_channel_0"
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
@ -1029,19 +1030,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT_2:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL_2:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
|
||||
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "while_condition_channel_0"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster2_0_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT_2]], %[[DEVICE_ORDINAL_2]])
|
||||
// CHECK-NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT_1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL_1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
|
||||
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
|
||||
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT_1]], %[[DEVICE_ORDINAL_1]])
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
@ -1049,6 +1057,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: %[[HOST_COMPUTE_OUTPUT:.+]] = "tf._XlaHostComputeMlir"
|
||||
@ -1094,10 +1103,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "while_condition_channel_1"
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: "tf.Yield"
|
||||
@ -1110,7 +1120,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK-NEXT: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
|
||||
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]])
|
||||
@ -1158,7 +1168,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-SAME: key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: (tensor<2x!tf.string>, tensor<i64>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
|
||||
@ -1172,8 +1182,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_0_args"
|
||||
// CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals"
|
||||
// CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args"
|
||||
// CHECK-SAME: tpu_core = 0
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%7 = "tf.F"() : () -> tensor<?xi32>
|
||||
@ -1279,15 +1289,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
|
||||
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
|
||||
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[B_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: "tf.WhileRegion"()
|
||||
// CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]])
|
||||
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
|
||||
// CHECK: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
|
||||
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[C_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]])
|
||||
// CHECK-NEXT: "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
|
||||
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
@ -1297,7 +1310,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[H_OUT:.*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUT]])
|
||||
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUT]])
|
||||
// CHECK: %[[C_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"()
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUT]])
|
||||
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]])
|
||||
// CHECK-NEXT: "tf._XlaHostComputeMlir"()
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
@ -1335,10 +1349,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// violated. tf.C op in this case.
|
||||
// CHECK-LABEL: func @device_op_dominance
|
||||
func @device_op_dominance() -> () {
|
||||
// CHECK: tf.B
|
||||
// CHECK: tf._XlaSendFromHost
|
||||
// CHECK: tf._XlaRecvAtHost
|
||||
// CHECK: tf.B
|
||||
// CHECK: tf.D
|
||||
// CHECK: tf._XlaSendFromHost
|
||||
|
||||
// CHECK: tf.A
|
||||
// CHECK: tf._XlaHostComputeMlir
|
||||
@ -1361,15 +1375,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// CHECK-LABEL: func @device_op_dominance_with_indirect_dependency
|
||||
func @device_op_dominance_with_indirect_dependency() -> () {
|
||||
// CHECK: tf.B
|
||||
// CHECK: tf._XlaRecvAtHost
|
||||
// CHECK: tf.B
|
||||
// CHECK: tf.F
|
||||
// CHECK: tf._XlaSendFromHost
|
||||
|
||||
// CHECK: tf.A
|
||||
// CHECK: tf.C
|
||||
// CHECK: tf.D
|
||||
// CHECK: tf.E
|
||||
// CHECK: tf._XlaHostComputeMlir
|
||||
// CHECK: tf.C
|
||||
// CHECK: tf.E
|
||||
// CHECK: tf.G
|
||||
|
||||
"tf_device.cluster"() ( {
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user