diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4cefc08645a..6edeb7047f9 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -183,6 +183,13 @@ cc_library( ], ) +cc_library( + name = "shape_inference_helpers", + srcs = ["shape_inference_helpers.cc"], + hdrs = ["shape_inference_helpers.h"], + deps = ["//tensorflow/core:graph"], +) + # Internal targets below this point. cc_library( @@ -293,6 +300,7 @@ cc_library( deps = [ ":common", ":graph_to_functiondef", + ":shape_inference_helpers", ":union_find", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index b04b333141a..9465385b585 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" @@ -576,7 +578,8 @@ class Encapsulator { // satisfied, e.g., because send_node depends on a node that doesn't have a // registered shape inference function. Status DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, @@ -599,7 +602,7 @@ class Encapsulator { // to nodes in pruned_graph. Status MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library); @@ -1712,9 +1715,13 @@ namespace { // matter because it will only be used subsequently for shape inference. (It // would be possible to add a switch statement over data_type to create a value // for the constant, but that would entail maintaining the logic as new types -// are added, and is not necessary.) -Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, - Graph* graph_out) { +// are added, and is not necessary.) If the node being replaced was within a +// control flow frame, adds appropriate Enter nodes so that the use of the Const +// is well-formed. +Node* AddDummyShapedNode(const Node* src_node, int src_port, + const std::vector& control_flow_info, + const TensorShapeProto& shape, Graph* graph_out) { + DataType data_type = src_node->output_type(src_port); TensorProto dummy_proto; dummy_proto.set_dtype(data_type); *dummy_proto.mutable_tensor_shape() = shape; @@ -1725,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", options.op_registry()); node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); - return options.FinalizeBuilder(&node_builder); + Node* node = options.FinalizeBuilder(&node_builder); + // Add any Enter nodes required to bring the constant to the correct control + // flow frame. + while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", + options.op_registry()); + enter_builder.Attr("frame_name", + control_flow_info[src_node->id()].frame_name); + enter_builder.Attr("is_constant", true); + enter_builder.Input(node, 0); + Node* enter_node = options.FinalizeBuilder(&enter_builder); + // Adopt the new Enter node as the value in the current frame. + node = enter_node; + // Recurse to the parent frame to see if more Enter nodes need to be added. + src_node = control_flow_info[src_node->id()].parent_frame; + } + return node; } // Adds a copy of node_in to graph_out and adds the mapping to @@ -1767,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph( } } } + // Work around the fact that Enter nodes refuse to propagate shape information + // unless they are marked loop invariant. Since we are never going to execute + // this graph, marking them all loop invariant is fine. + if (node_out->type_string() == "Enter") { + node_out->ClearAttr("is_constant"); + node_out->AddAttr("is_constant", true); + } return Status::OK(); } } // namespace Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, std::unique_ptr* graph_out) { + // Get the control flow structure of the input graph so we can build + // well-formed output graphs. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info)); + // Maps from nodes in graph_in to nodes in graph_out. // // When an edge has fully defined shape the source node in graph_in is @@ -1802,7 +1838,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // We don't use the standard ReverseDFS because we want to cut off traversal // whenever we find an output with fully defined shape. - // TODO(misard) make this work properly in the presence of control flow. struct Work { Node* node; bool leave; // Are we entering or leaving node? @@ -1840,8 +1875,9 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); if (dummy_node_images.find(src_node) == dummy_node_images.end()) { - dummy_node_images[src_node] = AddDummyShapedNode( - src_node->output_type(src_port), proto, graph_out->get()); + dummy_node_images[src_node] = + AddDummyShapedNode(src_node, src_port, control_flow_info, + proto, graph_out->get()); } // The final input to the send node is the dynamic key, which we // don't include in the static shapes. @@ -1889,6 +1925,38 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( } } + for (const auto edge : back_edge_helper.RemovedEdges()) { + if (copied_node_images.find(edge.dst) != copied_node_images.end()) { + // The destination of this back edge was added to the inference graph, so + // fix it up. + Node* dst = copied_node_images[edge.dst]; + if (dst->type_string() != "Merge") { + return errors::InvalidArgument( + "outside_compilation cluster contains a back-edge to node ", + dst->name(), " of type ", dst->type_string(), + ". The analysis pass only supports back-edges to Merge nodes."); + } + const Edge* existing_input_edge; + if (edge.dst_input != 1 || dst->num_inputs() != 2 || + !dst->input_edge(0, &existing_input_edge).ok()) { + // TODO(misard) if we see graphs built with a different structure, relax + // this constraint. Leaving it here for now to avoid writing unnecessary + // complex code since we believe graphs generated by front ends all have + // the back edge as the second input to the merge node. + return errors::Internal( + "Internal assumption failed while rewriting an outside_compilation " + "cluster that contains a while loop. Logic assumes back-edge is to " + "port 1 of a 2-input " + "Merge node."); + } + // Connect the existing edge to both inputs of the Merge node so that the + // graph will be well-formed. + (*graph_out) + ->AddEdge(existing_input_edge->src(), + existing_input_edge->src_output(), dst, edge.dst_input); + } + } + return Status::OK(); } @@ -1956,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( Status Encapsulator::MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library) { // Find all the send_from_host nodes in all subgraphs, to use as roots for the @@ -1978,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( // nodes, inlining any functions as needed. TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( graph, send_from_host_nodes, pruned_graph, node_images, library)); + FixupSourceAndSinkEdges(pruned_graph->get()); + + // Remove back edges from any cycles in the pruned graph to simplify shape + // inference traversal. They will be fixed up in the per-subgraph shape + // inference graphs stored in the function library. + TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get())); // Perform shape inference on the pruned graph. shape_refiner->set_require_shape_inference_fns(false); - FixupSourceAndSinkEdges(pruned_graph->get()); std::vector post_order; GetReversePostOrder(*(*pruned_graph), &post_order); for (auto node : post_order) { @@ -1999,11 +2072,13 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( Status Encapsulator::GetShapeInfoForOutsideCompilationSends( Graph* graph_out, FunctionLibraryDefinition* library) { + BackEdgeHelper back_edge_helper; std::unique_ptr pruned_graph; ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); std::unordered_map node_images; TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( - *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + *graph_out, &pruned_graph, &back_edge_helper, &shape_refiner, + &node_images, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", @@ -2033,7 +2108,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( std::unique_ptr graph; if (send_node != nullptr) { TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( - *pruned_graph, shape_refiner, recv_at_host_names, + *pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names, node_images[send_node], library, &static_shape, &graph)); if (graph == nullptr) { VLOG(2) << "Send node " << send_node->name() << " shapes"; diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc new file mode 100644 index 00000000000..d9cfa16526b --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains helpers for use in shape inference. + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +Status BackEdgeHelper::Remove(Graph* graph) { + if (graph_ != nullptr) { + return errors::Internal("BackEdgeHelper duplicate call to Remove."); + } + graph_ = graph; + for (Node* n : graph_->nodes()) { + if (n->IsMerge()) { + for (const Edge* e : n->in_edges()) { + if (e->src()->IsNextIteration()) { + back_edges_.push_back( + BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()}); + } + } + } + } + for (const BackEdge& be : back_edges_) { + graph_->RemoveEdge(be.edge); + } + return Status::OK(); +} + +const std::vector& BackEdgeHelper::RemovedEdges() + const { + return back_edges_; +} + +Status BackEdgeHelper::Replace() { + if (graph_ == nullptr) { + return errors::Internal("BackEdgeHelper Replace called before Remove."); + } + if (replaced_) { + return errors::Internal("BackEdgeHelper Replace called more than once."); + } + replaced_ = true; + for (const BackEdge& be : back_edges_) { + graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.h b/tensorflow/compiler/jit/shape_inference_helpers.h new file mode 100644 index 00000000000..2f053c9a45d --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Helper class to temporarily remove, then replace, the back edges in a +// graph. Simple algorithms for shape inference don't work with cycles, and this +// class can be used to remove cycles before running inference and replace them +// after. Correct usage requires exactly one call to Remove(), followed by any +// number of calls to RemovedEdges() and at most one call to Replace(). The call +// to Replace() is optional if the graph will be discarded without being +// executed, e.g., if it is being used purely for a shape inference pass. +class BackEdgeHelper { + public: + struct BackEdge { + const Edge* edge; + Node* src; + int src_output; + Node* dst; + int dst_input; + }; + + BackEdgeHelper() = default; + // Disallows copy and assign. + BackEdgeHelper(const BackEdgeHelper& other) = delete; + BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete; + + // Temporarily removes all the back edges in graph. + Status Remove(Graph* graph); + + // Gets the list of removed edges. + const std::vector& RemovedEdges() const; + + // Replaces the back edges removed by a prior call to Remove. + Status Replace(); + + private: + Graph* graph_ = nullptr; // not owned + std::vector back_edges_; + // Set once Replace has been called. + bool replaced_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_