Add capability to restrict functionalization via node filter

For the MLIR-based TPU bridge, functionalization runs before TPU cluster
extraction as part of the graph-to-MLIR conversion. This was problematic because
previously also non-TPU-nodes were functionalized at this stage which caused
issues in the TF v1 session runtime that assumes certain nodes are left
unchanged.
This change adds the capability to restrict functionalization to certain loops
and conditions, according to a user-defined node filter, which can be used to
fix the above issues (there will be a separate CL for this).

PiperOrigin-RevId: 318846935
Change-Id: I36078909c6091de083ffa5d57cdf63eca5f844ef
This commit is contained in:
Michael Gester 2020-06-29 10:50:39 -07:00 committed by TensorFlower Gardener
parent 777b6ad484
commit 4e0d3b117d
11 changed files with 494 additions and 274 deletions

View File

@ -781,6 +781,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -22,10 +22,10 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
@ -285,8 +285,12 @@ string StateMap::AncestorStateToString(const Node* node) const {
}
FunctionalizeCond::FunctionalizeCond(Graph* graph,
FunctionLibraryDefinition* library)
: state_map_(graph), library_(library), graph_(graph) {}
FunctionLibraryDefinition* library,
const NodeFilter& node_filter)
: state_map_(graph),
library_(library),
graph_(graph),
node_filter_(node_filter) {}
// Class representing the merge/switch nodes that will become a conditional.
class Conditional {
@ -807,11 +811,13 @@ Status Conditional::BuildIfNode(Graph* graph,
<< PartialTensorShapeUtils::PartialShapeListString(output_shapes);
builder.Attr("Tcond", DT_BOOL);
string outside_compilation;
if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {
builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
// Add all underscore attributes, these need to be propagated.
for (const auto& attr : predicate_.node->def().attr()) {
const string& name(attr.first);
const AttrValue& value(attr.second);
if (absl::StartsWith(name, "_")) {
builder.Attr(name, value);
}
}
builder.Device(predicate_.node->assigned_device_name());
// Conditional should be the first input ...
@ -1076,7 +1082,7 @@ StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
// Determine the flow state when joining two states for a merge
// node. Combining the two states for a merge node is effectively performing a
// disjunction of the states along the different input edges. For a merge that
// can be transformed into a If the two inputs paths have to have a predicate
// can be transformed into an If the two inputs paths have to have a predicate
// on which they differ (e.g., along one edge predicate `p` has to hold while
// on another it should not). This function first determines this predicate
// and then the resultant state is the common path between the two inputs
@ -1368,8 +1374,9 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes(
deleted[graph_->kSourceId] = true;
deleted[graph_->kSinkId] = true;
// All remaining Switch nodes are not reachable from a Merge node and
// removed. This is to account for dead Switch nodes.
// All remaining switch nodes that were not excluded from functionalization
// according to `node_filter_` are not reachable from a merge node and
// removed. This is to account for dead switch nodes.
for (int s_id : switch_ids_) {
Node* s = graph_->FindNodeId(s_id);
if (s == nullptr) continue;
@ -1379,11 +1386,17 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes(
// conditional.
if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
}
deleted[s_id] = true;
graph_->RemoveNode(s);
// Only remove switch node if we have functionalized the corresponding
// condition before (according to `node_filter_`).
if (!node_filter_ || node_filter_(s)) {
VLOG(2) << "Removing obsolete switch node " << s->name();
deleted[s_id] = true;
graph_->RemoveNode(s);
}
}
// All merge nodes should have been transformed at this point and we remove
// All merge nodes that were not excluded from functionalization according to
// `node_filter_` should have been transformed at this point and we remove
// them from the graph here.
for (Node* m : merge_order) {
for (const Edge* e : m->out_edges()) {
@ -1393,8 +1406,13 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes(
// being removed in AddOutputEdges.
if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
}
deleted[m->id()] = true;
graph_->RemoveNode(m);
// Only remove merge node if we have functionalized the corresponding
// condition before (according to `node_filter_`).
if (!node_filter_ || node_filter_(m)) {
VLOG(2) << "Removing obsolete merge node " << m->name();
deleted[m->id()] = true;
graph_->RemoveNode(m);
}
}
// Enqueue all the dead nodes.
@ -1403,7 +1421,7 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes(
delete_nodes.push_back(n->id());
}
}
// Remove dead nodes and nodes that are reachable from dead nodes.
while (!delete_nodes.empty()) {
int d_id = delete_nodes.front();
delete_nodes.pop_front();
@ -1414,6 +1432,7 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes(
for (const Edge* e : d->out_edges()) {
delete_nodes.push_back(e->dst()->id());
}
VLOG(2) << "Removing obsolete node " << d->name();
deleted[d_id] = true;
graph_->RemoveNode(d);
}
@ -1454,6 +1473,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
// AncestorState from the innermost to the outermost into IfOps;
// Note: In the above only nodes that feed into a merge node will be
// considered for functionalization.
// Note: Nodes for which `node_filter_` returns false are excluded.
// Perform a DFS over the graph and
// * Determine the reverse topological order of the nodes (there should be no
@ -1463,12 +1483,18 @@ Status FunctionalizeCond::FunctionalizeInternal() {
std::vector<Node*> rev_topo_order;
std::vector<Node*> merge_order;
DFS(*graph_, nullptr, [&](Node* n) {
if (IsSwitch(n)) {
AddSwitchId(n->id());
}
if (IsMerge(n)) {
merge_order.push_back(n);
// Only collect switch and merge nodes that are not filtered out, those form
// the conditions that will be functionalized.
if (!node_filter_ || node_filter_(n)) {
if (IsSwitch(n)) {
AddSwitchId(n->id());
}
if (IsMerge(n)) {
merge_order.push_back(n);
}
}
// Collect all other nodes here, independent of `node_filter_`, because they
// might belong to a condition that should be functionalized.
if (n->IsOp()) {
rev_topo_order.push_back(n);
}
@ -1571,19 +1597,22 @@ void FunctionalizeCond::AddSwitchId(int switch_id) {
}
Status FunctionalizeCond::Functionalize(Graph* graph,
FunctionLibraryDefinition* library) {
FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
VLOG(1) << "FunctionalizeCond::Functionalize";
FunctionalizeCond fc(graph, library);
FunctionalizeCond fc(graph, library, node_filter);
return fc.FunctionalizeInternal();
}
} // namespace functionalize_cond
Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
// in successive invocations.
return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
return functionalize_cond::FunctionalizeCond::Functionalize(graph, library,
node_filter);
}
} // namespace tensorflow

View File

@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
#include <deque>
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@ -26,8 +28,17 @@ namespace tensorflow {
// Functionalize all the switch-merge nodes of a loop-free graph into If
// nodes. That is, attempt to transform every remaining switch and merge nodes
// in the graph into If nodes.
// Precondition: All while loops have been removed from graph.
Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
//
// If `node_filter` is defined, then only conditions for whose nodes
// `node_filter` returns true are functionalized.
//
// Preconditions:
// a) Same as for `FunctionalizeControlFlow` (see comment there).
// b) While loops must have been functionalized before according to
// `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same
// filter before calling this function).
Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {});
// Internal functions/classes exposed for testing purposes.
namespace functionalize_cond {
@ -172,11 +183,9 @@ class StateMap {
// of the given graph together.
class FunctionalizeCond {
public:
// Functionalize all the switch-merge nodes of a loop-free graph into If
// nodes. That is, attempt to transform every remaining switch and merge nodes
// in the graph into If nodes.
// Precondition: All while loops have been removed from graph.
static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
// See comment for function `FunctionalizeCond`.
static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter);
// Build identity node with the same name as the merge that will be replaced
// in case the output is fetched/colocated.
@ -197,7 +206,8 @@ class FunctionalizeCond {
void AddSwitchId(int switch_id);
private:
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter);
// Performs the actual cond functionalization. Iterate over groups of merge
// nodes (linked by common predicates & ancestor IDs), from innermost to
@ -268,6 +278,9 @@ class FunctionalizeCond {
friend class FunctionalizeCondTest;
std::vector<int> switch_ids_;
// Controls which nodes are skipped for functionalization.
NodeFilter node_filter_ = {};
};
} // namespace functionalize_cond

View File

@ -40,8 +40,8 @@ class FunctionalizeCondTest : public ::testing::Test {
graph_.reset(new Graph(OpRegistry::Global()));
flib_def_.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_));
fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(),
flib_def_.get()));
fc_.reset(new functionalize_cond::FunctionalizeCond(
graph_.get(), flib_def_.get(), NodeFilter{}));
}
StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) {

View File

@ -46,20 +46,19 @@ limitations under the License.
namespace tensorflow {
// Transformation that converts TensorFlow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
<< DumpGraphToFile("functionalize_initial", *graph, library);
// Functionalize and remove while loops from graph.
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library));
TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
// in successive invocations.
TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library));
TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter));
VLOG(2) << "FunctionalizeControlFlow (final): "
<< DumpGraphToFile("functionalize_final", *graph, library);
@ -68,12 +67,13 @@ Status FunctionalizeControlFlow(Graph* graph,
}
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
FunctionLibraryDefinition* library) {
FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
FunctionDefLibrary function_lib = graph_def->library();
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library));
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter));
graph.ToGraphDef(graph_def);
std::swap(*graph_def->mutable_library(), function_lib);
return Status::OK();

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
@ -26,11 +27,27 @@ namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators and tf.cond() conditionals into function If operators, suitable for
// XLA compilation.
//
// If `node_filter` is defined, then only loops and conditions for whose
// nodes `node_filter` returns true are functionalized.
//
// Precondition:
// For any node in a loop or condition for which `node_filter` returns true,
// all nodes inside of the same loop or condition must also return true
// (including nodes in other nested loops and conditions inside of that loop or
// condition).
// This means that a "not to be functionalized" loop or condition is not allowed
// inside a "to be functionalized" loop or condition.
//
// The user of this function is responsible for using a node filter that
// satisfies the above conditions.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {});
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
FunctionLibraryDefinition* library);
FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {});
// This pass looks at the graph, and turns V1 control flow structure
// (Switch/Merge/etc.) into V2 control flow structure (If/While).

