Update Grappler to use existing functions for retrieving a node's

name and position.

PiperOrigin-RevId: 205465354
This commit is contained in:
A. Unique TensorFlower 2018-07-20 15:41:04 -07:00 committed by TensorFlower Gardener
parent 9151e61398
commit f4f37efdc9
3 changed files with 285 additions and 10 deletions

View File

@ -496,18 +496,11 @@ class SymbolicShapeRefiner {
"supported.");
}
// It is guaranteed that output_tensors does not contain any control
// inputs, so port_id >= 0.
string out_tensor = out_arg.output_tensors[0];
auto out_tensor_pieces = str_util::Split(out_tensor, ",");
string node_name = out_tensor_pieces[0];
int port_id;
// Check if port_id was included in out_tensor
if (out_tensor_pieces.size() <= 1) {
port_id = 0;
} else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) {
return errors::FailedPrecondition(
"Failed string to integer conversion for ", out_tensor_pieces[1]);
}
string node_name = ParseNodeName(out_tensor, &port_id);
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
@ -516,6 +509,11 @@ class SymbolicShapeRefiner {
}
auto output_properties = gp.GetOutputProperties(retnode->name());
if (port_id >= output_properties.size()) {
return errors::InvalidArgument(
out_tensor, " has invalid position ", port_id,
" (output_properties.size() = ", output_properties.size(), ").");
}
auto const& outprop = output_properties[port_id];
const TensorShapeProto& shape = outprop.shape();
ShapeHandle out;

View File

@ -887,6 +887,44 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(8, in_prop3.shape().dim(3).size());
}
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
// Test graph produced in python using:
/*
@function.Defun(noinline=True)
def MyFunc():
@function.Defun(*[tf.float32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[tf.float32] * 2)
def Body(n, x):
return n - 1, x + n
i = tf.constant(10)
return functional_ops.While([i, 0.], Cond, Body)
with tf.Graph().as_default():
z = MyFunc()
*/
GrapplerItem item;
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
"function_functional_while.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
EXPECT_EQ(2, out_props.size());
const OpInfo::TensorProperties& out_prop0 = out_props[0];
EXPECT_EQ(DT_INT32, out_prop0.dtype());
EXPECT_FALSE(out_prop0.shape().unknown_rank());
const OpInfo::TensorProperties& out_prop1 = out_props[1];
EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
EXPECT_FALSE(out_prop1.shape().unknown_rank());
}
TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
GrapplerItem item;
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,

View File

@ -0,0 +1,239 @@
node {
name: "MyFunc_AenMyWWx1Us"
op: "MyFunc_AenMyWWx1Us"
}
library {
function {
signature {
name: "MyFunc_AenMyWWx1Us"
output_arg {
name: "while"
type: DT_INT32
}
output_arg {
name: "while_0"
type: DT_FLOAT
}
is_stateful: true
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node_def {
name: "While/input_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node_def {
name: "While"
op: "While"
input: "Const:output:0"
input: "While/input_1:output:0"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "Body_8GOMGeZeK5c"
}
}
}
attr {
key: "cond"
value {
func {
name: "Cond_Xf5ttAHgUCg"
}
}
}
}
ret {
key: "while"
value: "While:output:0"
}
ret {
key: "while_0"
value: "While:output:1"
}
attr {
key: "_noinline"
value {
b: true
}
}
}
function {
signature {
name: "Body_8GOMGeZeK5c"
input_arg {
name: "n"
type: DT_FLOAT
}
input_arg {
name: "x"
type: DT_FLOAT
}
output_arg {
name: "sub"
type: DT_FLOAT
}
output_arg {
name: "add"
type: DT_FLOAT
}
}
node_def {
name: "sub/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node_def {
name: "sub_0"
op: "Sub"
input: "n"
input: "sub/y:output:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node_def {
name: "add_0"
op: "Add"
input: "x"
input: "n"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
ret {
key: "add"
value: "add_0:z:0"
}
ret {
key: "sub"
value: "sub_0:z:0"
}
}
function {
signature {
name: "Cond_Xf5ttAHgUCg"
input_arg {
name: "n"
type: DT_FLOAT
}
input_arg {
name: "unused_x"
type: DT_FLOAT
}
output_arg {
name: "greater"
type: DT_BOOL
}
}
node_def {
name: "Greater/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node_def {
name: "Greater"
op: "Greater"
input: "n"
input: "Greater/y:output:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
ret {
key: "greater"
value: "Greater:z:0"
}
}
}
versions {
producer: 26
min_consumer: 12
}