Factor out logic to extract WhileLoopFrames to
functionalize_control_flow_util.h file. PiperOrigin-RevId: 263608965
This commit is contained in:
parent
f2b374e8df
commit
df59375fe8
tensorflow/compiler/tf2xla
@ -48,6 +48,43 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
|
||||
return AddNodeDefToGraph(ret_def, graph);
|
||||
}
|
||||
|
||||
Status ExtractWhileLoopFrames(
|
||||
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
|
||||
std::unordered_map<string, WhileLoopFrame>* frames) {
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_info[node->id()];
|
||||
|
||||
VLOG(2) << "node: " << node->name() << " (" << node->id()
|
||||
<< ") frame_name: " << cf.frame_name
|
||||
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
|
||||
<< " parent_frame: "
|
||||
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
|
||||
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
|
||||
|
||||
WhileLoopFrame& frame = (*frames)[cf.frame_name];
|
||||
WhileLoopFrame* parent =
|
||||
&(*frames)[cf_info[cf.parent_frame->id()].frame_name];
|
||||
if (frame.parent == nullptr) {
|
||||
frame.parent = parent;
|
||||
frame.name = cf.frame_name;
|
||||
++parent->num_children;
|
||||
}
|
||||
|
||||
if (IsEnter(node)) {
|
||||
WhileLoopArg arg;
|
||||
arg.enter = node;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
|
||||
&arg.is_loop_invariant));
|
||||
frame.args.push_back(arg);
|
||||
} else if (IsLoopCond(node)) {
|
||||
frame.loop_cond = node;
|
||||
}
|
||||
frame.nodes.insert(node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Check that the graph has no cycle containing the given node.
|
||||
Status CheckNodeNotInCycle(const Node* node, const int num_nodes) {
|
||||
std::vector<const Node*> ready;
|
||||
|
@ -18,12 +18,56 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
// Utility functions shared between functionalize cond and while.
|
||||
// Utility functions shared between functionalize cond and while
|
||||
// or used by other graph optimization passes.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Information about a loop argument.
|
||||
struct WhileLoopArg {
|
||||
// Every loop argument has an Enter node.
|
||||
Node* enter;
|
||||
|
||||
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
|
||||
// attribute on the Enter node.
|
||||
bool is_loop_invariant;
|
||||
|
||||
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
|
||||
// arguments must have all of the following nodes:
|
||||
Node* merge = nullptr;
|
||||
Node* switch_node = nullptr;
|
||||
Node* next_iteration = nullptr;
|
||||
Node* exit = nullptr;
|
||||
};
|
||||
|
||||
// Information about a loop frame.
|
||||
struct WhileLoopFrame {
|
||||
string name;
|
||||
|
||||
// Pointer to the parent frame. The root frame has a pointer to itself.
|
||||
WhileLoopFrame* parent = nullptr;
|
||||
int num_children = 0;
|
||||
|
||||
// Arguments to this loop.
|
||||
std::vector<WhileLoopArg> args;
|
||||
|
||||
// The loop condition of the loop. There should be exactly one loop condition
|
||||
// in every loop.
|
||||
Node* loop_cond = nullptr;
|
||||
|
||||
// Set of nodes that belong to the loop frame.
|
||||
std::unordered_set<Node*> nodes;
|
||||
};
|
||||
|
||||
// Extracts v1 while loops within a graph and creates a map of
|
||||
// <ControlFLowInfo.name, WhileLoopFrame>.
|
||||
Status ExtractWhileLoopFrames(
|
||||
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
|
||||
std::unordered_map<string, WhileLoopFrame>* frames);
|
||||
|
||||
// Check that the graph has no cycle containing the given node.
|
||||
Status CheckNodeNotInCycle(const Node* node, const int num_nodes);
|
||||
|
||||
|
@ -42,42 +42,6 @@ namespace {
|
||||
|
||||
using xla::StatusOr;
|
||||
|
||||
// Information about a loop argument.
|
||||
struct Arg {
|
||||
// Every loop argument has an Enter node.
|
||||
Node* enter;
|
||||
|
||||
// Is the loop argument a loop-invariant value? Taken from the `is_constant`
|
||||
// attribute on the Enter node.
|
||||
bool is_loop_invariant;
|
||||
|
||||
// If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
|
||||
// arguments must have all of the following nodes:
|
||||
Node* merge = nullptr;
|
||||
Node* switch_node = nullptr;
|
||||
Node* next_iteration = nullptr;
|
||||
Node* exit = nullptr;
|
||||
};
|
||||
|
||||
// Information about a loop frame.
|
||||
struct Frame {
|
||||
string name;
|
||||
|
||||
// Pointer to the parent frame. The root frame has a pointer to itself.
|
||||
Frame* parent = nullptr;
|
||||
int num_children = 0;
|
||||
|
||||
// Arguments to this loop.
|
||||
std::vector<Arg> args;
|
||||
|
||||
// The loop condition of the loop. There should be exactly one loop condition
|
||||
// in every loop.
|
||||
Node* loop_cond = nullptr;
|
||||
|
||||
// Set of nodes that belong to the loop frame.
|
||||
std::unordered_set<Node*> nodes;
|
||||
};
|
||||
|
||||
// Copies a subgraph from `graph` to `output` by performing a reverse DFS
|
||||
// starting at nodes in vector `stack`.
|
||||
// `node_map` is a vector indexed by source node ID to dest nodes.
|
||||
@ -93,7 +57,7 @@ struct Frame {
|
||||
// taking from the Switch node was not necessarily the first output, but _Arg
|
||||
// nodes only have one output. By adding the Switch node to `squash_src_outputs`
|
||||
// we rewrite the src_output of the corresponding edge to be 0.
|
||||
Status CopySubgraph(const Graph& graph, const Frame* frame,
|
||||
Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
|
||||
std::vector<Node*> stack,
|
||||
const std::vector<bool>& squash_src_outputs,
|
||||
std::vector<Node*>* node_map, Graph* output) {
|
||||
@ -154,7 +118,7 @@ StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
|
||||
}
|
||||
|
||||
// Builds a graph for the loop condition.
|
||||
Status BuildLoopCondition(const Graph& graph, Frame* frame,
|
||||
Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
|
||||
std::unique_ptr<Graph>* cond_output) {
|
||||
VLOG(2) << "Building loop condition for " << frame->name;
|
||||
*cond_output = absl::make_unique<Graph>(graph.op_registry());
|
||||
@ -166,7 +130,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame,
|
||||
|
||||
// Build one _Arg node for each Enter node.
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const WhileLoopArg& arg = frame->args[i];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Node * arg_node,
|
||||
BuildArgNode(output, arg.enter->input_type(0), i));
|
||||
@ -190,7 +154,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame,
|
||||
}
|
||||
|
||||
// Builds a graph for the loop body.
|
||||
Status BuildLoopBody(const Graph& graph, Frame* frame,
|
||||
Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
|
||||
DataTypeVector* arg_types,
|
||||
std::unique_ptr<Graph>* body_output) {
|
||||
VLOG(2) << "Building loop body for " << frame->name;
|
||||
@ -206,7 +170,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
|
||||
next_iterations.reserve(frame->args.size());
|
||||
arg_types->reserve(frame->args.size());
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const WhileLoopArg& arg = frame->args[i];
|
||||
|
||||
DataType dtype = arg.enter->input_type(0);
|
||||
arg_types->push_back(dtype);
|
||||
@ -297,7 +261,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef,
|
||||
}
|
||||
|
||||
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
Graph* graph, Frame* frame,
|
||||
Graph* graph, WhileLoopFrame* frame,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "Frame " << frame->name << " before: "
|
||||
<< DumpGraphToFile("functionalize_before", *graph, library);
|
||||
@ -307,8 +271,8 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
// shared Enter node. We clone Enter nodes with multiple successors to
|
||||
// maintain the invariant of a unique Enter node per argument of the final
|
||||
// loop.
|
||||
std::vector<Arg> args;
|
||||
for (const Arg& arg : frame->args) {
|
||||
std::vector<WhileLoopArg> args;
|
||||
for (const WhileLoopArg& arg : frame->args) {
|
||||
if (arg.is_loop_invariant) {
|
||||
args.push_back(arg);
|
||||
} else {
|
||||
@ -319,7 +283,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
continue;
|
||||
}
|
||||
TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
|
||||
Arg new_arg;
|
||||
WhileLoopArg new_arg;
|
||||
new_arg.is_loop_invariant = false;
|
||||
if (i == 0) {
|
||||
new_arg.enter = arg.enter;
|
||||
@ -342,7 +306,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
frame->args = std::move(args);
|
||||
|
||||
std::sort(frame->args.begin(), frame->args.end(),
|
||||
[](const Arg& a, const Arg& b) {
|
||||
[](const WhileLoopArg& a, const WhileLoopArg& b) {
|
||||
return NodeCmpByNameResourcesLast()(a.enter, b.enter);
|
||||
});
|
||||
|
||||
@ -368,7 +332,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
// ^ ^
|
||||
// | |
|
||||
// ... ...
|
||||
for (Arg& arg : frame->args) {
|
||||
for (WhileLoopArg& arg : frame->args) {
|
||||
if (!arg.is_loop_invariant) {
|
||||
// Follow the edge from the Enter to Merge.
|
||||
const Edge* enter_merge = nullptr;
|
||||
@ -537,7 +501,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
}
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const WhileLoopArg& arg = frame->args[i];
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
@ -553,7 +517,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
|
||||
// Copies edges to the Enter nodes and from the Exit nodes onto the While.
|
||||
for (int i = 0; i < frame->args.size(); ++i) {
|
||||
const Arg& arg = frame->args[i];
|
||||
const WhileLoopArg& arg = frame->args[i];
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
@ -613,39 +577,11 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
}
|
||||
|
||||
// Builds Frames, indexed by name.
|
||||
std::unordered_map<string, Frame> frames;
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_info[node->id()];
|
||||
|
||||
VLOG(2) << "node: " << node->name() << " (" << node->id()
|
||||
<< ") frame_name: " << cf.frame_name
|
||||
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
|
||||
<< " parent_frame: "
|
||||
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
|
||||
TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
|
||||
|
||||
Frame& frame = frames[cf.frame_name];
|
||||
Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
|
||||
if (frame.parent == nullptr) {
|
||||
frame.parent = parent;
|
||||
frame.name = cf.frame_name;
|
||||
++parent->num_children;
|
||||
}
|
||||
|
||||
if (IsEnter(node)) {
|
||||
Arg arg;
|
||||
arg.enter = node;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
|
||||
&arg.is_loop_invariant));
|
||||
frame.args.push_back(arg);
|
||||
} else if (IsLoopCond(node)) {
|
||||
frame.loop_cond = node;
|
||||
}
|
||||
frame.nodes.insert(node);
|
||||
}
|
||||
std::unordered_map<string, WhileLoopFrame> frames;
|
||||
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
|
||||
|
||||
// Adds frames with no children (i.e., the innermost frames) to a worklist.
|
||||
std::deque<Frame*> worklist;
|
||||
std::deque<WhileLoopFrame*> worklist;
|
||||
for (auto& frame : frames) {
|
||||
if (frame.second.num_children == 0) {
|
||||
worklist.push_back(&frame.second);
|
||||
@ -654,7 +590,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
|
||||
// Eliminate loops from innermost to outermost.
|
||||
while (!worklist.empty()) {
|
||||
Frame* frame = worklist.front();
|
||||
WhileLoopFrame* frame = worklist.front();
|
||||
worklist.pop_front();
|
||||
if (frame->parent == frame) {
|
||||
// Skip the root frame.
|
||||
|
Loading…
Reference in New Issue
Block a user