View File

@ -62,7 +62,18 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
// z = control_flow_ops.cond(
// math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
// lambda: math_ops.add(x, 23))
TEST(FunctionalizeControlFlow, Conditional) {
//
// Tests different node filters.
class ConditionalTestFixture : public ::testing::TestWithParam<bool> {
protected:
void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); }
void RunTest();
private:
bool restrict_to_tpu_nodes_ = false;
};
void ConditionalTestFixture::RunTest() {
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
@ -92,14 +103,25 @@ TEST(FunctionalizeControlFlow, Conditional) {
std::initializer_list<Input>{add, mul});
TF_EXPECT_OK(scope.ToGraph(&graph));
// Set `_tpu_replicate` attribute for all nodes.
for (Node* n : graph.nodes()) {
n->AddAttr("_tpu_replicate", "cluster");
}
}
// If `restrict_to_tpu_nodes_` is true let filter function return true for
// `_tpu_replicate` nodes.
NodeFilter node_filter =
restrict_to_tpu_nodes_
? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
: NodeFilter{};
FunctionLibraryDefinition library(OpRegistry::Global(), {});
GraphDef optimized_graph_def;
graph.ToGraphDef(&optimized_graph_def);
TF_ASSERT_OK(
FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
&library, node_filter));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter));
GraphDef converted_graph_def;
graph.ToGraphDef(&converted_graph_def);
@ -180,6 +202,13 @@ TEST(FunctionalizeControlFlow, Conditional) {
}
}
TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
INSTANTIATE_TEST_SUITE_P(
FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(),
[](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
info) { return info.param ? "with_filter" : "without_filter"; });
// Returns the names of the "cond" and "body" functions for the While node
// in a graph.
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
@ -758,25 +787,75 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) {
}
}
// Example with nesting, loop-invariant arguments, and resource variables.
//
// accum = resource_variable_ops.ResourceVariable(1)
// x = array_ops.placeholder(2, dtype=dtypes.int32)
// y = 3 + x
//
// def inner_body(j, k):
// add = state_ops.assign_add(accum, k * j + x)
// with ops.control_dependencies([add]):
// return [j + 1, k]
//
// def body(i):
// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
// [1, y], name="inner")
// with ops.control_dependencies(m):
// return [i + 1]
//
// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
TEST(FunctionalizeControlFlow, Complex) {
// More complex example with nesting, loop-invariant arguments, and resource
// variables. Used for multiple tests with different node filters.
class ComplexTestFixture
: public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
protected:
void SetUp() override {
restrict_to_tpu_nodes_ = std::get<0>(GetParam());
mark_inner_loop_tpu_ = std::get<1>(GetParam());
mark_outer_loop_tpu_ = std::get<2>(GetParam());
}
void RunTest();
private:
void CheckOuterNodesFunctionalized(const GraphDef& graph_def,
const FunctionLibraryDefinition& library,
NameAttrList& inner_cond_fn,
NameAttrList& inner_body_fn);
void CheckInnerNodesFunctionalized(const GraphDef& graph_def,
const FunctionLibraryDefinition& library,
const NameAttrList& inner_cond_fn,
const NameAttrList& inner_body_fn);
bool restrict_to_tpu_nodes_ = false;
bool mark_inner_loop_tpu_ = false;
bool mark_outer_loop_tpu_ = false;
};
TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); }
INSTANTIATE_TEST_SUITE_P(
FunctionalizeControlFlow, ComplexTestFixture,
::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()),
[](const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) {
bool restrict_to_tpu_nodes = std::get<0>(info.param);
bool mark_inner_loop_tpu = std::get<1>(info.param);
bool mark_outer_loop_tpu = std::get<2>(info.param);
string node_string;
if (mark_inner_loop_tpu && mark_outer_loop_tpu)
node_string = "both_loops_tpu";
else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu)
node_string = "no_loop_tpu";
else
node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu";
string name = absl::StrCat(
restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string);
return name;
});
void ComplexTestFixture::RunTest() {
// Graph:
//
// accum = resource_variable_ops.ResourceVariable(1)
// x = array_ops.placeholder(2, dtype=dtypes.int32)
// y = 3 + x
//
// def inner_body(j, k):
// add = state_ops.assign_add(accum, k * j + x)
// with ops.control_dependencies([add]):
// return [j + 1, k]
//
// def body(i):
// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
// [1, y], name="inner")
// with ops.control_dependencies(m):
// return [i + 1]
//
// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
@ -846,7 +925,8 @@ TEST(FunctionalizeControlFlow, Complex) {
5);
auto less_j =
ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
auto loop_cond =
ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j);
auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
merge_j.output, loop_cond);
@ -906,193 +986,246 @@ TEST(FunctionalizeControlFlow, Complex) {
TF_EXPECT_OK(scope.ToGraph(&graph));
}
// Add '_tpu_replicate' attributes as specified.
for (Node* n : graph.nodes()) {
string name = n->name();
bool is_inner_node = name.find("outer/inner/") != string::npos;
bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos;
if ((is_inner_node && mark_inner_loop_tpu_) ||
(is_outer_node && mark_outer_loop_tpu_)) {
n->AddAttr("_tpu_replicate", "cluster");
}
}
FunctionLibraryDefinition library(OpRegistry::Global(), {});
GraphDef optimized_graph_def;
graph.ToGraphDef(&optimized_graph_def);
TF_ASSERT_OK(
FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
GraphDef converted_graph_def;
graph.ToGraphDef(&converted_graph_def);
GraphDef orig_graph_def, optimized_graph_def;
graph.ToGraphDef(&orig_graph_def);
optimized_graph_def = orig_graph_def;
// If `restrict_to_tpu_nodes_` is true let filter function return true for
// `_tpu_replicate` nodes, otherwise don't set filter.
NodeFilter node_filter =
restrict_to_tpu_nodes_
? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
: NodeFilter{};
for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
NameAttrList outer_cond_fn, outer_body_fn;
TF_EXPECT_OK(
FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
&library, node_filter);
Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter);
ASSERT_EQ(status1, status2);
if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
// This case violates the precondition of `FunctionalizeControlFlow`, we
// expect an internal error.
ASSERT_EQ(errors::IsInternal(status1), true);
return;
} else {
// Supported cases, no error expected.
TF_ASSERT_OK(status1);
}
// Outer graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
auto y = ops::Add(scope.WithOpName("y"), x, three);
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
TensorShape({}));
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
std::initializer_list<Input>{zero, y, x, var},
outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Outer condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto ten = ops::Const<int32>(
scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
10);
auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Outer body graph.
GraphDef optimized_converted_graph_def;
graph.ToGraphDef(&optimized_converted_graph_def);
for (const GraphDef& graph_def :
{optimized_graph_def, optimized_converted_graph_def}) {
NameAttrList inner_cond_fn, inner_body_fn;
{
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
// Find the inner condition and body names.
TF_EXPECT_OK(
FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
ops::While(scope.WithOpName("outer/LoopCond_1"),
std::initializer_list<Input>{one_j, arg1, arg2, arg3},
inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i),
1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
.WithControlDependencies(absl::Span<const Operation>{
while_op[0].op(), while_op[1].op()}),
identity_i, one_outer);
auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
if (!restrict_to_tpu_nodes_ ||
(restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ &&
mark_inner_loop_tpu_)) {
// We expect that both inner and outer nodes have been functionalized.
CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn,
inner_body_fn);
CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
inner_body_fn);
} else /*restrict_to_tpu_nodes_ == true*/ {
if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
// Graph has no TPU nodes so we expect no functionalization.
TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def);
} else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) {
// We expect that only inner nodes have been functionalized.
TF_EXPECT_OK(
FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn));
CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
inner_body_fn);
}
}
}
}
// Inner condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
void ComplexTestFixture::CheckOuterNodesFunctionalized(
const GraphDef& graph_def, const FunctionLibraryDefinition& library,
NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) {
NameAttrList outer_cond_fn, outer_body_fn;
TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
auto five = ops::Const<int32>(
scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0),
5);
auto less_j =
ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
// Outer graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
auto y = ops::Add(scope.WithOpName("y"), x, three);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
TensorShape({}));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
std::initializer_list<Input>{zero, y, x, var},
outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Inner body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
// Outer condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto identity_j =
ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
auto identity_k =
ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
auto ten = ops::Const<int32>(
scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
10);
auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
auto mul_jk =
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
auto add_jkx =
ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
auto one = ops::Const<int32>(
scope.WithOpName("outer/inner/One")
.WithControlDependencies(
absl::Span<const Operation>{assign.operation}),
1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
auto retval1 =
ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
// Outer body graph.
{
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
// Find the inner condition and body names.
TF_EXPECT_OK(
FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
ops::While(scope.WithOpName("outer/inner/LoopCond"),
std::initializer_list<Input>{one_j, arg1, arg2, arg3},
inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
.WithControlDependencies(absl::Span<const Operation>{
while_op[0].op(), while_op[1].op()}),
identity_i, one_outer);
auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}
void ComplexTestFixture::CheckInnerNodesFunctionalized(
const GraphDef& graph_def, const FunctionLibraryDefinition& library,
const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) {
// Inner condition graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto five = ops::Const<int32>(
scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
// Inner body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
auto identity_j =
ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
auto identity_k =
ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
auto mul_jk =
ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
auto one = ops::Const<int32>(
scope.WithOpName("outer/inner/One")
.WithControlDependencies(
absl::Span<const Operation>{assign.operation}),
1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
auto retval1 =
ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
TF_EXPECT_OK(
InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}

View File

@ -51,7 +51,8 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
Status ExtractWhileLoopFrames(
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
std::unordered_map<string, WhileLoopFrame>* frames) {
std::unordered_map<string, WhileLoopFrame>* frames,
const NodeFilter& node_filter) {
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
@ -81,6 +82,9 @@ Status ExtractWhileLoopFrames(
frame.loop_cond = node;
}
frame.nodes.insert(node);
if (node->IsControlFlow() && node_filter && !node_filter(node)) {
frame.should_be_functionalized = false;
}
}
return Status::OK();

View File

@ -26,6 +26,8 @@ limitations under the License.
namespace tensorflow {
using NodeFilter = std::function<bool(const Node*)>;
// Information about a loop argument.
struct WhileLoopArg {
// Every loop argument has an Enter node.
@ -60,13 +62,22 @@ struct WhileLoopFrame {
// Set of nodes that belong to the loop frame.
std::unordered_set<Node*> nodes;
// After `ExtractWhileLoopFrames` this is true if for all control flow nodes
// of this frame `node_filter` returns true, i.e., the frame should be
// functionalized, and false otherwise.
bool should_be_functionalized = true;
};
// Extracts v1 while loops within a graph and creates a map of
// <ControlFLowInfo.name, WhileLoopFrame>.
// If `node_filter` is defined, then we keep track of frames that should be
// functionalized according to the filter (see comment for
// `FunctionalizeControlFlow` for more details about node filters).
Status ExtractWhileLoopFrames(
const std::vector<ControlFlowInfo>& cf_info, const Graph* graph,
std::unordered_map<string, WhileLoopFrame>* frames);
std::unordered_map<string, WhileLoopFrame>* frames,
const NodeFilter& node_filter = {});
// Check that the graph has no cycle containing the given node.
Status CheckNodeNotInCycle(const Node* node, const int num_nodes);

View File

@ -22,11 +22,10 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
@ -162,7 +161,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
*body_output = absl::make_unique<Graph>(graph.op_registry());
Graph* output = body_output->get();
// Map from nodes in the original graph to the condition graph.
// Map from nodes in the original graph to the body graph.
std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
@ -212,7 +211,14 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
}
Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
FunctionLibraryDefinition* library) {
FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
if (node_filter && !frame->should_be_functionalized) {
VLOG(2) << "Skipping functionalization for frame " << frame->name
<< " because it has control flow nodes that are filtered out by "
"the specified node filter.";
return Status::OK();
}
VLOG(2) << "Frame " << frame->name << " before: "
<< DumpGraphToFile("functionalize_before", *graph, library);
@ -349,10 +355,6 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
return errors::InvalidArgument("Missing Switch successor to ",
FormatNodeForError(*arg.merge));
}
// Update the device on the Identity outputs of the switch to match their
// target. These Identity outputs do not
// Loop over the switch node's output to:
// - Find the Exit successor.
// - Set the sharding on all Identity outputs of the switch. These
@ -402,12 +404,12 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
FixupSourceAndSinkEdges(cond_graph.get());
TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
FixupSourceAndSinkEdges(body_graph.get());
TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
VLOG(2) << "Frame " << frame->name << " condition: "
<< DumpGraphToFile("loop_condition", *cond_graph, library)
@ -433,17 +435,13 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
string outside_compilation;
string frontend_attributes;
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
&frontend_attributes)
.ok()) {
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
}
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {
builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
// Add all underscore attributes, these need to be propagated.
for (const auto& attr : frame->loop_cond->def().attr()) {
const string& name(attr.first);
const AttrValue& value(attr.second);
if (absl::StartsWith(name, "_")) {
builder.Attr(name, value);
}
}
std::vector<NodeDefBuilder::NodeOut> inputs;
for (int i = 0; i < frame->args.size(); ++i) {
@ -495,6 +493,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
// Remove the old nodes from the graph, and add the while node to the parent
// frame.
for (Node* node : frame->nodes) {
VLOG(2) << "Removing obsolete node " << node->name();
graph->RemoveNode(node);
}
frame->nodes.clear();
@ -507,8 +506,8 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
}
} // namespace
Status FunctionalizeWhileLoop(Graph* graph,
FunctionLibraryDefinition* library) {
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter) {
// Note: BuildControlFlowInfo() requires that the graph's source node is
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
@ -523,7 +522,8 @@ Status FunctionalizeWhileLoop(Graph* graph,
// Builds Frames, indexed by name.
std::unordered_map<string, WhileLoopFrame> frames;
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
TF_RETURN_IF_ERROR(
ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
// Adds frames with no children (i.e., the innermost frames) to a worklist.
std::deque<WhileLoopFrame*> worklist;
@ -533,7 +533,9 @@ Status FunctionalizeWhileLoop(Graph* graph,
}
}
// Eliminate loops from innermost to outermost.
// Eliminate loops from innermost to outermost. Note that the precondition for
// `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
// works.
while (!worklist.empty()) {
WhileLoopFrame* frame = worklist.front();
worklist.pop_front();
@ -542,7 +544,7 @@ Status FunctionalizeWhileLoop(Graph* graph,
continue;
}
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
@ -551,14 +553,16 @@ Status FunctionalizeWhileLoop(Graph* graph,
}
}
// There should be no cycle at this point, since while loops have been removed
// from graph.
// Check that the newly added While nodes don't feed into themselves.
for (const Node* node : graph->op_nodes()) {
if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");
if (!node_filter) {
// There should be no cycle at this point, since while loops have been
// removed from graph. Check that the newly added While nodes don't feed
// into themselves.
for (const Node* node : graph->op_nodes()) {
if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");
}
}
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@ -24,7 +25,14 @@ namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators, suitable for XLA compilation. If lookup_library is provided, use
// it to make the library for control flow self-contained.
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library);
//
// If `node_filter` is defined, then only loops for whose nodes `node_filter`
// returns true are functionalized.
//
// Preconditions:
// Same as for `FunctionalizeControlFlow` (see comment there).
Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {});
} // namespace tensorflow