Enable grappler to propagate shapes through queues.

Change: 154789133
This commit is contained in:
Benoit Steiner 2017-05-01 17:41:00 -08:00 committed by TensorFlower Gardener
parent aebaf317ce
commit 0287e879ac
8 changed files with 402 additions and 81 deletions

View File

@ -89,9 +89,6 @@ Status ShapeRefiner::AddNode(const Node* node) {
// This needs to be filled in with real data in a second pass.
std::vector<const Tensor*> input_tensors(node->num_inputs());
std::vector<Tensor> real_tensors(node->num_inputs());
std::vector<bool> attempted_materialization(node->num_inputs());
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
@ -104,78 +101,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
}
// Run the shape inference function, and return if there was an error.
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
} else {
TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
}
// We must run the shape function repeatedly, in case users write
// shape functions where they only conditionally call input_tensor()
// based on the values of another input tensor.
bool rerun_shape_fn;
do {
// If the result of running shape inference would have benefitted
// from knowing the values of input tensors, try to materialize
// the results of those tensors, and then run the shape inference
// function again using those known tensors.
rerun_shape_fn = false;
// NOTE: It is possible to batch the extraction and
// materialization of inputs, instead of materializing one input
// at a time like we do below. If input-at-a-time computation
// becomes a bottleneck, we could separate ExtractConstantSubgraph
// into two functions: one that returns true if an input is
// derivable from constants, and another function that extracts
// the subgraph for multiple target nodes and executes the whole
// subgraph once.
for (int i = 0; i < c->num_inputs(); ++i) {
if (!c->requested_input_tensor(i)) {
continue;
}
// Check if we have not already filled in the requested input,
// and if not, try to materialize the tensors.
if (!attempted_materialization[i]) {
attempted_materialization[i] = true;
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
}
}
if (c->requested_input_tensor_as_partial_shape(i) &&
!attempted_tensor_as_shape_conversion[i]) {
attempted_tensor_as_shape_conversion[i] = true;
if (i >= input_tensors_as_shapes.size()) {
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}
}
if (rerun_shape_fn) {
// We have more information about the shapes on this pass,
// so re-run shape inference.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get()));
} else {
TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get()));
}
}
} while (rerun_shape_fn);
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, c.get()));
// Store the resulting InferenceContext object in the map.
node_to_context_[node].swap(c);
@ -211,6 +137,71 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
return Status::OK();
}
Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
*refined = true;
return AddNode(node);
}
InferenceContext* node_context = it->second.get();
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have update the node input shapes.
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
Node* input = e->src();
auto iter = node_to_context_.find(input);
if (iter == node_to_context_.end()) {
return errors::FailedPrecondition(
"Input ", e->dst_input(), " ('", input->name(), "') for '",
node->name(), "' was not previously added to ShapeRefiner.");
}
InferenceContext* c = iter->second.get();
DCHECK_GE(e->dst_input(), 0);
if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
*refined = true;
}
// Also propagate handle shape and dtype of edges which are carrying
// resource handles.
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
if (node_context->set_input_handle_dtype(
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
*refined = true;
}
if (node_context->set_input_handle_shape(
e->dst_input(), c->output_handle_shape(e->src_output()))) {
*refined = true;
}
}
}
if (!*refined) {
// No input shape has changed, we're done
return Status::OK();
}
// Get and run the shape function for this node to update the shapes of the
// outputs.
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
if (op_reg_data->shape_inference_fn == nullptr &&
require_shape_inference_fns_) {
return errors::InvalidArgument(
"No shape inference function exists for op '", node->type_string(),
"', did you forget to define it?");
}
if (!op_reg_data->shape_inference_fn) {
// There is nothing more we can infer
return Status::OK();
}
return RunShapeFn(node, op_reg_data, node_context);
}
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
int dst_idx, bool* evaluated,
Tensor* result) {
@ -463,4 +454,91 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
return Status::OK();
}
Status ShapeRefiner::RunShapeFn(const Node* node,
const OpRegistrationData* op_reg_data,
shape_inference::InferenceContext* c) {
// This will be filled in with real data in a second pass.
std::vector<const Tensor*> input_tensors(node->num_inputs());
std::vector<Tensor> real_tensors(node->num_inputs());
std::vector<bool> attempted_materialization(node->num_inputs());
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
std::vector<ShapeHandle> input_tensors_as_shapes;
// Run the shape inference function, and return if there was an error.
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
} else {
TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
}
// We must run the shape function repeatedly, in case users write
// shape functions where they only conditionally call input_tensor()
// based on the values of another input tensor.
bool rerun_shape_fn;
do {
// If the result of running shape inference would have benefitted
// from knowing the values of input tensors, try to materialize
// the results of those tensors, and then run the shape inference
// function again using those known tensors.
rerun_shape_fn = false;
// NOTE: It is possible to batch the extraction and
// materialization of inputs, instead of materializing one input
// at a time like we do below. If input-at-a-time computation
// becomes a bottleneck, we could separate ExtractConstantSubgraph
// into two functions: one that returns true if an input is
// derivable from constants, and another function that extracts
// the subgraph for multiple target nodes and executes the whole
// subgraph once.
for (int i = 0; i < c->num_inputs(); ++i) {
if (!c->requested_input_tensor(i)) {
continue;
}
// Check if we have not already filled in the requested input,
// and if not, try to materialize the tensors.
if (!attempted_materialization[i]) {
attempted_materialization[i] = true;
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
}
}
if (c->requested_input_tensor_as_partial_shape(i) &&
!attempted_tensor_as_shape_conversion[i]) {
attempted_tensor_as_shape_conversion[i] = true;
if (i >= input_tensors_as_shapes.size()) {
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}
}
if (rerun_shape_fn) {
// We have more information about the shapes on this pass,
// so re-run shape inference.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
if (op_reg_data->shape_inference_fn) {
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c));
} else {
TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c));
}
}
} while (rerun_shape_fn);
return Status::OK();
}
} // namespace tensorflow

