Add import and export tests for handling of If and StatelessIf ops
PiperOrigin-RevId: 261600770
This commit is contained in:
parent
8029f9f172
commit
cc6e729c9e
@ -0,0 +1,256 @@
|
|||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s
|
||||||
|
|
||||||
|
# Verify that TensorFlow If and StatelessIf ops are mapped to the
|
||||||
|
# composite If op in MLIR with is_stateless attribute set accordingly to
|
||||||
|
# distinguish between them.
|
||||||
|
|
||||||
|
# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf"
|
||||||
|
# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf"
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "tf.Less"
|
||||||
|
op: "Less"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "StatefulIf"
|
||||||
|
op: "If"
|
||||||
|
input: "tf.Less"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "Tcond"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "else_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_false"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "then_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "StatelessIf"
|
||||||
|
op: "StatelessIf"
|
||||||
|
input: "tf.Less"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "Tcond"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "else_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_false"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "then_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "main"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "StatefulIf"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "main1"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "StatelessIf"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "a"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "b"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_true"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_true2"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Add"
|
||||||
|
op: "Add"
|
||||||
|
input: "cond_true"
|
||||||
|
input: "cond_true1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_true2"
|
||||||
|
value: "tf.Add:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_false"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_false2"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "cond_false"
|
||||||
|
input: "cond_false1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Mul"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_false2"
|
||||||
|
value: "tf.Mul:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 115
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,34 @@
|
|||||||
|
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||||
|
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||||
|
%2 = "tf.Less"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatefulIf")
|
||||||
|
%4 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatelessIf")
|
||||||
|
return %3, %4 : tensor<f32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless
|
||||||
|
// attribute is present and otherwise it is mapped to TensorFlow If op. In both
|
||||||
|
// cases, the additional attribute should be dropped.
|
||||||
|
|
||||||
|
// CHECK: name: "StatefulIf"
|
||||||
|
// CHECK-NOT: name:
|
||||||
|
// CHECK: op: "If"
|
||||||
|
// CHECK-NOT: is_stateless
|
||||||
|
|
||||||
|
// CHECK: name: "StatelessIf"
|
||||||
|
// CHECK-NOT: name:
|
||||||
|
// CHECK: op: "StatelessIf"
|
||||||
|
// CHECK-NOT: is_stateless
|
Loading…
Reference in New Issue
Block a user