From f4202184458ac97fdc447805a8cc57b75d4f01cf Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 30 Jun 2020 10:47:19 -0700 Subject: [PATCH] [mlir] Change tf_executor.SwitchN to tf_executor._SwitchN The actual TF op is _SwitchN, mapping it to tf_executor.SwithN op requires special handling on export and that has hit us 2x now, so just make it _SwitchN and be uniform instead. PiperOrigin-RevId: 319052946 Change-Id: I9bd833dd5a4351f5fdd9fb648af8844f0dcdb76b --- .../mlir/tensorflow/ir/tf_executor_ops.td | 12 ++-- .../tensorflow/tests/breakup-islands.mlir | 4 +- .../tests/graphdef2mlir/switch_n.pbtxt | 4 +- .../tests/mlir2graphdef/switchn.mlir | 60 +++++++++++++++++++ .../tensorflow/tests/shape_inference.mlir | 4 +- .../tensorflow/tests/tf_executor_ops.mlir | 8 +-- .../tests/tf_executor_ops_invalid.mlir | 30 +++++----- 7 files changed, 91 insertions(+), 31 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/switchn.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 0efe578f151..3081018b8da 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -278,14 +278,14 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", let verifier = ?; } -def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN", +def TfExecutor_SwitchNOp : TfExecutor_Op<"_SwitchN", [ControlOperandsAfterAllData, HasParent<"GraphOp">]> { let summary = [{ - The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and - an integer attribute `num_outs` indicating the number of outputs. The `data` - input is copied to output indicated by the `index` input. The other outputs - are marked as dead. If one of the inputs or a control token is dead, then - all of the outputs are marked as dead as well. + The "tf_executor._SwitchN" operation takes two inputs, `data` and `index` + and an integer attribute `num_outs` indicating the number of outputs. The + `data` input is copied to output indicated by the `index` input. The other + outputs are marked as dead. If one of the inputs or a control token is + dead, then all of the outputs are marked as dead as well. }]; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index bdfe6f2ce07..05d34eb0755 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -331,7 +331,7 @@ func @enter_control_input() { } // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" -// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]]) +// CHECK: tf_executor._SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]]) func @switchn_control_input(%arg1: tensor) { tf_executor.graph { %island:2 = tf_executor.island { @@ -339,7 +339,7 @@ func @switchn_control_input(%arg1: tensor) { %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) tf_executor.yield %const : tensor<*xi32> } - %switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32> + %switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32> tf_executor.fetch %switchn#0 : tensor<*xi32> } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt index 4c4c8011932..59731b7cdb3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt @@ -1,10 +1,10 @@ # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s -# CHECK: tf_executor.SwitchN +# CHECK: tf_executor._SwitchN # CHECK-SAME: of 3 : tensor<*xi32> # CHECK-SAME: T = i32 # CHECK-SAME: loc("Case/branch_index/_3") -# CHECK: tf_executor.SwitchN +# CHECK: tf_executor._SwitchN # CHECK-SAME: of 2 : tensor<*xf32> # CHECK-SAME: T = f32 # CHECK-SAME: loc("Case/Case/input_0/_7") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/switchn.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/switchn.mlir new file mode 100644 index 00000000000..25f50603521 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/switchn.mlir @@ -0,0 +1,60 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +"module"() ( { + "func"() ( { + "tf_executor.graph"() ( { + %outputs, %control = "tf_executor.island"() ( { + %0 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + "tf_executor.yield"(%0) : (tensor) -> () + }) : () -> (tensor, !tf_executor.control) + %outputs_0:3, %control_1 = "tf_executor._SwitchN"(%outputs, %outputs) {T = i32, device = "", num_outs = 3 : i64} : (tensor, tensor) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, !tf_executor.control) + %outputs_2, %control_3 = "tf_executor.island"() ( { + %0 = "tf.Identity"(%outputs_0#0) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + "tf_executor.yield"(%0) : (tensor<*xi32>) -> () + }) : () -> (tensor<*xi32>, !tf_executor.control) + %outputs_4, %control_5 = "tf_executor.island"(%control_3) ( { + %0 = "tf.Const"() {device = "", value = dense<2.000000e+00> : tensor} : () -> tensor + "tf_executor.yield"(%0) : (tensor) -> () + }) : (!tf_executor.control) -> (tensor, !tf_executor.control) + %outputs_6, %control_7 = "tf_executor.island"() ( { + %0 = "tf.Identity"(%outputs_0#1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + "tf_executor.yield"(%0) : (tensor<*xi32>) -> () + }) : () -> (tensor<*xi32>, !tf_executor.control) + %outputs_8, %control_9 = "tf_executor.island"(%control_7) ( { + %0 = "tf.Const"() {device = "", value = dense<3.000000e+00> : tensor} : () -> tensor + "tf_executor.yield"(%0) : (tensor) -> () + }) : (!tf_executor.control) -> (tensor, !tf_executor.control) + %outputs_10, %control_11 = "tf_executor.island"() ( { + %0 = "tf.Identity"(%outputs_0#2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + "tf_executor.yield"(%0) : (tensor<*xi32>) -> () + }) : () -> (tensor<*xi32>, !tf_executor.control) + %outputs_12, %control_13 = "tf_executor.island"(%control_11) ( { + %0 = "tf.Const"() {device = "", value = dense<4.000000e+00> : tensor} : () -> tensor + "tf_executor.yield"(%0) : (tensor) -> () + }) : (!tf_executor.control) -> (tensor, !tf_executor.control) + %outputs_14, %control_15 = "tf_executor.island"() ( { + %0 = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor} : () -> tensor + "tf_executor.yield"(%0) : (tensor) -> () + }) : () -> (tensor, !tf_executor.control) + %outputs_16:2, %control_17 = "tf_executor._SwitchN"(%outputs_14, %outputs) {T = f32, _class = ["Case/input_0"], device = "", num_outs = 2 : i64} : (tensor, tensor) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + %outputs_18, %control_19 = "tf_executor.island"() ( { + %0 = "tf.Mul"(%outputs_16#0, %outputs_4) {device = ""} : (tensor<*xf32>, tensor) -> tensor<*xf32> + "tf_executor.yield"(%0) : (tensor<*xf32>) -> () + }) : () -> (tensor<*xf32>, !tf_executor.control) + %outputs_20, %control_21 = "tf_executor.island"() ( { + %0 = "tf.Mul"(%outputs_16#1, %outputs_8) {device = ""} : (tensor<*xf32>, tensor) -> tensor<*xf32> + "tf_executor.yield"(%0) : (tensor<*xf32>) -> () + }) : () -> (tensor<*xf32>, !tf_executor.control) + %output, %value_index, %control_22 = "tf_executor.Merge"(%outputs_18, %outputs_20) {N = 2 : i64, T = f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>, !tf_executor.control) + %control_23 = "tf_executor.island"() ( { + "tf._Retval"(%output) {T = f32, device = "/job:localhost/replica:0/task:0/device:CPU:0", index = 0 : i64} : (tensor<*xf32>) -> () + "tf_executor.yield"() : () -> () + }) : () -> !tf_executor.control + "tf_executor.fetch"() : () -> () + }) : () -> () + "std.return"() : () -> () + }) {sym_name = "main", type = () -> ()} : () -> () + "module_terminator"() : () -> () +}) {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 126 : i32}} : () -> () + +// CHECK: _SwitchN diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 1fd30953799..4d623e67257 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -273,7 +273,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-SAME: : (tensor<32x?x4xf32>, tensor) -> // CHECK: tf_executor.Switch // CHECK-SAME: : (tensor<32x?x4xf32>, tensor) -> - // CHECK: tf_executor.SwitchN + // CHECK: tf_executor._SwitchN // CHECK-SAME: : tensor // CHECK: tf_executor.Enter // CHECK-SAME: : (tensor<32x?x4xf32>) -> @@ -283,7 +283,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-SAME: tensor %merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) %switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) - %switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) + %switchn:3 = "tf_executor._SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) %enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor) -> (tensor, !tf_executor.control) %exit:2 = "tf_executor.Exit"(%island#0) : (tensor) -> (tensor, !tf_executor.control) %loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 5c2e5afd263..1e537880620 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -211,11 +211,11 @@ func @switch_with_control_inputs_functional(%arg0: tensor, %arg1: !tf_execut func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { -// CHECK: tf_executor.SwitchN %{{.*}}, %{{.*}} of 5 : tensor<*xf32> - %1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32> +// CHECK: tf_executor._SwitchN %{{.*}}, %{{.*}} of 5 : tensor<*xf32> + %1:6 = tf_executor._SwitchN %arg1, %arg0 of 5 : tensor<*xf32> -// CHECK: tf_executor.SwitchN %{{.*}}, %{{.*}} of 12 (%{{.*}}) : tensor<*xf32> - %2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32> +// CHECK: tf_executor._SwitchN %{{.*}}, %{{.*}} of 12 (%{{.*}}) : tensor<*xf32> + %2:13 = tf_executor._SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32> tf_executor.fetch %2#0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 1fdc99d1ec8..2f034f1bfae 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -391,11 +391,11 @@ func @invalid_switch(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- -// Check that a tf_executor.SwitchN parent is a graph. +// Check that a tf_executor._SwitchN parent is a graph. func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { "tf.some_op"() ({ - %1:6 = tf_executor.SwitchN %arg0, %arg1 of 5 : tensor<*xf32> -// expected-error@-1 {{'tf_executor.SwitchN' op expects parent op 'tf_executor.graph'}} + %1:6 = tf_executor._SwitchN %arg0, %arg1 of 5 : tensor<*xf32> +// expected-error@-1 {{'tf_executor._SwitchN' op expects parent op 'tf_executor.graph'}} }) : () -> () return } @@ -406,8 +406,8 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op expect `num_outs` (5) results but got 2}} + %1:3 = "tf_executor._SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) +// expected-error@-1 {{'tf_executor._SwitchN' op expect `num_outs` (5) results but got 2}} tf_executor.fetch %1#0 : tensor<*xf32> } @@ -419,8 +419,8 @@ func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> // Check that data operands of SwitchN have tensor type func @invalid_switchN(%arg0: i32, %arg1: tensor) -> tensor<*xi32> { %result = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor) -> (tensor<*xi32>, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}} + %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor._SwitchN' op expects data operand to have tensor type but got 'i32'}} tf_executor.fetch %1#0 : tensor<*xi32> } return %result : tensor<*xi32> @@ -431,8 +431,8 @@ func @invalid_switchN(%arg0: i32, %arg1: tensor) -> tensor<*xi32> { // Check that result of SwitchN has tensor type func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { %result = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}} + %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor._SwitchN' op expects outputs to have tensor type but got 'i32'}} tf_executor.fetch %1#0 : i32 } return %result : i32 @@ -444,8 +444,8 @@ func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4x!tf.f32ref> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} + %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) +// expected-error@-1 {{'tf_executor._SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} tf_executor.fetch %1#0 : tensor<4x!tf.f32ref> } return %fetches : tensor<4x!tf.f32ref> @@ -457,8 +457,8 @@ func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4x!tf. func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor'}} + %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor._SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor'}} tf_executor.fetch %1#0 : tensor<*xf32> } @@ -471,8 +471,8 @@ func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { - %1:3 = tf_executor.SwitchN %arg1, %arg0 of 2 : tensor<*xf32>, i32 -// expected-error@-1 {{custom op 'tf_executor.SwitchN' expects only a single data type}} + %1:3 = tf_executor._SwitchN %arg1, %arg0 of 2 : tensor<*xf32>, i32 +// expected-error@-1 {{custom op 'tf_executor._SwitchN' expects only a single data type}} tf_executor.fetch %1#0 : tensor<*xf32> }