Update Grappler to use existing functions for retrieving a node's
name and position. PiperOrigin-RevId: 205465354
This commit is contained in:
parent
9151e61398
commit
f4f37efdc9
@ -496,18 +496,11 @@ class SymbolicShapeRefiner {
|
||||
"supported.");
|
||||
}
|
||||
|
||||
// It is guaranteed that output_tensors does not contain any control
|
||||
// inputs, so port_id >= 0.
|
||||
string out_tensor = out_arg.output_tensors[0];
|
||||
auto out_tensor_pieces = str_util::Split(out_tensor, ",");
|
||||
string node_name = out_tensor_pieces[0];
|
||||
int port_id;
|
||||
|
||||
// Check if port_id was included in out_tensor
|
||||
if (out_tensor_pieces.size() <= 1) {
|
||||
port_id = 0;
|
||||
} else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed string to integer conversion for ", out_tensor_pieces[1]);
|
||||
}
|
||||
string node_name = ParseNodeName(out_tensor, &port_id);
|
||||
|
||||
const NodeDef* retnode = gv.GetNode(node_name);
|
||||
if (retnode == nullptr) {
|
||||
@ -516,6 +509,11 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
|
||||
auto output_properties = gp.GetOutputProperties(retnode->name());
|
||||
if (port_id >= output_properties.size()) {
|
||||
return errors::InvalidArgument(
|
||||
out_tensor, " has invalid position ", port_id,
|
||||
" (output_properties.size() = ", output_properties.size(), ").");
|
||||
}
|
||||
auto const& outprop = output_properties[port_id];
|
||||
const TensorShapeProto& shape = outprop.shape();
|
||||
ShapeHandle out;
|
||||
|
||||
@ -887,6 +887,44 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
|
||||
EXPECT_EQ(8, in_prop3.shape().dim(3).size());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
|
||||
// Test graph produced in python using:
|
||||
/*
|
||||
@function.Defun(noinline=True)
|
||||
def MyFunc():
|
||||
@function.Defun(*[tf.float32] * 2)
|
||||
def Cond(n, unused_x):
|
||||
return n > 0
|
||||
|
||||
@function.Defun(*[tf.float32] * 2)
|
||||
def Body(n, x):
|
||||
return n - 1, x + n
|
||||
|
||||
i = tf.constant(10)
|
||||
return functional_ops.While([i, 0.], Cond, Body)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
z = MyFunc()
|
||||
*/
|
||||
GrapplerItem item;
|
||||
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
|
||||
"function_functional_while.pbtxt");
|
||||
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(false));
|
||||
|
||||
const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
|
||||
EXPECT_EQ(2, out_props.size());
|
||||
|
||||
const OpInfo::TensorProperties& out_prop0 = out_props[0];
|
||||
EXPECT_EQ(DT_INT32, out_prop0.dtype());
|
||||
EXPECT_FALSE(out_prop0.shape().unknown_rank());
|
||||
|
||||
const OpInfo::TensorProperties& out_prop1 = out_props[1];
|
||||
EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
|
||||
EXPECT_FALSE(out_prop1.shape().unknown_rank());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
|
||||
GrapplerItem item;
|
||||
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
node {
|
||||
name: "MyFunc_AenMyWWx1Us"
|
||||
op: "MyFunc_AenMyWWx1Us"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "MyFunc_AenMyWWx1Us"
|
||||
output_arg {
|
||||
name: "while"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "while_0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
node_def {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "While/input_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "While"
|
||||
op: "While"
|
||||
input: "Const:output:0"
|
||||
input: "While/input_1:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "Body_8GOMGeZeK5c"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "Cond_Xf5ttAHgUCg"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "while"
|
||||
value: "While:output:0"
|
||||
}
|
||||
ret {
|
||||
key: "while_0"
|
||||
value: "While:output:1"
|
||||
}
|
||||
attr {
|
||||
key: "_noinline"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "Body_8GOMGeZeK5c"
|
||||
input_arg {
|
||||
name: "n"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "sub"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "sub/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "sub_0"
|
||||
op: "Sub"
|
||||
input: "n"
|
||||
input: "sub/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "n"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
ret {
|
||||
key: "sub"
|
||||
value: "sub_0:z:0"
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "Cond_Xf5ttAHgUCg"
|
||||
input_arg {
|
||||
name: "n"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "unused_x"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "greater"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Greater/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Greater"
|
||||
op: "Greater"
|
||||
input: "n"
|
||||
input: "Greater/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "greater"
|
||||
value: "Greater:z:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 26
|
||||
min_consumer: 12
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user