View File

@ -55,6 +55,11 @@ class ShapeRefiner {
Status SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape);
// Update the input shapes of node in case the shapes of the fan-ins of 'node'
// have themselves been modified (For example, in case of incremental shape
// refinement). Sets refined to true if any of the node shape has changed.
Status UpdateNode(const Node* node, bool* refined);
// Returns the InferenceContext for 'node', if present.
shape_inference::InferenceContext* GetContext(const Node* node) const {
auto it = node_to_context_.find(node);
@ -108,6 +113,9 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
shape_inference::InferenceContext* c);
int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_;

View File

@ -768,5 +768,38 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
m.AddNode(result).error_message());
}
TEST(ShapeRefinerTest, IncrementalUpdates) {
Scope root = Scope::NewRootScope();
Graph* g = root.graph();
Node* queue;
TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2")
.Attr("component_types", {DT_FLOAT})
.Finalize(g, &queue));
Node* dequeue;
TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2")
.Attr("component_types", {DT_FLOAT})
.Input(queue)
.Finalize(g, &dequeue));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(queue));
TF_ASSERT_OK(m.AddNode(dequeue));
// At this point, the shapes of the dequeued tensor are unknown.
shape_inference::InferenceContext* ctx = m.GetContext(dequeue);
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
// Inject a shape, and incrementally propagate it to the dequeue op.
ctx = m.GetContext(queue);
shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
ctx->set_output_handle_shape(0, shp);
ctx->set_output_handle_dtype(0, DT_FLOAT);
bool refined = false;
TF_ASSERT_OK(m.UpdateNode(dequeue, &refined));
EXPECT_TRUE(refined);
ctx = m.GetContext(dequeue);
EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0)));
}
} // namespace
} // namespace tensorflow

View File

