Add an optimization for While V2.

Find the following pattern in the graph:
1) EmptyTensorList -> forward While op -> backward While op,
2) in forward While op, a Const node is pushed,
3) in backward While op, data is popped from the tensor list.
And rewrites backward While op to use Const node instead of TensorListPopBack result.

PiperOrigin-RevId: 248382047
This commit is contained in:
Tong Shen 2019-05-15 12:25:48 -07:00 committed by TensorFlower Gardener
parent 4fa7eec9e5
commit 5e52f95bdb
3 changed files with 233 additions and 0 deletions

View File

@ -786,4 +786,144 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g,
}
return Status::OK();
}
Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld) {
for (Node* n : g->nodes()) {
if (n->type_string() != "EmptyTensorList") {
continue;
}
// Find the forward While op.
std::vector<const Edge*> fwd_while_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge() && e->dst()->type_string() == "While") {
fwd_while_edges.push_back(e);
}
}
if (fwd_while_edges.size() != 1) {
// No forward While op found, or multiple forward While ops.
continue;
}
// Find the backward While op.
Node* fwd_while = fwd_while_edges[0]->dst();
int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
std::vector<const Edge*> bwd_while_edges;
for (const Edge* e : fwd_while->out_edges()) {
if (e->src_output() == fwd_while_dst_input &&
e->dst()->type_string() == "While") {
bwd_while_edges.push_back(e);
}
}
if (bwd_while_edges.size() != 1) {
// No backward While op found, or multiple backward While ops.
continue;
}
Node* bwd_while = bwd_while_edges[0]->dst();
int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
// Look into forward While body function and check if TensorListPushBack op
// has a Const input.
NameAttrList fwd_body_attr;
TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
if (!fwd_body) {
return errors::InvalidArgument("Cannot find function ",
fwd_body_attr.name(), " for While node ",
fwd_while->DebugString());
}
std::unique_ptr<FunctionBody> fwd_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
// Find the TensorListPushBack node; it's one of fwd_arg's successors.
Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
std::vector<Node*> tl_push_nodes;
for (const Edge* out_edge : fwd_arg->out_edges()) {
if (out_edge->dst()->type_string() == "TensorListPushBack") {
tl_push_nodes.push_back(out_edge->dst());
}
}
if (tl_push_nodes.size() != 1) {
// No TensorListPushBack found, or multiple TensorListPushBack.
continue;
}
// Get input for the TensorListPushBack node.
Node* input_node;
TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
if (input_node->type_string() != "Const") {
// Input for the TensorList is not Const node.
continue;
}
NodeDef const_input_nodedef = input_node->def();
// Rewrite backward While body function, replace usages of
// TensorListPopBack with a Const node.
NameAttrList bwd_body_attr;
TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
if (!bwd_body) {
return errors::InvalidArgument("Cannot find function ",
bwd_body_attr.name(), " for While node ",
bwd_while->DebugString());
}
std::unique_ptr<FunctionBody> bwd_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(
*bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
// Find the TensorListPopBack node; it's one of bwd_arg's successors.
Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
std::vector<Node*> tl_pop_nodes;
for (const Edge* out_edge : bwd_arg->out_edges()) {
if (out_edge->dst()->type_string() == "TensorListPopBack") {
tl_pop_nodes.push_back(out_edge->dst());
}
}
if (tl_pop_nodes.size() != 1) {
// No TensorListPopBack found, or multiple TensorListPopBack.
continue;
}
// Replace TensorListPopBack usages with Const node.
std::vector<const Edge*> edges_to_replace;
for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
if (e->src_output() == 1) {
edges_to_replace.push_back(e);
}
}
if (edges_to_replace.empty()) {
continue;
}
Status s;
const_input_nodedef.set_name(
bwd_fbody->graph->NewName(const_input_nodedef.name()));
Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : edges_to_replace) {
Node* dst = e->dst();
int dst_input = e->dst_input();
bwd_fbody->graph->RemoveEdge(e);
bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
}
// Add rewritten backward While body function.
FunctionDef new_fdef;
string new_name = fld->UniqueFunctionName(
absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change backward While op to use the new body function.
bwd_body_attr.set_name(new_name);
bwd_while->ClearAttr("body");
bwd_while->AddAttr("body", bwd_body_attr);
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -202,6 +202,16 @@ Status PropagateConstIntoFunctionalNodes(
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld);
// Finds the following pattern in the graph:
// 1) EmptyTensorList -> forward While op -> backward While op,
// 2) in forward While op, a Const node is pushed,
// 3) in backward While op, data is popped from the tensor list.
// And rewrites backward While op to use Const node instead of TensorListPopBack
// result.
// TODO(b/128633174) remove the TensorList and related TensorList ops.
Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -22,8 +22,10 @@ limitations under the License.
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function.h"
@ -416,5 +418,86 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) {
EXPECT_EQ(const_def->second.op(), "Const");
}
TEST(PropagateConstIntoFunctionalNodes, RewriteTensorListWithConstMember) {
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
{
// Cond graph
Scope scope = Scope::NewRootScope().ExitOnError();
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto result =
ops::Const(scope.WithOpName("result"), false, TensorShape({}));
auto ret = ops::_Retval(scope.WithOpName("ret"), result, 0);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &fdef));
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
}
{
// Forward body graph
Scope scope = Scope::NewRootScope().ExitOnError();
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto element = ops::Const(scope.WithOpName("element"), 0, TensorShape({}));
auto push =
ops::TensorListPushBack(scope.WithOpName("push"), input, element);
auto ret = ops::_Retval(scope.WithOpName("ret"), push.output_handle, 0);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(graph, "fwd_body", &fdef));
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
}
{
// Backward body graph
Scope scope = Scope::NewRootScope().ExitOnError();
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
auto pop =
ops::TensorListPopBack(scope.WithOpName("pop"), input, shape, DT_INT32);
auto identity = ops::Identity(scope.WithOpName("identity"), pop.tensor);
auto ret = ops::_Retval(scope.WithOpName("ret"), pop.output_handle, 0);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(graph, "bwd_body", &fdef));
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
auto max_num_elements =
ops::Const(scope.WithOpName("max_num_elements"), 10, TensorShape({}));
auto tl = ops::EmptyTensorList(scope.WithOpName("tl"), shape,
max_num_elements, DT_INT32);
NameAttrList cond_fn, fwd_body_fn, bwd_body_fn;
cond_fn.set_name("cond");
fwd_body_fn.set_name("fwd_body");
bwd_body_fn.set_name("bwd_body");
auto fwd_while_op =
ops::While(scope.WithOpName("fwd_while"),
std::initializer_list<Input>{tl}, cond_fn, fwd_body_fn);
auto bwd_while_op =
ops::While(scope.WithOpName("bwd_while"),
std::initializer_list<Input>{fwd_while_op.output[0]}, cond_fn,
bwd_body_fn);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
TF_EXPECT_OK(RewriteTensorListWithConstElement(&graph, &fld));
// Check that in rewritten backward While body function, the Identity node now
// has Const node as input.
const FunctionDef* bwd_body = fld.Find("bwd_body_tl_rewrite_0");
ASSERT_NE(bwd_body, nullptr);
std::unique_ptr<FunctionBody> bwd_fbody;
TF_CHECK_OK(
FunctionDefToBodyHelper(*bwd_body, AttrSlice(), &fld, &bwd_fbody));
auto node_name_index = bwd_fbody->graph->BuildNodeNameIndex();
const Node* identity = node_name_index.at("identity");
ASSERT_NE(identity, nullptr);
const Node* input;
TF_ASSERT_OK(identity->input_node(0, &input));
EXPECT_EQ(input->type_string(), "Const");
}
} // namespace
} // namespace tensorflow