Add import and export tests for handling of If and StatelessIf ops

PiperOrigin-RevId: 261600770
This commit is contained in:
Smit Hinsu 2019-08-04 18:05:07 -07:00 committed by TensorFlower Gardener
parent 8029f9f172
commit cc6e729c9e
2 changed files with 290 additions and 0 deletions

View File

@ -0,0 +1,256 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s
# Verify that TensorFlow If and StatelessIf ops are mapped to the
# composite If op in MLIR with is_stateless attribute set accordingly to
# distinguish between them.
# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf"
# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf"
node {
name: "tf.Less"
op: "Less"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "StatefulIf"
op: "If"
input: "tf.Less"
input: "a"
input: "b"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true"
}
}
}
experimental_debug_info {
}
}
node {
name: "StatelessIf"
op: "StatelessIf"
input: "tf.Less"
input: "a"
input: "b"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true"
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "StatefulIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "main1"
op: "_Retval"
input: "StatelessIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "a"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "cond_true"
input_arg {
name: "cond_true"
type: DT_FLOAT
}
input_arg {
name: "cond_true1"
type: DT_FLOAT
}
output_arg {
name: "cond_true2"
type: DT_FLOAT
}
}
node_def {
name: "tf.Add"
op: "Add"
input: "cond_true"
input: "cond_true1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Add"
}
}
ret {
key: "cond_true2"
value: "tf.Add:z:0"
}
}
function {
signature {
name: "cond_false"
input_arg {
name: "cond_false"
type: DT_FLOAT
}
input_arg {
name: "cond_false1"
type: DT_FLOAT
}
output_arg {
name: "cond_false2"
type: DT_FLOAT
}
}
node_def {
name: "tf.Mul"
op: "Mul"
input: "cond_false"
input: "cond_false1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Mul"
}
}
ret {
key: "cond_false2"
value: "tf.Mul:z:0"
}
}
}
versions {
producer: 115
min_consumer: 12
}

View File

@ -0,0 +1,34 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = "tf.Less"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatefulIf")
%4 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatelessIf")
return %3, %4 : tensor<f32>, tensor<f32>
}
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless
// attribute is present and otherwise it is mapped to TensorFlow If op. In both
// cases, the additional attribute should be dropped.
// CHECK: name: "StatefulIf"
// CHECK-NOT: name:
// CHECK: op: "If"
// CHECK-NOT: is_stateless
// CHECK: name: "StatelessIf"
// CHECK-NOT: name:
// CHECK: op: "StatelessIf"
// CHECK-NOT: is_stateless