Fix shape inference for outside_compilation clusters that include cycles.

PiperOrigin-RevId: 192637289
This commit is contained in:
A. Unique TensorFlower 2018-04-12 11:04:55 -07:00 committed by TensorFlower Gardener
parent 151c31ce75
commit dc2d1c297a
4 changed files with 228 additions and 14 deletions

View File

@ -183,6 +183,13 @@ cc_library(
],
)
cc_library(
name = "shape_inference_helpers",
srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"],
deps = ["//tensorflow/core:graph"],
)
# Internal targets below this point.
cc_library(
@ -293,6 +300,7 @@ cc_library(
deps = [
":common",
":graph_to_functiondef",
":shape_inference_helpers",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -36,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
@ -576,7 +578,8 @@ class Encapsulator {
// satisfied, e.g., because send_node depends on a node that doesn't have a
// registered shape inference function.
Status DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
@ -599,7 +602,7 @@ class Encapsulator {
// to nodes in pruned_graph.
Status MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);
@ -1712,9 +1715,13 @@ namespace {
// matter because it will only be used subsequently for shape inference. (It
// would be possible to add a switch statement over data_type to create a value
// for the constant, but that would entail maintaining the logic as new types
// are added, and is not necessary.)
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
Graph* graph_out) {
// are added, and is not necessary.) If the node being replaced was within a
// control flow frame, adds appropriate Enter nodes so that the use of the Const
// is well-formed.
Node* AddDummyShapedNode(const Node* src_node, int src_port,
const std::vector<ControlFlowInfo>& control_flow_info,
const TensorShapeProto& shape, Graph* graph_out) {
DataType data_type = src_node->output_type(src_port);
TensorProto dummy_proto;
dummy_proto.set_dtype(data_type);
*dummy_proto.mutable_tensor_shape() = shape;
@ -1725,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
options.op_registry());
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
return options.FinalizeBuilder(&node_builder);
Node* node = options.FinalizeBuilder(&node_builder);
// Add any Enter nodes required to bring the constant to the correct control
// flow frame.
while (!control_flow_info[src_node->id()].frame_name.empty()) {
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
options.op_registry());
enter_builder.Attr("frame_name",
control_flow_info[src_node->id()].frame_name);
enter_builder.Attr("is_constant", true);
enter_builder.Input(node, 0);
Node* enter_node = options.FinalizeBuilder(&enter_builder);
// Adopt the new Enter node as the value in the current frame.
node = enter_node;
// Recurse to the parent frame to see if more Enter nodes need to be added.
src_node = control_flow_info[src_node->id()].parent_frame;
}
return node;
}
// Adds a copy of node_in to graph_out and adds the mapping to
@ -1767,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph(
}
}
}
// Work around the fact that Enter nodes refuse to propagate shape information
// unless they are marked loop invariant. Since we are never going to execute
// this graph, marking them all loop invariant is fine.
if (node_out->type_string() == "Enter") {
node_out->ClearAttr("is_constant");
node_out->AddAttr("is_constant", true);
}
return Status::OK();
}
} // namespace
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
std::unique_ptr<Graph>* graph_out) {
// Get the control flow structure of the input graph so we can build
// well-formed output graphs.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info));
// Maps from nodes in graph_in to nodes in graph_out.
//
// When an edge has fully defined shape the source node in graph_in is
@ -1802,7 +1838,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
// We don't use the standard ReverseDFS because we want to cut off traversal
// whenever we find an output with fully defined shape.
// TODO(misard) make this work properly in the presence of control flow.
struct Work {
Node* node;
bool leave; // Are we entering or leaving node?
@ -1840,8 +1875,9 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
dummy_node_images[src_node] = AddDummyShapedNode(
src_node->output_type(src_port), proto, graph_out->get());
dummy_node_images[src_node] =
AddDummyShapedNode(src_node, src_port, control_flow_info,
proto, graph_out->get());
}
// The final input to the send node is the dynamic key, which we
// don't include in the static shapes.
@ -1889,6 +1925,38 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
}
}
for (const auto edge : back_edge_helper.RemovedEdges()) {
if (copied_node_images.find(edge.dst) != copied_node_images.end()) {
// The destination of this back edge was added to the inference graph, so
// fix it up.
Node* dst = copied_node_images[edge.dst];
if (dst->type_string() != "Merge") {
return errors::InvalidArgument(
"outside_compilation cluster contains a back-edge to node ",
dst->name(), " of type ", dst->type_string(),
". The analysis pass only supports back-edges to Merge nodes.");
}
const Edge* existing_input_edge;
if (edge.dst_input != 1 || dst->num_inputs() != 2 ||
!dst->input_edge(0, &existing_input_edge).ok()) {
// TODO(misard) if we see graphs built with a different structure, relax
// this constraint. Leaving it here for now to avoid writing unnecessary
// complex code since we believe graphs generated by front ends all have
// the back edge as the second input to the merge node.
return errors::Internal(
"Internal assumption failed while rewriting an outside_compilation "
"cluster that contains a while loop. Logic assumes back-edge is to "
"port 1 of a 2-input "
"Merge node.");
}
// Connect the existing edge to both inputs of the Merge node so that the
// graph will be well-formed.
(*graph_out)
->AddEdge(existing_input_edge->src(),
existing_input_edge->src_output(), dst, edge.dst_input);
}
}
return Status::OK();
}
@ -1956,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
Status Encapsulator::MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
@ -1978,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
// nodes, inlining any functions as needed.
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
graph, send_from_host_nodes, pruned_graph, node_images, library));
FixupSourceAndSinkEdges(pruned_graph->get());
// Remove back edges from any cycles in the pruned graph to simplify shape
// inference traversal. They will be fixed up in the per-subgraph shape
// inference graphs stored in the function library.
TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get()));
// Perform shape inference on the pruned graph.
shape_refiner->set_require_shape_inference_fns(false);
FixupSourceAndSinkEdges(pruned_graph->get());
std::vector<Node*> post_order;
GetReversePostOrder(*(*pruned_graph), &post_order);
for (auto node : post_order) {
@ -1999,11 +2072,13 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
Graph* graph_out, FunctionLibraryDefinition* library) {
BackEdgeHelper back_edge_helper;
std::unique_ptr<Graph> pruned_graph;
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
std::unordered_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
*graph_out, &pruned_graph, &back_edge_helper, &shape_refiner,
&node_images, library));
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
@ -2033,7 +2108,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
std::unique_ptr<Graph> graph;
if (send_node != nullptr) {
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
*pruned_graph, shape_refiner, recv_at_host_names,
*pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names,
node_images[send_node], library, &static_shape, &graph));
if (graph == nullptr) {
VLOG(2) << "Send node " << send_node->name() << " shapes";

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains helpers for use in shape inference.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include <vector>
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
Status BackEdgeHelper::Remove(Graph* graph) {
if (graph_ != nullptr) {
return errors::Internal("BackEdgeHelper duplicate call to Remove.");
}
graph_ = graph;
for (Node* n : graph_->nodes()) {
if (n->IsMerge()) {
for (const Edge* e : n->in_edges()) {
if (e->src()->IsNextIteration()) {
back_edges_.push_back(
BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()});
}
}
}
}
for (const BackEdge& be : back_edges_) {
graph_->RemoveEdge(be.edge);
}
return Status::OK();
}
const std::vector<BackEdgeHelper::BackEdge>& BackEdgeHelper::RemovedEdges()
const {
return back_edges_;
}
Status BackEdgeHelper::Replace() {
if (graph_ == nullptr) {
return errors::Internal("BackEdgeHelper Replace called before Remove.");
}
if (replaced_) {
return errors::Internal("BackEdgeHelper Replace called more than once.");
}
replaced_ = true;
for (const BackEdge& be : back_edges_) {
graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input);
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
#include <vector>
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Helper class to temporarily remove, then replace, the back edges in a
// graph. Simple algorithms for shape inference don't work with cycles, and this
// class can be used to remove cycles before running inference and replace them
// after. Correct usage requires exactly one call to Remove(), followed by any
// number of calls to RemovedEdges() and at most one call to Replace(). The call
// to Replace() is optional if the graph will be discarded without being
// executed, e.g., if it is being used purely for a shape inference pass.
class BackEdgeHelper {
public:
struct BackEdge {
const Edge* edge;
Node* src;
int src_output;
Node* dst;
int dst_input;
};
BackEdgeHelper() = default;
// Disallows copy and assign.
BackEdgeHelper(const BackEdgeHelper& other) = delete;
BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete;
// Temporarily removes all the back edges in graph.
Status Remove(Graph* graph);
// Gets the list of removed edges.
const std::vector<BackEdge>& RemovedEdges() const;
// Replaces the back edges removed by a prior call to Remove.
Status Replace();
private:
Graph* graph_ = nullptr; // not owned
std::vector<BackEdge> back_edges_;
// Set once Replace has been called.
bool replaced_ = false;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_