Add import and export tests for handling of While and StatelessWhile ops

PiperOrigin-RevId: 261596381
This commit is contained in:
Smit Hinsu 2019-08-04 16:50:40 -07:00 committed by TensorFlower Gardener
parent 1442bc2116
commit 8029f9f172
2 changed files with 326 additions and 0 deletions

View File

@ -0,0 +1,283 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
# composite While op in MLIR with is_stateless attribute set accordingly to
# distinguish between them.
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile"
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile"
node {
name: "StatefulWhile"
op: "While"
input: "iter"
input: "val"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "body"
}
}
}
attr {
key: "cond"
value {
func {
name: "cond"
}
}
}
experimental_debug_info {
}
}
node {
name: "StatelessWhile"
op: "StatelessWhile"
input: "iter"
input: "val"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "body"
}
}
}
attr {
key: "cond"
value {
func {
name: "cond"
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "StatefulWhile:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "main1"
op: "_Retval"
input: "StatelessWhile:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "iter"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
experimental_debug_info {
}
}
node {
name: "val"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "cond"
input_arg {
name: "cond"
type: DT_INT32
}
input_arg {
name: "cond1"
type: DT_FLOAT
}
output_arg {
name: "cond2"
type: DT_BOOL
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
experimental_debug_info {
original_node_names: "Const"
}
}
node_def {
name: "tf.Greater"
op: "Greater"
input: "cond"
input: "Const:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "tf.Greater"
}
}
ret {
key: "cond2"
value: "tf.Greater:z:0"
}
}
function {
signature {
name: "body"
input_arg {
name: "body"
type: DT_INT32
}
input_arg {
name: "body1"
type: DT_FLOAT
}
output_arg {
name: "body2"
type: DT_INT32
}
output_arg {
name: "body3"
type: DT_FLOAT
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
experimental_debug_info {
original_node_names: "Const"
}
}
node_def {
name: "tf.Sub"
op: "Sub"
input: "body"
input: "Const:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "tf.Sub"
}
}
node_def {
name: "tf.Add"
op: "Add"
input: "body1"
input: "body1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Add"
}
}
ret {
key: "body2"
value: "tf.Sub:z:0"
}
ret {
key: "body3"
value: "tf.Add:z:0"
}
}
}
versions {
producer: 115
min_consumer: 12
}

View File

@ -0,0 +1,43 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%iter = "tf.Placeholder.input"(%arg0) : (tensor<i32>) -> tensor<i32> loc("iter")
%val = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32> loc("val")
// Element wise add `val` with itself for `iter` number of times.
%2:2 = "tf.While"(%iter, %val) {
cond = @cond, body = @body, is_stateless = false
} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatefulWhile")
%3:2 = "tf.While"(%iter, %val) {
cond = @cond, body = @body, is_stateless = true
} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatelessWhile")
return %2#1, %3#1 : tensor<f32>, tensor<f32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.Const" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tf.Greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%0 = "tf.Const" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tf.Sub"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1, %2 : tensor<*xi32>, tensor<*xf32>
}
// Verify that While op is mapped to TensorFlow StatelessWhile op if the
// is_stateless attribute is present and otherwise it is mapped to TensorFlow
// While op. In both cases, the additional attribute should be dropped.
// CHECK: name: "StatefulWhile"
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK: name: "StatelessWhile"
// CHECK-NOT: name:
// CHECK: op: "StatelessWhile"
// CHECK-NOT: is_stateless