Add import and export tests for handling of While and StatelessWhile ops
PiperOrigin-RevId: 261596381
This commit is contained in:
parent
1442bc2116
commit
8029f9f172
@ -0,0 +1,283 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s
|
||||
|
||||
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
|
||||
# composite While op in MLIR with is_stateless attribute set accordingly to
|
||||
# distinguish between them.
|
||||
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile"
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile"
|
||||
|
||||
node {
|
||||
name: "StatefulWhile"
|
||||
op: "While"
|
||||
input: "iter"
|
||||
input: "val"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "body"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "cond"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "StatelessWhile"
|
||||
op: "StatelessWhile"
|
||||
input: "iter"
|
||||
input: "val"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "body"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "cond"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
input: "StatefulWhile:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main1"
|
||||
op: "_Retval"
|
||||
input: "StatelessWhile:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "iter"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "val"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "cond"
|
||||
input_arg {
|
||||
name: "cond"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "cond1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "cond2"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Const"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Greater"
|
||||
op: "Greater"
|
||||
input: "cond"
|
||||
input: "Const:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Greater"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "cond2"
|
||||
value: "tf.Greater:z:0"
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "body"
|
||||
input_arg {
|
||||
name: "body"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "body1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "body2"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "body3"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Const"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Sub"
|
||||
op: "Sub"
|
||||
input: "body"
|
||||
input: "Const:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Sub"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Add"
|
||||
op: "Add"
|
||||
input: "body1"
|
||||
input: "body1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "body2"
|
||||
value: "tf.Sub:z:0"
|
||||
}
|
||||
ret {
|
||||
key: "body3"
|
||||
value: "tf.Add:z:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 115
|
||||
min_consumer: 12
|
||||
}
|
||||
|
@ -0,0 +1,43 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%iter = "tf.Placeholder.input"(%arg0) : (tensor<i32>) -> tensor<i32> loc("iter")
|
||||
%val = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32> loc("val")
|
||||
|
||||
// Element wise add `val` with itself for `iter` number of times.
|
||||
%2:2 = "tf.While"(%iter, %val) {
|
||||
cond = @cond, body = @body, is_stateless = false
|
||||
} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatefulWhile")
|
||||
%3:2 = "tf.While"(%iter, %val) {
|
||||
cond = @cond, body = @body, is_stateless = true
|
||||
} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatelessWhile")
|
||||
|
||||
return %2#1, %3#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "tf.Const" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tf.Greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||
%0 = "tf.Const" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tf.Sub"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %1, %2 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// Verify that While op is mapped to TensorFlow StatelessWhile op if the
|
||||
// is_stateless attribute is present and otherwise it is mapped to TensorFlow
|
||||
// While op. In both cases, the additional attribute should be dropped.
|
||||
|
||||
// CHECK: name: "StatefulWhile"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
|
||||
// CHECK: name: "StatelessWhile"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "StatelessWhile"
|
||||
// CHECK-NOT: is_stateless
|
Loading…
x
Reference in New Issue
Block a user