diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt new file mode 100644 index 00000000000..953f83a9f68 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt @@ -0,0 +1,283 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s + +# Verify that TensorFlow While and StatelessWhile ops are mapped to the +# composite While op in MLIR with is_stateless attribute set accordingly to +# distinguish between them. + +# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile" +# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile" + +node { + name: "StatefulWhile" + op: "While" + input: "iter" + input: "val" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "body" + } + } + } + attr { + key: "cond" + value { + func { + name: "cond" + } + } + } + experimental_debug_info { + } +} +node { + name: "StatelessWhile" + op: "StatelessWhile" + input: "iter" + input: "val" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "body" + } + } + } + attr { + key: "cond" + value { + func { + name: "cond" + } + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "StatefulWhile:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "main1" + op: "_Retval" + input: "StatelessWhile:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "iter" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + experimental_debug_info { + } +} +node { + name: "val" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "cond" + input_arg { + name: "cond" + type: DT_INT32 + } + input_arg { + name: "cond1" + type: DT_FLOAT + } + output_arg { + name: "cond2" + type: DT_BOOL + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + experimental_debug_info { + original_node_names: "Const" + } + } + node_def { + name: "tf.Greater" + op: "Greater" + input: "cond" + input: "Const:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "tf.Greater" + } + } + ret { + key: "cond2" + value: "tf.Greater:z:0" + } + } + function { + signature { + name: "body" + input_arg { + name: "body" + type: DT_INT32 + } + input_arg { + name: "body1" + type: DT_FLOAT + } + output_arg { + name: "body2" + type: DT_INT32 + } + output_arg { + name: "body3" + type: DT_FLOAT + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + experimental_debug_info { + original_node_names: "Const" + } + } + node_def { + name: "tf.Sub" + op: "Sub" + input: "body" + input: "Const:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "tf.Sub" + } + } + node_def { + name: "tf.Add" + op: "Add" + input: "body1" + input: "body1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Add" + } + } + ret { + key: "body2" + value: "tf.Sub:z:0" + } + ret { + key: "body3" + value: "tf.Add:z:0" + } + } +} +versions { + producer: 115 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir new file mode 100644 index 00000000000..0009c7a4dc4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %iter = "tf.Placeholder.input"(%arg0) : (tensor) -> tensor loc("iter") + %val = "tf.Placeholder.input"(%arg1) : (tensor) -> tensor loc("val") + + // Element wise add `val` with itself for `iter` number of times. + %2:2 = "tf.While"(%iter, %val) { + cond = @cond, body = @body, is_stateless = false + } : (tensor, tensor) -> (tensor, tensor) loc("StatefulWhile") + %3:2 = "tf.While"(%iter, %val) { + cond = @cond, body = @body, is_stateless = true + } : (tensor, tensor) -> (tensor, tensor) loc("StatelessWhile") + + return %2#1, %3#1 : tensor, tensor +} + +func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { + %0 = "tf.Const" () {value = dense<0> : tensor} : () -> tensor loc("Const") + %1 = "tf.Greater"(%arg0, %0) : (tensor<*xi32>, tensor) -> tensor + return %1 : tensor +} + +func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) { + %0 = "tf.Const" () {value = dense<1> : tensor} : () -> tensor loc("Const") + %1 = "tf.Sub"(%arg0, %0) : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %1, %2 : tensor<*xi32>, tensor<*xf32> +} + +// Verify that While op is mapped to TensorFlow StatelessWhile op if the +// is_stateless attribute is present and otherwise it is mapped to TensorFlow +// While op. In both cases, the additional attribute should be dropped. + +// CHECK: name: "StatefulWhile" +// CHECK-NOT: name: +// CHECK: op: "While" +// CHECK-NOT: is_stateless + +// CHECK: name: "StatelessWhile" +// CHECK-NOT: name: +// CHECK: op: "StatelessWhile" +// CHECK-NOT: is_stateless