[mlir] Change tf_executor.SwitchN to tf_executor._SwitchN

The actual TF op is _SwitchN, mapping it to tf_executor.SwithN op requires
special handling on export and that has hit us 2x now, so just make it
_SwitchN and be uniform instead.

PiperOrigin-RevId: 319052946
Change-Id: I9bd833dd5a4351f5fdd9fb648af8844f0dcdb76b
This commit is contained in:
Jacques Pienaar 2020-06-30 10:47:19 -07:00 committed by TensorFlower Gardener
parent 7046e2fbed
commit f420218445
7 changed files with 91 additions and 31 deletions

View File

@ -278,14 +278,14 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
let verifier = ?; let verifier = ?;
} }
def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN", def TfExecutor_SwitchNOp : TfExecutor_Op<"_SwitchN",
[ControlOperandsAfterAllData, HasParent<"GraphOp">]> { [ControlOperandsAfterAllData, HasParent<"GraphOp">]> {
let summary = [{ let summary = [{
The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and The "tf_executor._SwitchN" operation takes two inputs, `data` and `index`
an integer attribute `num_outs` indicating the number of outputs. The `data` and an integer attribute `num_outs` indicating the number of outputs. The
input is copied to output indicated by the `index` input. The other outputs `data` input is copied to output indicated by the `index` input. The other
are marked as dead. If one of the inputs or a control token is dead, then outputs are marked as dead. If one of the inputs or a control token is
all of the outputs are marked as dead as well. dead, then all of the outputs are marked as dead as well.
}]; }];
let description = [{ let description = [{

View File

@ -331,7 +331,7 @@ func @enter_control_input() {
} }
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]]) // CHECK: tf_executor._SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]])
func @switchn_control_input(%arg1: tensor<i32>) { func @switchn_control_input(%arg1: tensor<i32>) {
tf_executor.graph { tf_executor.graph {
%island:2 = tf_executor.island { %island:2 = tf_executor.island {
@ -339,7 +339,7 @@ func @switchn_control_input(%arg1: tensor<i32>) {
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32> tf_executor.yield %const : tensor<*xi32>
} }
%switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32> %switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32>
tf_executor.fetch %switchn#0 : tensor<*xi32> tf_executor.fetch %switchn#0 : tensor<*xi32>
} }
return return

View File

@ -1,10 +1,10 @@
# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s
# CHECK: tf_executor.SwitchN # CHECK: tf_executor._SwitchN
# CHECK-SAME: of 3 : tensor<*xi32> # CHECK-SAME: of 3 : tensor<*xi32>
# CHECK-SAME: T = i32 # CHECK-SAME: T = i32
# CHECK-SAME: loc("Case/branch_index/_3") # CHECK-SAME: loc("Case/branch_index/_3")
# CHECK: tf_executor.SwitchN # CHECK: tf_executor._SwitchN
# CHECK-SAME: of 2 : tensor<*xf32> # CHECK-SAME: of 2 : tensor<*xf32>
# CHECK-SAME: T = f32 # CHECK-SAME: T = f32
# CHECK-SAME: loc("Case/Case/input_0/_7") # CHECK-SAME: loc("Case/Case/input_0/_7")

View File

@ -0,0 +1,60 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
"module"() ( {
"func"() ( {
"tf_executor.graph"() ( {
%outputs, %control = "tf_executor.island"() ( {
%0 = "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
"tf_executor.yield"(%0) : (tensor<i32>) -> ()
}) : () -> (tensor<i32>, !tf_executor.control)
%outputs_0:3, %control_1 = "tf_executor._SwitchN"(%outputs, %outputs) {T = i32, device = "", num_outs = 3 : i64} : (tensor<i32>, tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, !tf_executor.control)
%outputs_2, %control_3 = "tf_executor.island"() ( {
%0 = "tf.Identity"(%outputs_0#0) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
"tf_executor.yield"(%0) : (tensor<*xi32>) -> ()
}) : () -> (tensor<*xi32>, !tf_executor.control)
%outputs_4, %control_5 = "tf_executor.island"(%control_3) ( {
%0 = "tf.Const"() {device = "", value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
"tf_executor.yield"(%0) : (tensor<f32>) -> ()
}) : (!tf_executor.control) -> (tensor<f32>, !tf_executor.control)
%outputs_6, %control_7 = "tf_executor.island"() ( {
%0 = "tf.Identity"(%outputs_0#1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
"tf_executor.yield"(%0) : (tensor<*xi32>) -> ()
}) : () -> (tensor<*xi32>, !tf_executor.control)
%outputs_8, %control_9 = "tf_executor.island"(%control_7) ( {
%0 = "tf.Const"() {device = "", value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
"tf_executor.yield"(%0) : (tensor<f32>) -> ()
}) : (!tf_executor.control) -> (tensor<f32>, !tf_executor.control)
%outputs_10, %control_11 = "tf_executor.island"() ( {
%0 = "tf.Identity"(%outputs_0#2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
"tf_executor.yield"(%0) : (tensor<*xi32>) -> ()
}) : () -> (tensor<*xi32>, !tf_executor.control)
%outputs_12, %control_13 = "tf_executor.island"(%control_11) ( {
%0 = "tf.Const"() {device = "", value = dense<4.000000e+00> : tensor<f32>} : () -> tensor<f32>
"tf_executor.yield"(%0) : (tensor<f32>) -> ()
}) : (!tf_executor.control) -> (tensor<f32>, !tf_executor.control)
%outputs_14, %control_15 = "tf_executor.island"() ( {
%0 = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
"tf_executor.yield"(%0) : (tensor<f32>) -> ()
}) : () -> (tensor<f32>, !tf_executor.control)
%outputs_16:2, %control_17 = "tf_executor._SwitchN"(%outputs_14, %outputs) {T = f32, _class = ["Case/input_0"], device = "", num_outs = 2 : i64} : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
%outputs_18, %control_19 = "tf_executor.island"() ( {
%0 = "tf.Mul"(%outputs_16#0, %outputs_4) {device = ""} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
"tf_executor.yield"(%0) : (tensor<*xf32>) -> ()
}) : () -> (tensor<*xf32>, !tf_executor.control)
%outputs_20, %control_21 = "tf_executor.island"() ( {
%0 = "tf.Mul"(%outputs_16#1, %outputs_8) {device = ""} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
"tf_executor.yield"(%0) : (tensor<*xf32>) -> ()
}) : () -> (tensor<*xf32>, !tf_executor.control)
%output, %value_index, %control_22 = "tf_executor.Merge"(%outputs_18, %outputs_20) {N = 2 : i64, T = f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>, !tf_executor.control)
%control_23 = "tf_executor.island"() ( {
"tf._Retval"(%output) {T = f32, device = "/job:localhost/replica:0/task:0/device:CPU:0", index = 0 : i64} : (tensor<*xf32>) -> ()
"tf_executor.yield"() : () -> ()
}) : () -> !tf_executor.control
"tf_executor.fetch"() : () -> ()
}) : () -> ()
"std.return"() : () -> ()
}) {sym_name = "main", type = () -> ()} : () -> ()
"module_terminator"() : () -> ()
}) {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 126 : i32}} : () -> ()
// CHECK: _SwitchN

View File

@ -273,7 +273,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<?x?x?xf32>) -> // CHECK-SAME: : (tensor<32x?x4xf32>, tensor<?x?x?xf32>) ->
// CHECK: tf_executor.Switch // CHECK: tf_executor.Switch
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<i1>) -> // CHECK-SAME: : (tensor<32x?x4xf32>, tensor<i1>) ->
// CHECK: tf_executor.SwitchN // CHECK: tf_executor._SwitchN
// CHECK-SAME: : tensor<?x?x?xf32> // CHECK-SAME: : tensor<?x?x?xf32>
// CHECK: tf_executor.Enter // CHECK: tf_executor.Enter
// CHECK-SAME: : (tensor<32x?x4xf32>) -> // CHECK-SAME: : (tensor<32x?x4xf32>) ->
@ -283,7 +283,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-SAME: tensor<i1> // CHECK-SAME: tensor<i1>
%merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<i32>, !tf_executor.control) %merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<i32>, !tf_executor.control)
%switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor<?x?x?xf32>, tensor<i1>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control) %switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor<?x?x?xf32>, tensor<i1>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
%switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control) %switchn:3 = "tf_executor._SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
%enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control) %enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
%exit:2 = "tf_executor.Exit"(%island#0) : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control) %exit:2 = "tf_executor.Exit"(%island#0) : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
%loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control) %loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control)

View File

@ -211,11 +211,11 @@ func @switch_with_control_inputs_functional(%arg0: tensor<i1>, %arg1: !tf_execut
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
// CHECK: tf_executor.SwitchN %{{.*}}, %{{.*}} of 5 : tensor<*xf32> // CHECK: tf_executor._SwitchN %{{.*}}, %{{.*}} of 5 : tensor<*xf32>
%1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32> %1:6 = tf_executor._SwitchN %arg1, %arg0 of 5 : tensor<*xf32>
// CHECK: tf_executor.SwitchN %{{.*}}, %{{.*}} of 12 (%{{.*}}) : tensor<*xf32> // CHECK: tf_executor._SwitchN %{{.*}}, %{{.*}} of 12 (%{{.*}}) : tensor<*xf32>
%2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32> %2:13 = tf_executor._SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32>
tf_executor.fetch %2#0 : tensor<*xf32> tf_executor.fetch %2#0 : tensor<*xf32>
} }

View File

@ -391,11 +391,11 @@ func @invalid_switch(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
// ----- // -----
// Check that a tf_executor.SwitchN parent is a graph. // Check that a tf_executor._SwitchN parent is a graph.
func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i32>) { func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i32>) {
"tf.some_op"() ({ "tf.some_op"() ({
%1:6 = tf_executor.SwitchN %arg0, %arg1 of 5 : tensor<*xf32> %1:6 = tf_executor._SwitchN %arg0, %arg1 of 5 : tensor<*xf32>
// expected-error@-1 {{'tf_executor.SwitchN' op expects parent op 'tf_executor.graph'}} // expected-error@-1 {{'tf_executor._SwitchN' op expects parent op 'tf_executor.graph'}}
}) : () -> () }) : () -> ()
return return
} }
@ -406,8 +406,8 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i32>) {
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) %1:3 = "tf_executor._SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expect `num_outs` (5) results but got 2}} // expected-error@-1 {{'tf_executor._SwitchN' op expect `num_outs` (5) results but got 2}}
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }
@ -419,8 +419,8 @@ func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
// Check that data operands of SwitchN have tensor type // Check that data operands of SwitchN have tensor type
func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> { func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
%result = tf_executor.graph { %result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control) %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}} // expected-error@-1 {{'tf_executor._SwitchN' op expects data operand to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : tensor<*xi32> tf_executor.fetch %1#0 : tensor<*xi32>
} }
return %result : tensor<*xi32> return %result : tensor<*xi32>
@ -431,8 +431,8 @@ func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
// Check that result of SwitchN has tensor type // Check that result of SwitchN has tensor type
func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 { func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
%result = tf_executor.graph { %result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control) %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}} // expected-error@-1 {{'tf_executor._SwitchN' op expects outputs to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : i32 tf_executor.fetch %1#0 : i32
} }
return %result : i32 return %result : i32
@ -444,8 +444,8 @@ func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> { func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} // expected-error@-1 {{'tf_executor._SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
tf_executor.fetch %1#0 : tensor<4x!tf.f32ref> tf_executor.fetch %1#0 : tensor<4x!tf.f32ref>
} }
return %fetches : tensor<4x!tf.f32ref> return %fetches : tensor<4x!tf.f32ref>
@ -457,8 +457,8 @@ func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.
func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> { func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control) %1:3 = "tf_executor._SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}} // expected-error@-1 {{'tf_executor._SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}}
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }
@ -471,8 +471,8 @@ func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32>
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
%1:3 = tf_executor.SwitchN %arg1, %arg0 of 2 : tensor<*xf32>, i32 %1:3 = tf_executor._SwitchN %arg1, %arg0 of 2 : tensor<*xf32>, i32
// expected-error@-1 {{custom op 'tf_executor.SwitchN' expects only a single data type}} // expected-error@-1 {{custom op 'tf_executor._SwitchN' expects only a single data type}}
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }