Add support for target nodes of TensorFlow Graphs in TF MLIR importer.

It is possible to have target nodes/control rets for nodes that must execute but may not have data output(s). For example, side effecting ops, like resource writes, may rely on being set as a target node to be executed. This also unifies control rets support added for function graphs.

PiperOrigin-RevId: 302683770
Change-Id: I33a66ae0ce08ac2230b575caae2cc752636d0d8b
This commit is contained in:
Andy Ly 2020-03-24 09:39:51 -07:00 committed by TensorFlower Gardener
parent e1afcc5feb
commit 463bec0d92
15 changed files with 386 additions and 121 deletions

View File

@ -36,8 +36,11 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME inputs = "input0,input1"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -443,12 +443,15 @@ node {
}
}
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
# MLIR-LABEL: func @main
# MLIR-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR-SAME: control_outputs = ""
# MLIR-SAME: inputs = "input"
# MLIR-SAME: outputs = "output"
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: }

View File

@ -47,7 +47,10 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
# CHECK-NEXT: return %arg0 : tensor<4xi32>
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>) -> tensor<4xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: return %[[ARG_0]] : tensor<4xi32>
# CHECK-NEXT: }

View File

@ -137,8 +137,11 @@ versions {
producer: 198
}
# CHECK: func @main([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
# CHECK: attributes {tf.entry_function = {inputs = "unranked", outputs = "unranked,static,static_10"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: ([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "unranked"
# CHECK-SAME: outputs = "unranked,static,static_10"
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>

View File

@ -7783,8 +7783,11 @@ node {
library {
}
# CHECK: func @main(%arg0: tensor<1x3x3xf32>) -> tensor<1x3xf32>
# CHECK: attributes {tf.entry_function = {inputs = "INPUT", outputs = "OUTPUT"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x3x3xf32>) -> tensor<1x3xf32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "INPUT"
# CHECK-SAME: outputs = "OUTPUT"
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
@ -7808,7 +7811,7 @@ library {
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_23:%.*]] = constant unit
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>

View File

@ -38,17 +38,26 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "Add"
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK: fetch %[[add]]
# SOME: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# SOME: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# SOME-LABEL: func @main
# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# SOME-SAME: control_outputs = ""
# SOME-SAME: inputs = "input0,input1"
# SOME-SAME: outputs = "Add"
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# SOME: fetch %[[add]]
# NONE: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# NONE: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# NONE-LABEL: func @main
# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# NONE-SAME: control_outputs = ""
# NONE-SAME: inputs = "input0,input1"
# NONE-SAME: outputs = "Add"
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# NONE: fetch %[[add]]

View File

@ -20,8 +20,11 @@ versions {
producer: 27
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK: attributes {tf.entry_function = {inputs = "arg", outputs = "arg"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "arg"
# CHECK-SAME: outputs = "arg"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]

View File

@ -14,8 +14,11 @@ versions {
producer: 27
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "input"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "input"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]

View File

@ -59,8 +59,11 @@ library {
versions {
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
# CHECK: tf.entry_function = {inputs = "input", outputs = "output_node"}
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "output_node"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])

View File

@ -2,12 +2,15 @@
# Verify that we match correctly the input / output when they are scalar.
# CHECK: func @main(%arg0: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "out:1,out"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: (%{{[a-z0-9]+}}: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "out:1,out"
# CHECK: tf.Relu
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
node {
name: "input"

View File

@ -268,8 +268,11 @@ versions {
# function. Rets that happen to coincide with a feed should have its value be
# of the feed.
#
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "z:1,z:2"
# CHECK-SAME: outputs = "z:2,z:1,a:0"
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
@ -278,8 +281,11 @@ versions {
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are pruned if pruning is enabled.
#
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# PRUNE-LABEL: func @main
# PRUNE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# PRUNE-SAME: control_outputs = ""
# PRUNE-SAME: inputs = "z:1,z:2"
# PRUNE-SAME: outputs = "z:2,z:1,a:0"
# PRUNE-NOT: "tf.Const"
# PRUNE-NOT: "tf.VariableV2"
# PRUNE-NOT: "tf.Assign"
@ -292,8 +298,11 @@ versions {
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are preserved if pruning is not enabled.
#
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
# PRESERVE-LABEL: func @main
# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
# PRESERVE-SAME: control_outputs = ""
# PRESERVE-SAME: inputs = "z:1,z:2"
# PRESERVE-SAME: outputs = "z:0,a:0"
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])

View File

@ -0,0 +1,203 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
# Generated in Python via
# ```
# import tensorflow as tf
#
# with tf.compat.v1.Graph().as_default() as g:
# var = tf.Variable(2.0)
# var_add = var.assign_add(3.0)
# ```
node {
name: "Variable/initial_value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "Variable"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "Variable/Assign"
op: "Assign"
input: "Variable"
input: "Variable/initial_value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "Variable/read"
op: "Identity"
input: "Variable"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
}
node {
name: "AssignAdd/value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "AssignAdd"
op: "AssignAdd"
input: "Variable"
input: "AssignAdd/value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
attr {
key: "use_locking"
value {
b: false
}
}
}
versions {
producer: 309
}
# Tests single target node with no pruning set. All nodes will remain in the
# graph and the target node is added to the graph fetch as a control.
#
# CHECK-LABEL: func @main
# CHECK-SAME: control_outputs = "AssignAdd"
# CHECK-SAME: inputs = ""
# CHECK-SAME: outputs = ""
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
# Tests single target node with pruning set. Unreachable nodes from the fetch,
# including side effecting nodes, will be removed. The target node is added to
# the graph fetch as a control.
#
# PRUNE-LABEL: func @main
# PRUNE-SAME: control_outputs = "AssignAdd"
# PRUNE-SAME: inputs = ""
# PRUNE-SAME: outputs = ""
# PRUNE-NOT: "tf.Assign"
# PRUNE-NOT: "tf.Identity"
# PRUNE-DAG: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
# PRUNE-DAG: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
# PRUNE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"(%[[VAR]], %[[CONST]])
# PRUNE-NEXT: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
# Tests multiple target nodes with pruning set. Unreachable nodes from the
# fetch, including side effecting nodes, will be removed. The target nodes are
# added to the graph fetch as controls.
#
# PRESERVE-LABEL: func @main
# PRESERVE-SAME: control_outputs = "Variable/Assign,AssignAdd"
# PRESERVE-SAME: inputs = ""
# PRESERVE-SAME: outputs = ""
# PRESERVE-NOT: "tf.Identity"
# PRESERVE: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_CTRL:.*]] = tf_executor.island wraps "tf.Assign"
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE-NEXT: tf_executor.fetch %[[ASSIGN_CTRL]], %[[ASSIGN_ADD_CTRL]]

View File

@ -570,6 +570,13 @@ StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
edge->dst_input()));
}
// TODO(lyandy): Preserve control dependencies properly by not forwarding
// control dependencies to data outputs and not removing single output nodes.
// When a data output is replaced as a feed, unless there is another non feed
// data output or an explicit control output used by the same node, transitive
// control dependencies are not to be executed. For single output nodes,
// Placeholders can be converted to a NoOp if there are no uses, and
// PlaceholderWithDefault can be converted to an Identity.
for (const auto* edge : control_edges) {
graph_->AddControlEdge(placeholder_node, edge->dst());
graph_->RemoveControlEdge(edge);
@ -616,6 +623,9 @@ Status ImporterBase::GetInputOutputNodes(
}
}
for (const auto& control_output : specs_.control_outputs)
TF_RETURN_IF_ERROR(add_node(control_output));
return Status::OK();
}
@ -658,7 +668,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
DataType dtype = array_info.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
dtype = node->output_type(index);
}
TF_ASSIGN_OR_RETURN(
@ -1787,10 +1797,10 @@ class GraphDefImporter : public ImporterBase {
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids);
// Finds the function's control ret nodes based on supplied node names in
// `control_outputs`. If `control_outputs` are not unique or a control ret
// node is missing, an error will be returned.
Status GetControlRetsFromFunctionGraph(
// Finds the graph's target nodes/function's control ret nodes based on
// supplied node names in `control_outputs`. If `control_outputs` are not
// unique or a control ret node is missing, an error will be returned.
Status GetControlRetsFromGraph(
llvm::ArrayRef<std::string> control_outputs,
absl::InlinedVector<Node*, 4>* control_ret_nodes);
};
@ -1827,8 +1837,8 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
importer.GetArgsRetsAndTypesFromFunctionGraph(
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromFunctionGraph(
specs.control_outputs, &control_ret_nodes));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes));
if (!arg_nodes.empty() || !ret_nodes.empty() ||
!control_ret_nodes.empty()) {
@ -1858,10 +1868,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
specs, context, &arg_nodes, &ret_nodes));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes));
// TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
// decoding in a centralized place.
// Record the input and output mapping.
if (!specs.inputs.empty() || !specs.outputs.empty()) {
if (!specs.inputs.empty() || !specs.outputs.empty() ||
!specs.control_outputs.empty()) {
mlir::Builder b(context);
std::string s;
llvm::raw_string_ostream ss(s);
@ -1873,9 +1887,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
s.clear();
mlir::interleave(specs.outputs, ss, ",");
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
s.clear();
mlir::interleave(specs.control_outputs, ss, ",");
auto control_outputs =
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
attrs.push_back(b.getNamedAttr("tf.entry_function",
b.getDictionaryAttr({inputs, outputs})));
attrs.push_back(b.getNamedAttr(
"tf.entry_function",
b.getDictionaryAttr({inputs, outputs, control_outputs})));
}
}
@ -2064,7 +2083,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
return builder.getFunctionType(arg_types, ret_types);
}
Status GraphDefImporter::GetControlRetsFromFunctionGraph(
Status GraphDefImporter::GetControlRetsFromGraph(
llvm::ArrayRef<std::string> control_outputs,
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
if (control_outputs.empty()) return Status::OK();

View File

@ -42,8 +42,7 @@ struct GraphImportConfig {
InputArrays inputs;
// name:index strings for the data outputs.
std::vector<string> outputs;
// name strings for the control outputs. This is currently only used when
// `graph_as_function` is set.
// name strings for the control outputs.
std::vector<string> control_outputs;
// Setting prune_unused_nodes to true, would prune unreachable nodes if
// output_arrays is specified.

View File

@ -50,8 +50,7 @@ opt<std::string> output_arrays(
// NOLINTNEXTLINE
opt<std::string> control_output_arrays(
"tf-control-output-arrays",
llvm::cl::desc("Control output node names, separated by ',', for main "
"graphs that are functions"),
llvm::cl::desc("Control output node names, separated by ','"),
llvm::cl::init(""));
// NOLINTNEXTLINE