Factor out logic to extract WhileLoopFrames to

functionalize_control_flow_util.h file.

PiperOrigin-RevId: 263608965
This commit is contained in:
A. Unique TensorFlower 2019-08-15 11:44:08 -07:00 committed by TensorFlower Gardener
parent f2b374e8df
commit df59375fe8
3 changed files with 99 additions and 82 deletions

View File

@ -48,6 +48,43 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
return AddNodeDefToGraph(ret_def, graph);
}
Status ExtractWhileLoopFrames(
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
std::unordered_map<string, WhileLoopFrame>* frames) {
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
VLOG(2) << "node: " << node->name() << " (" << node->id()
<< ") frame_name: " << cf.frame_name
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
<< " parent_frame: "
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
WhileLoopFrame& frame = (*frames)[cf.frame_name];
WhileLoopFrame* parent =
&(*frames)[cf_info[cf.parent_frame->id()].frame_name];
if (frame.parent == nullptr) {
frame.parent = parent;
frame.name = cf.frame_name;
++parent->num_children;
}
if (IsEnter(node)) {
WhileLoopArg arg;
arg.enter = node;
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
&arg.is_loop_invariant));
frame.args.push_back(arg);
} else if (IsLoopCond(node)) {
frame.loop_cond = node;
}
frame.nodes.insert(node);
}
return Status::OK();
}
// Check that the graph has no cycle containing the given node.
Status CheckNodeNotInCycle(const Node* node, const int num_nodes) {
std::vector<const Node*> ready;

View File

@ -18,12 +18,56 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
// Utility functions shared between functionalize cond and while.
// Utility functions shared between functionalize cond and while
// or used by other graph optimization passes.
namespace tensorflow {
// Information about a loop argument.
struct WhileLoopArg {
// Every loop argument has an Enter node.
Node* enter;
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
// attribute on the Enter node.
bool is_loop_invariant;
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
// arguments must have all of the following nodes:
Node* merge = nullptr;
Node* switch_node = nullptr;
Node* next_iteration = nullptr;
Node* exit = nullptr;
};
// Information about a loop frame.
struct WhileLoopFrame {
string name;
// Pointer to the parent frame. The root frame has a pointer to itself.
WhileLoopFrame* parent = nullptr;
int num_children = 0;
// Arguments to this loop.
std::vector<WhileLoopArg> args;
// The loop condition of the loop. There should be exactly one loop condition
// in every loop.
Node* loop_cond = nullptr;
// Set of nodes that belong to the loop frame.
std::unordered_set<Node*> nodes;
};
// Extracts v1 while loops within a graph and creates a map of
// <ControlFLowInfo.name, WhileLoopFrame>.
Status ExtractWhileLoopFrames(
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
std::unordered_map<string, WhileLoopFrame>* frames);
// Check that the graph has no cycle containing the given node.
Status CheckNodeNotInCycle(const Node* node, const int num_nodes);

View File

@ -42,42 +42,6 @@ namespace {
using xla::StatusOr;
// Information about a loop argument.
struct Arg {
// Every loop argument has an Enter node.
Node* enter;
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
// attribute on the Enter node.
bool is_loop_invariant;
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
// arguments must have all of the following nodes:
Node* merge = nullptr;
Node* switch_node = nullptr;
Node* next_iteration = nullptr;
Node* exit = nullptr;
};
// Information about a loop frame.
struct Frame {
string name;
// Pointer to the parent frame. The root frame has a pointer to itself.
Frame* parent = nullptr;
int num_children = 0;
// Arguments to this loop.
std::vector<Arg> args;
// The loop condition of the loop. There should be exactly one loop condition
// in every loop.
Node* loop_cond = nullptr;
// Set of nodes that belong to the loop frame.
std::unordered_set<Node*> nodes;
};
// Copies a subgraph from `graph` to `output` by performing a reverse DFS
// starting at nodes in vector `stack`.
// `node_map` is a vector indexed by source node ID to dest nodes.
@ -93,7 +57,7 @@ struct Frame {
// taking from the Switch node was not necessarily the first output, but _Arg
// nodes only have one output. By adding the Switch node to `squash_src_outputs`
// we rewrite the src_output of the corresponding edge to be 0.
Status CopySubgraph(const Graph& graph, const Frame* frame,
Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
std::vector<Node*> stack,
const std::vector<bool>& squash_src_outputs,
std::vector<Node*>* node_map, Graph* output) {
@ -154,7 +118,7 @@ StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
}
// Builds a graph for the loop condition.
Status BuildLoopCondition(const Graph& graph, Frame* frame,
Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
std::unique_ptr<Graph>* cond_output) {
VLOG(2) << "Building loop condition for " << frame->name;
*cond_output = absl::make_unique<Graph>(graph.op_registry());
@ -166,7 +130,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame,
// Build one _Arg node for each Enter node.
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const WhileLoopArg& arg = frame->args[i];
TF_ASSIGN_OR_RETURN(Node * arg_node,
BuildArgNode(output, arg.enter->input_type(0), i));
@ -190,7 +154,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame,
}
// Builds a graph for the loop body.
Status BuildLoopBody(const Graph& graph, Frame* frame,
Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
DataTypeVector* arg_types,
std::unique_ptr<Graph>* body_output) {
VLOG(2) << "Building loop body for " << frame->name;
@ -206,7 +170,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
next_iterations.reserve(frame->args.size());
arg_types->reserve(frame->args.size());
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const WhileLoopArg& arg = frame->args[i];
DataType dtype = arg.enter->input_type(0);
arg_types->push_back(dtype);
@ -297,7 +261,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef,
}
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph, Frame* frame,
Graph* graph, WhileLoopFrame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< DumpGraphToFile("functionalize_before", *graph, library);
@ -307,8 +271,8 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// shared Enter node. We clone Enter nodes with multiple successors to
// maintain the invariant of a unique Enter node per argument of the final
// loop.
std::vector<Arg> args;
for (const Arg& arg : frame->args) {
std::vector<WhileLoopArg> args;
for (const WhileLoopArg& arg : frame->args) {
if (arg.is_loop_invariant) {
args.push_back(arg);
} else {
@ -319,7 +283,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
Arg new_arg;
WhileLoopArg new_arg;
new_arg.is_loop_invariant = false;
if (i == 0) {
new_arg.enter = arg.enter;
@ -342,7 +306,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
frame->args = std::move(args);
std::sort(frame->args.begin(), frame->args.end(),
[](const Arg& a, const Arg& b) {
[](const WhileLoopArg& a, const WhileLoopArg& b) {
return NodeCmpByNameResourcesLast()(a.enter, b.enter);
});
@ -368,7 +332,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// ^ ^
// | |
// ... ...
for (Arg& arg : frame->args) {
for (WhileLoopArg& arg : frame->args) {
if (!arg.is_loop_invariant) {
// Follow the edge from the Enter to Merge.
const Edge* enter_merge = nullptr;
@ -537,7 +501,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
std::vector<NodeDefBuilder::NodeOut> inputs;
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const WhileLoopArg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
@ -553,7 +517,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Copies edges to the Enter nodes and from the Exit nodes onto the While.
for (int i = 0; i < frame->args.size(); ++i) {
const Arg& arg = frame->args[i];
const WhileLoopArg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
@ -613,39 +577,11 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
}
// Builds Frames, indexed by name.
std::unordered_map<string, Frame> frames;
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
VLOG(2) << "node: " << node->name() << " (" << node->id()
<< ") frame_name: " << cf.frame_name
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
<< " parent_frame: "
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
Frame& frame = frames[cf.frame_name];
Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
if (frame.parent == nullptr) {
frame.parent = parent;
frame.name = cf.frame_name;
++parent->num_children;
}
if (IsEnter(node)) {
Arg arg;
arg.enter = node;
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
&arg.is_loop_invariant));
frame.args.push_back(arg);
} else if (IsLoopCond(node)) {
frame.loop_cond = node;
}
frame.nodes.insert(node);
}
std::unordered_map<string, WhileLoopFrame> frames;
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
// Adds frames with no children (i.e., the innermost frames) to a worklist.
std::deque<Frame*> worklist;
std::deque<WhileLoopFrame*> worklist;
for (auto& frame : frames) {
if (frame.second.num_children == 0) {
worklist.push_back(&frame.second);
@ -654,7 +590,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
// Eliminate loops from innermost to outermost.
while (!worklist.empty()) {
Frame* frame = worklist.front();
WhileLoopFrame* frame = worklist.front();
worklist.pop_front();
if (frame->parent == frame) {
// Skip the root frame.