Fix shape inference for outside_compilation clusters that include cycles.
PiperOrigin-RevId: 192637289
This commit is contained in:
parent
151c31ce75
commit
dc2d1c297a
@ -183,6 +183,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "shape_inference_helpers",
|
||||
srcs = ["shape_inference_helpers.cc"],
|
||||
hdrs = ["shape_inference_helpers.h"],
|
||||
deps = ["//tensorflow/core:graph"],
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
@ -293,6 +300,7 @@ cc_library(
|
||||
deps = [
|
||||
":common",
|
||||
":graph_to_functiondef",
|
||||
":shape_inference_helpers",
|
||||
":union_find",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -36,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
@ -576,7 +578,8 @@ class Encapsulator {
|
||||
// satisfied, e.g., because send_node depends on a node that doesn't have a
|
||||
// registered shape inference function.
|
||||
Status DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||
FunctionLibraryDefinition* library,
|
||||
std::vector<TensorShapeProto>* static_shape_out,
|
||||
@ -599,7 +602,7 @@ class Encapsulator {
|
||||
// to nodes in pruned_graph.
|
||||
Status MakeGraphForOutsideCompilationSends(
|
||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||
ShapeRefiner* shape_refiner,
|
||||
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
@ -1712,9 +1715,13 @@ namespace {
|
||||
// matter because it will only be used subsequently for shape inference. (It
|
||||
// would be possible to add a switch statement over data_type to create a value
|
||||
// for the constant, but that would entail maintaining the logic as new types
|
||||
// are added, and is not necessary.)
|
||||
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
||||
Graph* graph_out) {
|
||||
// are added, and is not necessary.) If the node being replaced was within a
|
||||
// control flow frame, adds appropriate Enter nodes so that the use of the Const
|
||||
// is well-formed.
|
||||
Node* AddDummyShapedNode(const Node* src_node, int src_port,
|
||||
const std::vector<ControlFlowInfo>& control_flow_info,
|
||||
const TensorShapeProto& shape, Graph* graph_out) {
|
||||
DataType data_type = src_node->output_type(src_port);
|
||||
TensorProto dummy_proto;
|
||||
dummy_proto.set_dtype(data_type);
|
||||
*dummy_proto.mutable_tensor_shape() = shape;
|
||||
@ -1725,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
||||
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
|
||||
options.op_registry());
|
||||
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
|
||||
return options.FinalizeBuilder(&node_builder);
|
||||
Node* node = options.FinalizeBuilder(&node_builder);
|
||||
// Add any Enter nodes required to bring the constant to the correct control
|
||||
// flow frame.
|
||||
while (!control_flow_info[src_node->id()].frame_name.empty()) {
|
||||
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
|
||||
options.op_registry());
|
||||
enter_builder.Attr("frame_name",
|
||||
control_flow_info[src_node->id()].frame_name);
|
||||
enter_builder.Attr("is_constant", true);
|
||||
enter_builder.Input(node, 0);
|
||||
Node* enter_node = options.FinalizeBuilder(&enter_builder);
|
||||
// Adopt the new Enter node as the value in the current frame.
|
||||
node = enter_node;
|
||||
// Recurse to the parent frame to see if more Enter nodes need to be added.
|
||||
src_node = control_flow_info[src_node->id()].parent_frame;
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
// Adds a copy of node_in to graph_out and adds the mapping to
|
||||
@ -1767,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Work around the fact that Enter nodes refuse to propagate shape information
|
||||
// unless they are marked loop invariant. Since we are never going to execute
|
||||
// this graph, marking them all loop invariant is fine.
|
||||
if (node_out->type_string() == "Enter") {
|
||||
node_out->ClearAttr("is_constant");
|
||||
node_out->AddAttr("is_constant", true);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||
FunctionLibraryDefinition* library,
|
||||
std::vector<TensorShapeProto>* static_shape_out,
|
||||
std::unique_ptr<Graph>* graph_out) {
|
||||
// Get the control flow structure of the input graph so we can build
|
||||
// well-formed output graphs.
|
||||
std::vector<ControlFlowInfo> control_flow_info;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info));
|
||||
|
||||
// Maps from nodes in graph_in to nodes in graph_out.
|
||||
//
|
||||
// When an edge has fully defined shape the source node in graph_in is
|
||||
@ -1802,7 +1838,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
|
||||
// We don't use the standard ReverseDFS because we want to cut off traversal
|
||||
// whenever we find an output with fully defined shape.
|
||||
// TODO(misard) make this work properly in the presence of control flow.
|
||||
struct Work {
|
||||
Node* node;
|
||||
bool leave; // Are we entering or leaving node?
|
||||
@ -1840,8 +1875,9 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
TensorShapeProto proto;
|
||||
context->ShapeHandleToProto(shape, &proto);
|
||||
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
|
||||
dummy_node_images[src_node] = AddDummyShapedNode(
|
||||
src_node->output_type(src_port), proto, graph_out->get());
|
||||
dummy_node_images[src_node] =
|
||||
AddDummyShapedNode(src_node, src_port, control_flow_info,
|
||||
proto, graph_out->get());
|
||||
}
|
||||
// The final input to the send node is the dynamic key, which we
|
||||
// don't include in the static shapes.
|
||||
@ -1889,6 +1925,38 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto edge : back_edge_helper.RemovedEdges()) {
|
||||
if (copied_node_images.find(edge.dst) != copied_node_images.end()) {
|
||||
// The destination of this back edge was added to the inference graph, so
|
||||
// fix it up.
|
||||
Node* dst = copied_node_images[edge.dst];
|
||||
if (dst->type_string() != "Merge") {
|
||||
return errors::InvalidArgument(
|
||||
"outside_compilation cluster contains a back-edge to node ",
|
||||
dst->name(), " of type ", dst->type_string(),
|
||||
". The analysis pass only supports back-edges to Merge nodes.");
|
||||
}
|
||||
const Edge* existing_input_edge;
|
||||
if (edge.dst_input != 1 || dst->num_inputs() != 2 ||
|
||||
!dst->input_edge(0, &existing_input_edge).ok()) {
|
||||
// TODO(misard) if we see graphs built with a different structure, relax
|
||||
// this constraint. Leaving it here for now to avoid writing unnecessary
|
||||
// complex code since we believe graphs generated by front ends all have
|
||||
// the back edge as the second input to the merge node.
|
||||
return errors::Internal(
|
||||
"Internal assumption failed while rewriting an outside_compilation "
|
||||
"cluster that contains a while loop. Logic assumes back-edge is to "
|
||||
"port 1 of a 2-input "
|
||||
"Merge node.");
|
||||
}
|
||||
// Connect the existing edge to both inputs of the Merge node so that the
|
||||
// graph will be well-formed.
|
||||
(*graph_out)
|
||||
->AddEdge(existing_input_edge->src(),
|
||||
existing_input_edge->src_output(), dst, edge.dst_input);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1956,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
|
||||
|
||||
Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||
ShapeRefiner* shape_refiner,
|
||||
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
|
||||
@ -1978,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||
// nodes, inlining any functions as needed.
|
||||
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
|
||||
graph, send_from_host_nodes, pruned_graph, node_images, library));
|
||||
FixupSourceAndSinkEdges(pruned_graph->get());
|
||||
|
||||
// Remove back edges from any cycles in the pruned graph to simplify shape
|
||||
// inference traversal. They will be fixed up in the per-subgraph shape
|
||||
// inference graphs stored in the function library.
|
||||
TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get()));
|
||||
|
||||
// Perform shape inference on the pruned graph.
|
||||
shape_refiner->set_require_shape_inference_fns(false);
|
||||
FixupSourceAndSinkEdges(pruned_graph->get());
|
||||
std::vector<Node*> post_order;
|
||||
GetReversePostOrder(*(*pruned_graph), &post_order);
|
||||
for (auto node : post_order) {
|
||||
@ -1999,11 +2072,13 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||
|
||||
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
||||
Graph* graph_out, FunctionLibraryDefinition* library) {
|
||||
BackEdgeHelper back_edge_helper;
|
||||
std::unique_ptr<Graph> pruned_graph;
|
||||
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
|
||||
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
|
||||
*graph_out, &pruned_graph, &back_edge_helper, &shape_refiner,
|
||||
&node_images, library));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
|
||||
@ -2033,7 +2108,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
||||
std::unique_ptr<Graph> graph;
|
||||
if (send_node != nullptr) {
|
||||
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
*pruned_graph, shape_refiner, recv_at_host_names,
|
||||
*pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names,
|
||||
node_images[send_node], library, &static_shape, &graph));
|
||||
if (graph == nullptr) {
|
||||
VLOG(2) << "Send node " << send_node->name() << " shapes";
|
||||
|
66
tensorflow/compiler/jit/shape_inference_helpers.cc
Normal file
66
tensorflow/compiler/jit/shape_inference_helpers.cc
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Contains helpers for use in shape inference.
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status BackEdgeHelper::Remove(Graph* graph) {
|
||||
if (graph_ != nullptr) {
|
||||
return errors::Internal("BackEdgeHelper duplicate call to Remove.");
|
||||
}
|
||||
graph_ = graph;
|
||||
for (Node* n : graph_->nodes()) {
|
||||
if (n->IsMerge()) {
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->src()->IsNextIteration()) {
|
||||
back_edges_.push_back(
|
||||
BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const BackEdge& be : back_edges_) {
|
||||
graph_->RemoveEdge(be.edge);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<BackEdgeHelper::BackEdge>& BackEdgeHelper::RemovedEdges()
|
||||
const {
|
||||
return back_edges_;
|
||||
}
|
||||
|
||||
Status BackEdgeHelper::Replace() {
|
||||
if (graph_ == nullptr) {
|
||||
return errors::Internal("BackEdgeHelper Replace called before Remove.");
|
||||
}
|
||||
if (replaced_) {
|
||||
return errors::Internal("BackEdgeHelper Replace called more than once.");
|
||||
}
|
||||
replaced_ = true;
|
||||
for (const BackEdge& be : back_edges_) {
|
||||
graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/compiler/jit/shape_inference_helpers.h
Normal file
65
tensorflow/compiler/jit/shape_inference_helpers.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Helper class to temporarily remove, then replace, the back edges in a
|
||||
// graph. Simple algorithms for shape inference don't work with cycles, and this
|
||||
// class can be used to remove cycles before running inference and replace them
|
||||
// after. Correct usage requires exactly one call to Remove(), followed by any
|
||||
// number of calls to RemovedEdges() and at most one call to Replace(). The call
|
||||
// to Replace() is optional if the graph will be discarded without being
|
||||
// executed, e.g., if it is being used purely for a shape inference pass.
|
||||
class BackEdgeHelper {
|
||||
public:
|
||||
struct BackEdge {
|
||||
const Edge* edge;
|
||||
Node* src;
|
||||
int src_output;
|
||||
Node* dst;
|
||||
int dst_input;
|
||||
};
|
||||
|
||||
BackEdgeHelper() = default;
|
||||
// Disallows copy and assign.
|
||||
BackEdgeHelper(const BackEdgeHelper& other) = delete;
|
||||
BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete;
|
||||
|
||||
// Temporarily removes all the back edges in graph.
|
||||
Status Remove(Graph* graph);
|
||||
|
||||
// Gets the list of removed edges.
|
||||
const std::vector<BackEdge>& RemovedEdges() const;
|
||||
|
||||
// Replaces the back edges removed by a prior call to Remove.
|
||||
Status Replace();
|
||||
|
||||
private:
|
||||
Graph* graph_ = nullptr; // not owned
|
||||
std::vector<BackEdge> back_edges_;
|
||||
// Set once Replace has been called.
|
||||
bool replaced_ = false;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
Loading…
Reference in New Issue
Block a user