@ -191,6 +191,17 @@ class InferenceContext {
return s;
}
// Set the shape of the input in position idx. This requires idx to be in the
// [0, num_inputs) range. Returns true iff the stored input shape has been
// updated with a different handle.
bool set_input(int idx, ShapeHandle shape) {
if (!inputs_[idx].SameHandle(shape)) {
inputs_[idx] = shape;
return true;
} else {
return false;
}
}
ShapeHandle input(int64 idx) const { return inputs_[idx]; }
Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
int num_inputs() const { return inputs_.size(); }
@ -430,15 +441,53 @@ class InferenceContext {
// and dtypes of tensors which can be accessed via the handle. These methods
// propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE.
// Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored shape
// has been updated with a different handle.
bool set_input_handle_shape(int idx, ShapeHandle shape) {
if (!input_handle_shape_[idx].SameHandle(shape)) {
input_handle_shape_[idx] = shape;
return true;
}
return false;
}
// Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored type
// has been updated.
bool set_input_handle_dtype(int idx, DataType dtype) {
if (input_handle_dtype_[idx] != dtype) {
input_handle_dtype_[idx] = dtype;
return true;
}
return false;
}
ShapeHandle input_handle_shape(int idx);
DataType input_handle_dtype(int idx) const {
return input_handle_dtype_[idx];
}
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape;
// Set the shape corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range.
// Returns true iff the stored shape has been updated with a different handle.
bool set_output_handle_shape(int idx, ShapeHandle shape) {
if (!output_handle_shape_[idx].SameHandle(shape)) {
output_handle_shape_[idx] = shape;
return true;
}
return false;
}
void set_output_handle_dtype(int idx, DataType dtype) {
output_handle_dtype_[idx] = dtype;
// Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range. Returns true iff the stored type
// has been updated.
bool set_output_handle_dtype(int idx, DataType dtype) {
if (output_handle_dtype_[idx] != dtype) {
output_handle_dtype_[idx] = dtype;
return true;
}
return false;
}
ShapeHandle output_handle_shape(int idx) const {
return output_handle_shape_[idx];

View File

@ -50,11 +50,13 @@ cc_test(
args = ["--heap_check=local"], # The GPU tracer leaks memory
deps = [
":graph_properties",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@ -31,6 +34,71 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
// List the resources and the nodes using them
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) == DataType::DT_RESOURCE) {
const Node* resource;
TF_CHECK_OK(node->input_node(i, &resource));
resources[resource].insert(node);
}
}
}
// If we found a resource, try to propagate the shapes through it.
bool done = true;
do {
std::queue<const Node*> new_shapes;
for (const auto& resource_data : resources) {
const Node* qnode = resource_data.first;
StringPiece type(qnode->type_string());
if (!type.ends_with("QueueV2")) {
continue;
}
auto qctx = shape_refiner.GetContext(qnode);
if (!qctx) {
continue;
}
shape_inference::ShapeHandle data_shp = qctx->output_handle_shape(0);
if (qctx->FullyDefined(data_shp)) {
continue;
}
for (const auto& node : resource_data.second) {
auto ctx = shape_refiner.GetContext(node);
if (!ctx) {
continue;
}
if (node->type_string().find("Enqueue") != std::string::npos) {
if (ctx->num_inputs() == 2) {
const DataType dtype = node->input_type(1);
shape_inference::ShapeHandle shp = ctx->input(1);
shape_inference::ShapeHandle refined;
TF_RETURN_IF_ERROR(qctx->Merge(shp, data_shp, &refined));
if (qctx->set_output_handle_shape(0, refined) ||
qctx->set_output_handle_dtype(0, dtype)) {
new_shapes.push(qnode);
}
}
}
}
}
// Propagate the shapes in the transitive fan-out of the queue.
done = new_shapes.empty();
while (!new_shapes.empty()) {
const Node* n = new_shapes.front();
new_shapes.pop();
for (const Node* fanout : n->out_nodes()) {
bool updated = false;
TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated));
if (updated) {
new_shapes.push(fanout);
}
}
}
} while (!done);
for (const Node* const node : graph.nodes()) {
VLOG(1) << "<Node> " << node->name();
auto ctx = shape_refiner.GetContext(node);

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@ -129,6 +132,76 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
}
}
TEST_F(GraphPropertiesTest, VarHandles) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
.Attr("dtype", DT_FLOAT)
.Attr("shape", TensorShape({3, 7}))
.Finalize(item.graph.add_node()));
TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
.Attr("dtype", DT_FLOAT)
.Input("Var", 0, DT_RESOURCE)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props = properties.GetOutputProperties("VarRead");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_FALSE(prop.shape().unknown_rank());
EXPECT_EQ(2, prop.shape().dim_size());
EXPECT_EQ(3, prop.shape().dim(0).size());
EXPECT_EQ(7, prop.shape().dim(1).size());
}
TEST_F(GraphPropertiesTest, Queues) {
// Create a graph with known input shapes, and propagate the shapes through a
// couple of queues.
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
Output rnd =
ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
auto dequeue1 =
ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
auto q2 =
ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
auto dequeue2 =
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
GrapplerItem item;
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props1 = properties.GetOutputProperties("Dequeue1");
EXPECT_EQ(1, props1.size());
const OpInfo::TensorProperties& prop1 = props1[0];
EXPECT_EQ(DT_FLOAT, prop1.dtype());
EXPECT_FALSE(prop1.shape().unknown_rank());
EXPECT_EQ(2, prop1.shape().dim_size());
EXPECT_EQ(3, prop1.shape().dim(0).size());
EXPECT_EQ(7, prop1.shape().dim(1).size());
const auto props2 = properties.GetOutputProperties("Dequeue2");
EXPECT_EQ(1, props2.size());
const OpInfo::TensorProperties& prop2 = props2[0];
EXPECT_EQ(DT_FLOAT, prop2.dtype());
EXPECT_FALSE(prop2.shape().unknown_rank());
EXPECT_EQ(2, prop2.shape().dim_size());
EXPECT_EQ(3, prop2.shape().dim(0).size());
EXPECT_EQ(7, prop2.shape().dim(1).size());
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -623,7 +623,17 @@ REGISTER_OP("QueueDequeueV2")
.Output("components: component_types")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.SetShapeFn([](InferenceContext* c) {
if (c->num_outputs() == 1) {
c->set_output(0, c->input_handle_shape(0));
} else {
// TODO(vrv): handle the case of multiple outputs.
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}
}
return Status::OK();
})
.Doc(R"doc(
Dequeues a tuple of one or more tensors from the given queue.