Calling InferenceContext::UnknownShapes twice produces 2 shape handles for which ShapeHandle::SameHandle returns false. Therefore we need to merge the shapes handles in InferenceContext::set_input, InferenceContext::set_input_handle_shape, and InferenceContext::set_output_handle_shape

Change: 154911808
This commit is contained in:
Benoit Steiner 2017-05-02 18:02:26 -08:00 committed by TensorFlower Gardener
parent 5ad12420e7
commit a8d720c2b7
4 changed files with 268 additions and 27 deletions

View File

@ -163,7 +163,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
InferenceContext* c = iter->second.get();
DCHECK_GE(e->dst_input(), 0);
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
if (node_context->MergeInput(e->dst_input(), c->output(e->src_output()))) {
*refined = true;
}
@ -174,7 +174,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
*refined = true;
}
if (node_context->set_input_handle_shape(
if (node_context->MergeInputHandleShape(
e->dst_input(), c->output_handle_shape(e->src_output()))) {
*refined = true;
}

View File

@ -191,16 +191,18 @@ class InferenceContext {
return s;
}
// Set the shape of the input in position idx. This requires idx to be in the
// [0, num_inputs) range. Returns true iff the stored input shape has been
// updated with a different handle.
bool set_input(int idx, ShapeHandle shape) {
if (!inputs_[idx].SameHandle(shape)) {
inputs_[idx] = shape;
return true;
} else {
// Merge the stored shape of the input in position idx with the specified
// shape. This requires idx to be in the [0, num_inputs) range. If the merge
// is successful and the new shape differs from the old one, store the new
// shape and return true. Return false otherwise.
bool MergeInput(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
inputs_[idx].SameHandle(new_shape)) {
return false;
}
inputs_[idx] = new_shape;
return true;
}
ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
@ -442,15 +444,18 @@ class InferenceContext {
// propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE.
// Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape
// has been updated with a different handle.
bool set_input_handle_shape(int idx, ShapeHandle shape) {
if (!input_handle_shape_[idx].SameHandle(shape)) {
input_handle_shape_[idx] = shape;
return true;
// Merge the stored shape corresponding to the input handle in position idx
// with the specified shape. This requires idx to be in the [0, num_inputs)
// range. If the merge is successful and the new shape differs from the old
// one, store the new shape and return true. Return false otherwise.
bool MergeInputHandleShape(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
input_handle_shape_[idx].SameHandle(new_shape)) {
return false;
}
return false;
input_handle_shape_[idx] = shape;
return true;
}
// Set the type corresponding to the resource in position idx. This requires
@ -468,15 +473,24 @@ class InferenceContext {
return input_handle_dtype_[idx];
}
// Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range.
// Returns true iff the stored shape has been updated with a different handle.
bool set_output_handle_shape(int idx, ShapeHandle shape) {
if (!output_handle_shape_[idx].SameHandle(shape)) {
output_handle_shape_[idx] = shape;
return true;
// Merge the stored shape corresponding to the output handle in position idx
// with the specified shape. This requires idx to be in the [0, num_outputs)
// range. If the merge is successful and the new shape differs from the old
// one, store the new shape and return true. Return false otherwise.
bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
output_handle_shape_[idx].SameHandle(new_shape)) {
return false;
}
return false;
output_handle_shape_[idx] = shape;
return true;
}
// Overwrite the shape corresponding to the output handle in position idx with
// the specified shape.
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape;
}
// Set the type corresponding to the resource in position idx. This requires

View File

@ -85,7 +85,7 @@ Status GraphProperties::InferStatically() {
}
}
if (qctx->set_output_handle_dtype(0, queue_type) ||
qctx->set_output_handle_shape(0, queue_shp)) {
qctx->MergeOutputHandleShape(0, queue_shp)) {
new_shapes.push(qnode);
}
}

View File

@ -177,10 +177,14 @@ TEST_F(GraphPropertiesTest, Queues) {
auto dequeue2 =
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
// Create a queue that feeds itself.
auto q3 =
ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
auto dequeue3 =
ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
auto enqueue3 =
ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
auto q4 =
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
@ -227,6 +231,229 @@ TEST_F(GraphPropertiesTest, Queues) {
EXPECT_EQ(7, prop4.shape().dim(1).size());
}
TEST_F(GraphPropertiesTest, Loops) {
// Test graph produced in python using:
/*
with tf.Graph().as_default():
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with open('/tmp/graph.txt', 'w') as f:
f.write(str(tf.get_default_graph().as_graph_def()))
*/
const string gdef_ascii = R"EOF(
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "while/Enter"
op: "Enter"
input: "Const"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "frame_name"
value {
s: "while/while/"
}
}
attr {
key: "is_constant"
value {
b: false
}
}
attr {
key: "parallel_iterations"
value {
i: 10
}
}
}
node {
name: "while/Merge"
op: "Merge"
input: "while/Enter"
input: "while/NextIteration"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Less/y"
op: "Const"
input: "^while/Merge"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node {
name: "while/Less"
op: "Less"
input: "while/Merge"
input: "while/Less/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/LoopCond"
op: "LoopCond"
input: "while/Less"
}
node {
name: "while/Switch"
op: "Switch"
input: "while/Merge"
input: "while/LoopCond"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
s: "loc:@while/Merge"
}
}
}
}
node {
name: "while/Identity"
op: "Identity"
input: "while/Switch:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Add/y"
op: "Const"
input: "^while/Identity"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
}
node {
name: "while/Add"
op: "Add"
input: "while/Identity"
input: "while/Add/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/NextIteration"
op: "NextIteration"
input: "while/Add"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Exit"
op: "Exit"
input: "while/Switch"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 11
}
)EOF";
GrapplerItem item;
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props = properties.GetOutputProperties("while/Exit");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_INT32, prop.dtype());
EXPECT_TRUE(prop.shape().unknown_rank());
}
} // namespace
} // namespace grappler
} // namespace tensorflow