From cc6e729c9e210bb305c99a9d711787337a64b3d3 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sun, 4 Aug 2019 18:05:07 -0700 Subject: [PATCH] Add import and export tests for handling of If and StatelessIf ops PiperOrigin-RevId: 261600770 --- .../graphdef2mlir/functional-if-ops.pbtxt | 256 ++++++++++++++++++ .../mlir2graphdef/functional-if-ops.mlir | 34 +++ 2 files changed, 290 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt new file mode 100644 index 00000000000..cbfa973fd64 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt @@ -0,0 +1,256 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s + +# Verify that TensorFlow If and StatelessIf ops are mapped to the +# composite If op in MLIR with is_stateless attribute set accordingly to +# distinguish between them. + +# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf" +# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf" + +node { + name: "tf.Less" + op: "Less" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "StatefulIf" + op: "If" + input: "tf.Less" + input: "a" + input: "b" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "else_branch" + value { + func { + name: "cond_false" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "cond_true" + } + } + } + experimental_debug_info { + } +} +node { + name: "StatelessIf" + op: "StatelessIf" + input: "tf.Less" + input: "a" + input: "b" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "else_branch" + value { + func { + name: "cond_false" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "cond_true" + } + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "StatefulIf" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "main1" + op: "_Retval" + input: "StatelessIf" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "b" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "cond_true" + input_arg { + name: "cond_true" + type: DT_FLOAT + } + input_arg { + name: "cond_true1" + type: DT_FLOAT + } + output_arg { + name: "cond_true2" + type: DT_FLOAT + } + } + node_def { + name: "tf.Add" + op: "Add" + input: "cond_true" + input: "cond_true1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Add" + } + } + ret { + key: "cond_true2" + value: "tf.Add:z:0" + } + } + function { + signature { + name: "cond_false" + input_arg { + name: "cond_false" + type: DT_FLOAT + } + input_arg { + name: "cond_false1" + type: DT_FLOAT + } + output_arg { + name: "cond_false2" + type: DT_FLOAT + } + } + node_def { + name: "tf.Mul" + op: "Mul" + input: "cond_false" + input: "cond_false1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Mul" + } + } + ret { + key: "cond_false2" + value: "tf.Mul:z:0" + } + } +} +versions { + producer: 115 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir new file mode 100644 index 00000000000..ccd058842a9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir @@ -0,0 +1,34 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "tf.Placeholder.input"(%arg0) : (tensor) -> tensor + %1 = "tf.Placeholder.input"(%arg1) : (tensor) -> tensor + %2 = "tf.Less"(%0, %1) : (tensor, tensor) -> tensor + %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor, tensor, tensor) -> tensor loc("StatefulIf") + %4 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor, tensor, tensor) -> tensor loc("StatelessIf") + return %3, %4 : tensor, tensor +} + +func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless +// attribute is present and otherwise it is mapped to TensorFlow If op. In both +// cases, the additional attribute should be dropped. + +// CHECK: name: "StatefulIf" +// CHECK-NOT: name: +// CHECK: op: "If" +// CHECK-NOT: is_stateless + +// CHECK: name: "StatelessIf" +// CHECK-NOT: name: +// CHECK: op: "StatelessIf" +// CHECK-NOT: is_stateless