Fix shape inference for outside_compilation clusters that include cycles.
PiperOrigin-RevId: 192637289
This commit is contained in:
parent
151c31ce75
commit
dc2d1c297a
@ -183,6 +183,13 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "shape_inference_helpers",
|
||||||
|
srcs = ["shape_inference_helpers.cc"],
|
||||||
|
hdrs = ["shape_inference_helpers.h"],
|
||||||
|
deps = ["//tensorflow/core:graph"],
|
||||||
|
)
|
||||||
|
|
||||||
# Internal targets below this point.
|
# Internal targets below this point.
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -293,6 +300,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":graph_to_functiondef",
|
":graph_to_functiondef",
|
||||||
|
":shape_inference_helpers",
|
||||||
":union_find",
|
":union_find",
|
||||||
"//tensorflow/compiler/jit/graphcycles",
|
"//tensorflow/compiler/jit/graphcycles",
|
||||||
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
|
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -36,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/control_flow.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
@ -576,7 +578,8 @@ class Encapsulator {
|
|||||||
// satisfied, e.g., because send_node depends on a node that doesn't have a
|
// satisfied, e.g., because send_node depends on a node that doesn't have a
|
||||||
// registered shape inference function.
|
// registered shape inference function.
|
||||||
Status DoStaticShapeInferenceForOutsideCompilationSend(
|
Status DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
|
||||||
|
const ShapeRefiner& shape_refiner,
|
||||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||||
FunctionLibraryDefinition* library,
|
FunctionLibraryDefinition* library,
|
||||||
std::vector<TensorShapeProto>* static_shape_out,
|
std::vector<TensorShapeProto>* static_shape_out,
|
||||||
@ -599,7 +602,7 @@ class Encapsulator {
|
|||||||
// to nodes in pruned_graph.
|
// to nodes in pruned_graph.
|
||||||
Status MakeGraphForOutsideCompilationSends(
|
Status MakeGraphForOutsideCompilationSends(
|
||||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||||
ShapeRefiner* shape_refiner,
|
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
|
||||||
std::unordered_map<const Node*, Node*>* node_images,
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
FunctionLibraryDefinition* library);
|
FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
@ -1712,9 +1715,13 @@ namespace {
|
|||||||
// matter because it will only be used subsequently for shape inference. (It
|
// matter because it will only be used subsequently for shape inference. (It
|
||||||
// would be possible to add a switch statement over data_type to create a value
|
// would be possible to add a switch statement over data_type to create a value
|
||||||
// for the constant, but that would entail maintaining the logic as new types
|
// for the constant, but that would entail maintaining the logic as new types
|
||||||
// are added, and is not necessary.)
|
// are added, and is not necessary.) If the node being replaced was within a
|
||||||
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
// control flow frame, adds appropriate Enter nodes so that the use of the Const
|
||||||
Graph* graph_out) {
|
// is well-formed.
|
||||||
|
Node* AddDummyShapedNode(const Node* src_node, int src_port,
|
||||||
|
const std::vector<ControlFlowInfo>& control_flow_info,
|
||||||
|
const TensorShapeProto& shape, Graph* graph_out) {
|
||||||
|
DataType data_type = src_node->output_type(src_port);
|
||||||
TensorProto dummy_proto;
|
TensorProto dummy_proto;
|
||||||
dummy_proto.set_dtype(data_type);
|
dummy_proto.set_dtype(data_type);
|
||||||
*dummy_proto.mutable_tensor_shape() = shape;
|
*dummy_proto.mutable_tensor_shape() = shape;
|
||||||
@ -1725,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
|||||||
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
|
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
|
||||||
options.op_registry());
|
options.op_registry());
|
||||||
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
|
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
|
||||||
return options.FinalizeBuilder(&node_builder);
|
Node* node = options.FinalizeBuilder(&node_builder);
|
||||||
|
// Add any Enter nodes required to bring the constant to the correct control
|
||||||
|
// flow frame.
|
||||||
|
while (!control_flow_info[src_node->id()].frame_name.empty()) {
|
||||||
|
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
|
||||||
|
options.op_registry());
|
||||||
|
enter_builder.Attr("frame_name",
|
||||||
|
control_flow_info[src_node->id()].frame_name);
|
||||||
|
enter_builder.Attr("is_constant", true);
|
||||||
|
enter_builder.Input(node, 0);
|
||||||
|
Node* enter_node = options.FinalizeBuilder(&enter_builder);
|
||||||
|
// Adopt the new Enter node as the value in the current frame.
|
||||||
|
node = enter_node;
|
||||||
|
// Recurse to the parent frame to see if more Enter nodes need to be added.
|
||||||
|
src_node = control_flow_info[src_node->id()].parent_frame;
|
||||||
|
}
|
||||||
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds a copy of node_in to graph_out and adds the mapping to
|
// Adds a copy of node_in to graph_out and adds the mapping to
|
||||||
@ -1767,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Work around the fact that Enter nodes refuse to propagate shape information
|
||||||
|
// unless they are marked loop invariant. Since we are never going to execute
|
||||||
|
// this graph, marking them all loop invariant is fine.
|
||||||
|
if (node_out->type_string() == "Enter") {
|
||||||
|
node_out->ClearAttr("is_constant");
|
||||||
|
node_out->AddAttr("is_constant", true);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
|
||||||
|
const ShapeRefiner& shape_refiner,
|
||||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||||
FunctionLibraryDefinition* library,
|
FunctionLibraryDefinition* library,
|
||||||
std::vector<TensorShapeProto>* static_shape_out,
|
std::vector<TensorShapeProto>* static_shape_out,
|
||||||
std::unique_ptr<Graph>* graph_out) {
|
std::unique_ptr<Graph>* graph_out) {
|
||||||
|
// Get the control flow structure of the input graph so we can build
|
||||||
|
// well-formed output graphs.
|
||||||
|
std::vector<ControlFlowInfo> control_flow_info;
|
||||||
|
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info));
|
||||||
|
|
||||||
// Maps from nodes in graph_in to nodes in graph_out.
|
// Maps from nodes in graph_in to nodes in graph_out.
|
||||||
//
|
//
|
||||||
// When an edge has fully defined shape the source node in graph_in is
|
// When an edge has fully defined shape the source node in graph_in is
|
||||||
@ -1802,7 +1838,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
|||||||
|
|
||||||
// We don't use the standard ReverseDFS because we want to cut off traversal
|
// We don't use the standard ReverseDFS because we want to cut off traversal
|
||||||
// whenever we find an output with fully defined shape.
|
// whenever we find an output with fully defined shape.
|
||||||
// TODO(misard) make this work properly in the presence of control flow.
|
|
||||||
struct Work {
|
struct Work {
|
||||||
Node* node;
|
Node* node;
|
||||||
bool leave; // Are we entering or leaving node?
|
bool leave; // Are we entering or leaving node?
|
||||||
@ -1840,8 +1875,9 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
|||||||
TensorShapeProto proto;
|
TensorShapeProto proto;
|
||||||
context->ShapeHandleToProto(shape, &proto);
|
context->ShapeHandleToProto(shape, &proto);
|
||||||
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
|
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
|
||||||
dummy_node_images[src_node] = AddDummyShapedNode(
|
dummy_node_images[src_node] =
|
||||||
src_node->output_type(src_port), proto, graph_out->get());
|
AddDummyShapedNode(src_node, src_port, control_flow_info,
|
||||||
|
proto, graph_out->get());
|
||||||
}
|
}
|
||||||
// The final input to the send node is the dynamic key, which we
|
// The final input to the send node is the dynamic key, which we
|
||||||
// don't include in the static shapes.
|
// don't include in the static shapes.
|
||||||
@ -1889,6 +1925,38 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const auto edge : back_edge_helper.RemovedEdges()) {
|
||||||
|
if (copied_node_images.find(edge.dst) != copied_node_images.end()) {
|
||||||
|
// The destination of this back edge was added to the inference graph, so
|
||||||
|
// fix it up.
|
||||||
|
Node* dst = copied_node_images[edge.dst];
|
||||||
|
if (dst->type_string() != "Merge") {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"outside_compilation cluster contains a back-edge to node ",
|
||||||
|
dst->name(), " of type ", dst->type_string(),
|
||||||
|
". The analysis pass only supports back-edges to Merge nodes.");
|
||||||
|
}
|
||||||
|
const Edge* existing_input_edge;
|
||||||
|
if (edge.dst_input != 1 || dst->num_inputs() != 2 ||
|
||||||
|
!dst->input_edge(0, &existing_input_edge).ok()) {
|
||||||
|
// TODO(misard) if we see graphs built with a different structure, relax
|
||||||
|
// this constraint. Leaving it here for now to avoid writing unnecessary
|
||||||
|
// complex code since we believe graphs generated by front ends all have
|
||||||
|
// the back edge as the second input to the merge node.
|
||||||
|
return errors::Internal(
|
||||||
|
"Internal assumption failed while rewriting an outside_compilation "
|
||||||
|
"cluster that contains a while loop. Logic assumes back-edge is to "
|
||||||
|
"port 1 of a 2-input "
|
||||||
|
"Merge node.");
|
||||||
|
}
|
||||||
|
// Connect the existing edge to both inputs of the Merge node so that the
|
||||||
|
// graph will be well-formed.
|
||||||
|
(*graph_out)
|
||||||
|
->AddEdge(existing_input_edge->src(),
|
||||||
|
existing_input_edge->src_output(), dst, edge.dst_input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1956,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
|
|||||||
|
|
||||||
Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||||
ShapeRefiner* shape_refiner,
|
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
|
||||||
std::unordered_map<const Node*, Node*>* node_images,
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
FunctionLibraryDefinition* library) {
|
FunctionLibraryDefinition* library) {
|
||||||
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
|
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
|
||||||
@ -1978,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
|||||||
// nodes, inlining any functions as needed.
|
// nodes, inlining any functions as needed.
|
||||||
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
|
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
|
||||||
graph, send_from_host_nodes, pruned_graph, node_images, library));
|
graph, send_from_host_nodes, pruned_graph, node_images, library));
|
||||||
|
FixupSourceAndSinkEdges(pruned_graph->get());
|
||||||
|
|
||||||
|
// Remove back edges from any cycles in the pruned graph to simplify shape
|
||||||
|
// inference traversal. They will be fixed up in the per-subgraph shape
|
||||||
|
// inference graphs stored in the function library.
|
||||||
|
TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get()));
|
||||||
|
|
||||||
// Perform shape inference on the pruned graph.
|
// Perform shape inference on the pruned graph.
|
||||||
shape_refiner->set_require_shape_inference_fns(false);
|
shape_refiner->set_require_shape_inference_fns(false);
|
||||||
FixupSourceAndSinkEdges(pruned_graph->get());
|
|
||||||
std::vector<Node*> post_order;
|
std::vector<Node*> post_order;
|
||||||
GetReversePostOrder(*(*pruned_graph), &post_order);
|
GetReversePostOrder(*(*pruned_graph), &post_order);
|
||||||
for (auto node : post_order) {
|
for (auto node : post_order) {
|
||||||
@ -1999,11 +2072,13 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
|||||||
|
|
||||||
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
||||||
Graph* graph_out, FunctionLibraryDefinition* library) {
|
Graph* graph_out, FunctionLibraryDefinition* library) {
|
||||||
|
BackEdgeHelper back_edge_helper;
|
||||||
std::unique_ptr<Graph> pruned_graph;
|
std::unique_ptr<Graph> pruned_graph;
|
||||||
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
|
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
|
||||||
std::unordered_map<const Node*, Node*> node_images;
|
std::unordered_map<const Node*, Node*> node_images;
|
||||||
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
|
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
|
||||||
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
|
*graph_out, &pruned_graph, &back_edge_helper, &shape_refiner,
|
||||||
|
&node_images, library));
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
|
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
|
||||||
@ -2033,7 +2108,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
|||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
if (send_node != nullptr) {
|
if (send_node != nullptr) {
|
||||||
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
|
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
*pruned_graph, shape_refiner, recv_at_host_names,
|
*pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names,
|
||||||
node_images[send_node], library, &static_shape, &graph));
|
node_images[send_node], library, &static_shape, &graph));
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
VLOG(2) << "Send node " << send_node->name() << " shapes";
|
VLOG(2) << "Send node " << send_node->name() << " shapes";
|
||||||
|
66
tensorflow/compiler/jit/shape_inference_helpers.cc
Normal file
66
tensorflow/compiler/jit/shape_inference_helpers.cc
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Contains helpers for use in shape inference.
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status BackEdgeHelper::Remove(Graph* graph) {
|
||||||
|
if (graph_ != nullptr) {
|
||||||
|
return errors::Internal("BackEdgeHelper duplicate call to Remove.");
|
||||||
|
}
|
||||||
|
graph_ = graph;
|
||||||
|
for (Node* n : graph_->nodes()) {
|
||||||
|
if (n->IsMerge()) {
|
||||||
|
for (const Edge* e : n->in_edges()) {
|
||||||
|
if (e->src()->IsNextIteration()) {
|
||||||
|
back_edges_.push_back(
|
||||||
|
BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const BackEdge& be : back_edges_) {
|
||||||
|
graph_->RemoveEdge(be.edge);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<BackEdgeHelper::BackEdge>& BackEdgeHelper::RemovedEdges()
|
||||||
|
const {
|
||||||
|
return back_edges_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BackEdgeHelper::Replace() {
|
||||||
|
if (graph_ == nullptr) {
|
||||||
|
return errors::Internal("BackEdgeHelper Replace called before Remove.");
|
||||||
|
}
|
||||||
|
if (replaced_) {
|
||||||
|
return errors::Internal("BackEdgeHelper Replace called more than once.");
|
||||||
|
}
|
||||||
|
replaced_ = true;
|
||||||
|
for (const BackEdge& be : back_edges_) {
|
||||||
|
graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
65
tensorflow/compiler/jit/shape_inference_helpers.h
Normal file
65
tensorflow/compiler/jit/shape_inference_helpers.h
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Helper class to temporarily remove, then replace, the back edges in a
|
||||||
|
// graph. Simple algorithms for shape inference don't work with cycles, and this
|
||||||
|
// class can be used to remove cycles before running inference and replace them
|
||||||
|
// after. Correct usage requires exactly one call to Remove(), followed by any
|
||||||
|
// number of calls to RemovedEdges() and at most one call to Replace(). The call
|
||||||
|
// to Replace() is optional if the graph will be discarded without being
|
||||||
|
// executed, e.g., if it is being used purely for a shape inference pass.
|
||||||
|
class BackEdgeHelper {
|
||||||
|
public:
|
||||||
|
struct BackEdge {
|
||||||
|
const Edge* edge;
|
||||||
|
Node* src;
|
||||||
|
int src_output;
|
||||||
|
Node* dst;
|
||||||
|
int dst_input;
|
||||||
|
};
|
||||||
|
|
||||||
|
BackEdgeHelper() = default;
|
||||||
|
// Disallows copy and assign.
|
||||||
|
BackEdgeHelper(const BackEdgeHelper& other) = delete;
|
||||||
|
BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete;
|
||||||
|
|
||||||
|
// Temporarily removes all the back edges in graph.
|
||||||
|
Status Remove(Graph* graph);
|
||||||
|
|
||||||
|
// Gets the list of removed edges.
|
||||||
|
const std::vector<BackEdge>& RemovedEdges() const;
|
||||||
|
|
||||||
|
// Replaces the back edges removed by a prior call to Remove.
|
||||||
|
Status Replace();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Graph* graph_ = nullptr; // not owned
|
||||||
|
std::vector<BackEdge> back_edges_;
|
||||||
|
// Set once Replace has been called.
|
||||||
|
bool replaced_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_
|
Loading…
x
Reference in New Issue
Block a user