From 0287e879ac67bf41b862b0e4583ceab4e678ea2b Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 1 May 2017 17:41:00 -0800 Subject: [PATCH] Enable grappler to propagate shapes through queues. Change: 154789133 --- .../core/common_runtime/shape_refiner.cc | 228 ++++++++++++------ .../core/common_runtime/shape_refiner.h | 8 + .../core/common_runtime/shape_refiner_test.cc | 33 +++ tensorflow/core/framework/shape_inference.h | 57 ++++- tensorflow/core/grappler/costs/BUILD | 4 +- .../core/grappler/costs/graph_properties.cc | 68 ++++++ .../grappler/costs/graph_properties_test.cc | 73 ++++++ tensorflow/core/ops/data_flow_ops.cc | 12 +- 8 files changed, 402 insertions(+), 81 deletions(-) diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 5135355a949..daa9e5091af 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -89,9 +89,6 @@ Status ShapeRefiner::AddNode(const Node* node) { // This needs to be filled in with real data in a second pass. std::vector input_tensors(node->num_inputs()); - std::vector real_tensors(node->num_inputs()); - std::vector attempted_materialization(node->num_inputs()); - std::vector attempted_tensor_as_shape_conversion(node->num_inputs()); std::vector input_tensors_as_shapes; // Create the inference context for this node with the existing input shapes. @@ -104,78 +101,7 @@ Status ShapeRefiner::AddNode(const Node* node) { } // Run the shape inference function, and return if there was an error. - if (op_reg_data->shape_inference_fn) { - TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); - } else { - TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); - } - - // We must run the shape function repeatedly, in case users write - // shape functions where they only conditionally call input_tensor() - // based on the values of another input tensor. - bool rerun_shape_fn; - do { - // If the result of running shape inference would have benefitted - // from knowing the values of input tensors, try to materialize - // the results of those tensors, and then run the shape inference - // function again using those known tensors. - rerun_shape_fn = false; - - // NOTE: It is possible to batch the extraction and - // materialization of inputs, instead of materializing one input - // at a time like we do below. If input-at-a-time computation - // becomes a bottleneck, we could separate ExtractConstantSubgraph - // into two functions: one that returns true if an input is - // derivable from constants, and another function that extracts - // the subgraph for multiple target nodes and executes the whole - // subgraph once. - - for (int i = 0; i < c->num_inputs(); ++i) { - if (!c->requested_input_tensor(i)) { - continue; - } - // Check if we have not already filled in the requested input, - // and if not, try to materialize the tensors. - if (!attempted_materialization[i]) { - attempted_materialization[i] = true; - - Tensor result; - bool evaluated = false; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, i, &evaluated, &result)); - if (evaluated) { - real_tensors[i] = result; - input_tensors[i] = &real_tensors[i]; - // We have more concrete information about a shape, - // so re-run shape inference. - rerun_shape_fn = true; - } - } - if (c->requested_input_tensor_as_partial_shape(i) && - !attempted_tensor_as_shape_conversion[i]) { - attempted_tensor_as_shape_conversion[i] = true; - if (i >= input_tensors_as_shapes.size()) { - input_tensors_as_shapes.resize(i + 1); - } - ShapeHandle s; - TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s)); - input_tensors_as_shapes[i] = s; - rerun_shape_fn = true; - } - } - - if (rerun_shape_fn) { - // We have more information about the shapes on this pass, - // so re-run shape inference. - c->set_input_tensors(input_tensors); - c->set_input_tensors_as_shapes(input_tensors_as_shapes); - if (op_reg_data->shape_inference_fn) { - TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get())); - } else { - TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get())); - } - } - } while (rerun_shape_fn); + TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, c.get())); // Store the resulting InferenceContext object in the map. node_to_context_[node].swap(c); @@ -211,6 +137,71 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, return Status::OK(); } +Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) { + auto it = node_to_context_.find(node); + if (it == node_to_context_.end()) { + *refined = true; + return AddNode(node); + } + InferenceContext* node_context = it->second.get(); + + // Check if the shapes of the nodes in the fan-in of this node have changed, + // and if they have update the node input shapes. + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) continue; + + Node* input = e->src(); + auto iter = node_to_context_.find(input); + if (iter == node_to_context_.end()) { + return errors::FailedPrecondition( + "Input ", e->dst_input(), " ('", input->name(), "') for '", + node->name(), "' was not previously added to ShapeRefiner."); + } + + InferenceContext* c = iter->second.get(); + DCHECK_GE(e->dst_input(), 0); + if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) { + *refined = true; + } + + // Also propagate handle shape and dtype of edges which are carrying + // resource handles. + if (e->src()->output_type(e->src_output()) == DT_RESOURCE) { + if (node_context->set_input_handle_dtype( + e->dst_input(), c->output_handle_dtype(e->src_output()))) { + *refined = true; + } + if (node_context->set_input_handle_shape( + e->dst_input(), c->output_handle_shape(e->src_output()))) { + *refined = true; + } + } + } + + if (!*refined) { + // No input shape has changed, we're done + return Status::OK(); + } + + // Get and run the shape function for this node to update the shapes of the + // outputs. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr && + require_shape_inference_fns_) { + return errors::InvalidArgument( + "No shape inference function exists for op '", node->type_string(), + "', did you forget to define it?"); + } + + if (!op_reg_data->shape_inference_fn) { + // There is nothing more we can infer + return Status::OK(); + } + + return RunShapeFn(node, op_reg_data, node_context); +} + Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result) { @@ -463,4 +454,91 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, return Status::OK(); } +Status ShapeRefiner::RunShapeFn(const Node* node, + const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* c) { + // This will be filled in with real data in a second pass. + std::vector input_tensors(node->num_inputs()); + std::vector real_tensors(node->num_inputs()); + std::vector attempted_materialization(node->num_inputs()); + std::vector attempted_tensor_as_shape_conversion(node->num_inputs()); + std::vector input_tensors_as_shapes; + + // Run the shape inference function, and return if there was an error. + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); + } else { + TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); + } + + // We must run the shape function repeatedly, in case users write + // shape functions where they only conditionally call input_tensor() + // based on the values of another input tensor. + bool rerun_shape_fn; + do { + // If the result of running shape inference would have benefitted + // from knowing the values of input tensors, try to materialize + // the results of those tensors, and then run the shape inference + // function again using those known tensors. + rerun_shape_fn = false; + + // NOTE: It is possible to batch the extraction and + // materialization of inputs, instead of materializing one input + // at a time like we do below. If input-at-a-time computation + // becomes a bottleneck, we could separate ExtractConstantSubgraph + // into two functions: one that returns true if an input is + // derivable from constants, and another function that extracts + // the subgraph for multiple target nodes and executes the whole + // subgraph once. + + for (int i = 0; i < c->num_inputs(); ++i) { + if (!c->requested_input_tensor(i)) { + continue; + } + // Check if we have not already filled in the requested input, + // and if not, try to materialize the tensors. + if (!attempted_materialization[i]) { + attempted_materialization[i] = true; + + Tensor result; + bool evaluated = false; + TF_RETURN_IF_ERROR( + EvaluateConstantTensorForEdge(node, i, &evaluated, &result)); + if (evaluated) { + real_tensors[i] = result; + input_tensors[i] = &real_tensors[i]; + // We have more concrete information about a shape, + // so re-run shape inference. + rerun_shape_fn = true; + } + } + if (c->requested_input_tensor_as_partial_shape(i) && + !attempted_tensor_as_shape_conversion[i]) { + attempted_tensor_as_shape_conversion[i] = true; + if (i >= input_tensors_as_shapes.size()) { + input_tensors_as_shapes.resize(i + 1); + } + ShapeHandle s; + TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s)); + input_tensors_as_shapes[i] = s; + rerun_shape_fn = true; + } + } + + if (rerun_shape_fn) { + // We have more information about the shapes on this pass, + // so re-run shape inference. + c->set_input_tensors(input_tensors); + c->set_input_tensors_as_shapes(input_tensors_as_shapes); + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c)); + } else { + TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c)); + } + } + } while (rerun_shape_fn); + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 2d04ea15055..9709bd03021 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -55,6 +55,11 @@ class ShapeRefiner { Status SetShape(const Node* node, int output_port, shape_inference::ShapeHandle shape); + // Update the input shapes of node in case the shapes of the fan-ins of 'node' + // have themselves been modified (For example, in case of incremental shape + // refinement). Sets refined to true if any of the node shape has changed. + Status UpdateNode(const Node* node, bool* refined); + // Returns the InferenceContext for 'node', if present. shape_inference::InferenceContext* GetContext(const Node* node) const { auto it = node_to_context_.find(node); @@ -108,6 +113,9 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); + Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* c); + int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index d7e7c3b5ad5..b8df6dd4f62 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -768,5 +768,38 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { m.AddNode(result).error_message()); } +TEST(ShapeRefinerTest, IncrementalUpdates) { + Scope root = Scope::NewRootScope(); + Graph* g = root.graph(); + Node* queue; + TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2") + .Attr("component_types", {DT_FLOAT}) + .Finalize(g, &queue)); + Node* dequeue; + TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2") + .Attr("component_types", {DT_FLOAT}) + .Input(queue) + .Finalize(g, &dequeue)); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(queue)); + TF_ASSERT_OK(m.AddNode(dequeue)); + + // At this point, the shapes of the dequeued tensor are unknown. + shape_inference::InferenceContext* ctx = m.GetContext(dequeue); + EXPECT_EQ("?", ctx->DebugString(ctx->output(0))); + + // Inject a shape, and incrementally propagate it to the dequeue op. + ctx = m.GetContext(queue); + shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7}); + ctx->set_output_handle_shape(0, shp); + ctx->set_output_handle_dtype(0, DT_FLOAT); + + bool refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, &refined)); + EXPECT_TRUE(refined); + ctx = m.GetContext(dequeue); + EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index e88f6dbb042..71663027b3c 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -191,6 +191,17 @@ class InferenceContext { return s; } + // Set the shape of the input in position idx. This requires idx to be in the + // [0, num_inputs) range. Returns true iff the stored input shape has been + // updated with a different handle. + bool set_input(int idx, ShapeHandle shape) { + if (!inputs_[idx].SameHandle(shape)) { + inputs_[idx] = shape; + return true; + } else { + return false; + } + } ShapeHandle input(int64 idx) const { return inputs_[idx]; } Status input(StringPiece input_name, std::vector* output) const; int num_inputs() const { return inputs_.size(); } @@ -430,15 +441,53 @@ class InferenceContext { // and dtypes of tensors which can be accessed via the handle. These methods // propagate that information. Output handle dtypes and shapes are ignored if // the output tensor is not of type DT_RESOURCE. + + // Set the shape corresponding to the resource in position idx. This requires + // idx to be in the [0, num_inputs) range. Returns true iff the stored shape + // has been updated with a different handle. + bool set_input_handle_shape(int idx, ShapeHandle shape) { + if (!input_handle_shape_[idx].SameHandle(shape)) { + input_handle_shape_[idx] = shape; + return true; + } + return false; + } + + // Set the type corresponding to the resource in position idx. This requires + // idx to be in the [0, num_inputs) range. Returns true iff the stored type + // has been updated. + bool set_input_handle_dtype(int idx, DataType dtype) { + if (input_handle_dtype_[idx] != dtype) { + input_handle_dtype_[idx] = dtype; + return true; + } + return false; + } ShapeHandle input_handle_shape(int idx); DataType input_handle_dtype(int idx) const { return input_handle_dtype_[idx]; } - void set_output_handle_shape(int idx, ShapeHandle shape) { - output_handle_shape_[idx] = shape; + + // Set the shape corresponding to the resource in position idx. This requires + // idx to be in the [0, num_outputs) range. + // Returns true iff the stored shape has been updated with a different handle. + bool set_output_handle_shape(int idx, ShapeHandle shape) { + if (!output_handle_shape_[idx].SameHandle(shape)) { + output_handle_shape_[idx] = shape; + return true; + } + return false; } - void set_output_handle_dtype(int idx, DataType dtype) { - output_handle_dtype_[idx] = dtype; + + // Set the type corresponding to the resource in position idx. This requires + // idx to be in the [0, num_outputs) range. Returns true iff the stored type + // has been updated. + bool set_output_handle_dtype(int idx, DataType dtype) { + if (output_handle_dtype_[idx] != dtype) { + output_handle_dtype_[idx] = dtype; + return true; + } + return false; } ShapeHandle output_handle_shape(int idx) const { return output_handle_shape_[idx]; diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index d078d9af09e..e784c2df443 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -50,11 +50,13 @@ cc_test( args = ["--heap_check=local"], # The GPU tracer leaks memory deps = [ ":graph_properties", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:grappler_item_builder", "//tensorflow/core/grappler/clusters:single_machine", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 06e91af2c2a..ad8f4f3f7cc 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" +#include +#include +#include #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -31,6 +34,71 @@ Status GraphProperties::InferStatically() { Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); + // List the resources and the nodes using them + std::unordered_map> resources; + for (const Node* const node : graph.nodes()) { + for (int i = 0; i < node->num_inputs(); ++i) { + if (node->input_type(i) == DataType::DT_RESOURCE) { + const Node* resource; + TF_CHECK_OK(node->input_node(i, &resource)); + resources[resource].insert(node); + } + } + } + + // If we found a resource, try to propagate the shapes through it. + bool done = true; + do { + std::queue new_shapes; + for (const auto& resource_data : resources) { + const Node* qnode = resource_data.first; + StringPiece type(qnode->type_string()); + if (!type.ends_with("QueueV2")) { + continue; + } + auto qctx = shape_refiner.GetContext(qnode); + if (!qctx) { + continue; + } + shape_inference::ShapeHandle data_shp = qctx->output_handle_shape(0); + if (qctx->FullyDefined(data_shp)) { + continue; + } + + for (const auto& node : resource_data.second) { + auto ctx = shape_refiner.GetContext(node); + if (!ctx) { + continue; + } + if (node->type_string().find("Enqueue") != std::string::npos) { + if (ctx->num_inputs() == 2) { + const DataType dtype = node->input_type(1); + shape_inference::ShapeHandle shp = ctx->input(1); + shape_inference::ShapeHandle refined; + TF_RETURN_IF_ERROR(qctx->Merge(shp, data_shp, &refined)); + if (qctx->set_output_handle_shape(0, refined) || + qctx->set_output_handle_dtype(0, dtype)) { + new_shapes.push(qnode); + } + } + } + } + } + // Propagate the shapes in the transitive fan-out of the queue. + done = new_shapes.empty(); + while (!new_shapes.empty()) { + const Node* n = new_shapes.front(); + new_shapes.pop(); + for (const Node* fanout : n->out_nodes()) { + bool updated = false; + TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated)); + if (updated) { + new_shapes.push(fanout); + } + } + } + } while (!done); + for (const Node* const node : graph.nodes()) { VLOG(1) << " " << node->name(); auto ctx = shape_refiner.GetContext(node); diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 32683644fbb..1eff52ba0e6 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -129,6 +132,76 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { } } +TEST_F(GraphPropertiesTest, VarHandles) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp") + .Attr("dtype", DT_FLOAT) + .Attr("shape", TensorShape({3, 7})) + .Finalize(item.graph.add_node())); + + TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp") + .Attr("dtype", DT_FLOAT) + .Input("Var", 0, DT_RESOURCE) + .Finalize(item.graph.add_node())); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + const auto props = properties.GetOutputProperties("VarRead"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(3, prop.shape().dim(0).size()); + EXPECT_EQ(7, prop.shape().dim(1).size()); +} + +TEST_F(GraphPropertiesTest, Queues) { + // Create a graph with known input shapes, and propagate the shapes through a + // couple of queues. + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + + auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}); + Output rnd = + ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT); + Output square1 = ops::Square(root.WithOpName("Square1"), rnd); + auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1}); + auto dequeue1 = + ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); + + auto q2 = + ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT}); + Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]); + auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2}); + auto dequeue2 = + ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT}); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + const auto props1 = properties.GetOutputProperties("Dequeue1"); + EXPECT_EQ(1, props1.size()); + const OpInfo::TensorProperties& prop1 = props1[0]; + EXPECT_EQ(DT_FLOAT, prop1.dtype()); + EXPECT_FALSE(prop1.shape().unknown_rank()); + EXPECT_EQ(2, prop1.shape().dim_size()); + EXPECT_EQ(3, prop1.shape().dim(0).size()); + EXPECT_EQ(7, prop1.shape().dim(1).size()); + + const auto props2 = properties.GetOutputProperties("Dequeue2"); + EXPECT_EQ(1, props2.size()); + const OpInfo::TensorProperties& prop2 = props2[0]; + EXPECT_EQ(DT_FLOAT, prop2.dtype()); + EXPECT_FALSE(prop2.shape().unknown_rank()); + EXPECT_EQ(2, prop2.shape().dim_size()); + EXPECT_EQ(3, prop2.shape().dim(0).size()); + EXPECT_EQ(7, prop2.shape().dim(1).size()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f82e9d1eb76..f35a1bb6489 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -623,7 +623,17 @@ REGISTER_OP("QueueDequeueV2") .Output("components: component_types") .Attr("component_types: list(type) >= 1") .Attr("timeout_ms: int = -1") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + if (c->num_outputs() == 1) { + c->set_output(0, c->input_handle_shape(0)); + } else { + // TODO(vrv): handle the case of multiple outputs. + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); + } + } + return Status::OK(); + }) .Doc(R"doc( Dequeues a tuple of one or more tensors from the given queue.