Calling InferenceContext::UnknownShapes twice produces 2 shape handles for which ShapeHandle::SameHandle returns false. Therefore we need to merge the shapes handles in InferenceContext::set_input, InferenceContext::set_input_handle_shape, and InferenceContext::set_output_handle_shape
Change: 154911808
This commit is contained in:
parent
5ad12420e7
commit
a8d720c2b7
@ -163,7 +163,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
||||
|
||||
InferenceContext* c = iter->second.get();
|
||||
DCHECK_GE(e->dst_input(), 0);
|
||||
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
|
||||
if (node_context->MergeInput(e->dst_input(), c->output(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
||||
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
if (node_context->set_input_handle_shape(
|
||||
if (node_context->MergeInputHandleShape(
|
||||
e->dst_input(), c->output_handle_shape(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
|
@ -191,16 +191,18 @@ class InferenceContext {
|
||||
return s;
|
||||
}
|
||||
|
||||
// Set the shape of the input in position idx. This requires idx to be in the
|
||||
// [0, num_inputs) range. Returns true iff the stored input shape has been
|
||||
// updated with a different handle.
|
||||
bool set_input(int idx, ShapeHandle shape) {
|
||||
if (!inputs_[idx].SameHandle(shape)) {
|
||||
inputs_[idx] = shape;
|
||||
return true;
|
||||
} else {
|
||||
// Merge the stored shape of the input in position idx with the specified
|
||||
// shape. This requires idx to be in the [0, num_inputs) range. If the merge
|
||||
// is successful and the new shape differs from the old one, store the new
|
||||
// shape and return true. Return false otherwise.
|
||||
bool MergeInput(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
|
||||
inputs_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
inputs_[idx] = new_shape;
|
||||
return true;
|
||||
}
|
||||
ShapeHandle input(int64 idx) const { return inputs_[idx]; }
|
||||
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
|
||||
@ -442,15 +444,18 @@ class InferenceContext {
|
||||
// propagate that information. Output handle dtypes and shapes are ignored if
|
||||
// the output tensor is not of type DT_RESOURCE.
|
||||
|
||||
// Set the shape corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape
|
||||
// has been updated with a different handle.
|
||||
bool set_input_handle_shape(int idx, ShapeHandle shape) {
|
||||
if (!input_handle_shape_[idx].SameHandle(shape)) {
|
||||
input_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
// Merge the stored shape corresponding to the input handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_inputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
bool MergeInputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
input_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
input_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
@ -468,15 +473,24 @@ class InferenceContext {
|
||||
return input_handle_dtype_[idx];
|
||||
}
|
||||
|
||||
// Set the shape corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_outputs) range.
|
||||
// Returns true iff the stored shape has been updated with a different handle.
|
||||
bool set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
if (!output_handle_shape_[idx].SameHandle(shape)) {
|
||||
output_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
// Merge the stored shape corresponding to the output handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_outputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
|
||||
bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
output_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
output_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
}
|
||||
// Overwrite the shape corresponding to the output handle in position idx with
|
||||
// the specified shape.
|
||||
void set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
output_handle_shape_[idx] = shape;
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
|
@ -85,7 +85,7 @@ Status GraphProperties::InferStatically() {
|
||||
}
|
||||
}
|
||||
if (qctx->set_output_handle_dtype(0, queue_type) ||
|
||||
qctx->set_output_handle_shape(0, queue_shp)) {
|
||||
qctx->MergeOutputHandleShape(0, queue_shp)) {
|
||||
new_shapes.push(qnode);
|
||||
}
|
||||
}
|
||||
|
@ -177,10 +177,14 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
auto dequeue2 =
|
||||
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
|
||||
|
||||
// Create a queue that feeds itself.
|
||||
auto q3 =
|
||||
ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
|
||||
auto dequeue3 =
|
||||
ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
|
||||
auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
|
||||
auto enqueue3 =
|
||||
ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
|
||||
|
||||
auto q4 =
|
||||
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
|
||||
@ -227,6 +231,229 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
EXPECT_EQ(7, prop4.shape().dim(1).size());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, Loops) {
|
||||
// Test graph produced in python using:
|
||||
/*
|
||||
with tf.Graph().as_default():
|
||||
i = tf.constant(0)
|
||||
c = lambda i: tf.less(i, 10)
|
||||
b = lambda i: tf.add(i, 1)
|
||||
r = tf.while_loop(c, b, [i])
|
||||
with open('/tmp/graph.txt', 'w') as f:
|
||||
f.write(str(tf.get_default_graph().as_graph_def()))
|
||||
*/
|
||||
const string gdef_ascii = R"EOF(
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Enter"
|
||||
op: "Enter"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "frame_name"
|
||||
value {
|
||||
s: "while/while/"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "is_constant"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "parallel_iterations"
|
||||
value {
|
||||
i: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Merge"
|
||||
op: "Merge"
|
||||
input: "while/Enter"
|
||||
input: "while/NextIteration"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less/y"
|
||||
op: "Const"
|
||||
input: "^while/Merge"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less"
|
||||
op: "Less"
|
||||
input: "while/Merge"
|
||||
input: "while/Less/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/LoopCond"
|
||||
op: "LoopCond"
|
||||
input: "while/Less"
|
||||
}
|
||||
node {
|
||||
name: "while/Switch"
|
||||
op: "Switch"
|
||||
input: "while/Merge"
|
||||
input: "while/LoopCond"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@while/Merge"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity"
|
||||
op: "Identity"
|
||||
input: "while/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add/y"
|
||||
op: "Const"
|
||||
input: "^while/Identity"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add"
|
||||
op: "Add"
|
||||
input: "while/Identity"
|
||||
input: "while/Add/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/NextIteration"
|
||||
op: "NextIteration"
|
||||
input: "while/Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Exit"
|
||||
op: "Exit"
|
||||
input: "while/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 11
|
||||
}
|
||||
)EOF";
|
||||
|
||||
GrapplerItem item;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically());
|
||||
|
||||
const auto props = properties.GetOutputProperties("while/Exit");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_INT32, prop.dtype());
|
||||
EXPECT_TRUE(prop.shape().unknown_rank());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user