Add an optimization for While V2.
Find the following pattern in the graph: 1) EmptyTensorList -> forward While op -> backward While op, 2) in forward While op, a Const node is pushed, 3) in backward While op, data is popped from the tensor list. And rewrites backward While op to use Const node instead of TensorListPopBack result. PiperOrigin-RevId: 248382047
This commit is contained in:
parent
4fa7eec9e5
commit
5e52f95bdb
@ -786,4 +786,144 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RewriteTensorListWithConstElement(Graph* g,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->type_string() != "EmptyTensorList") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find the forward While op.
|
||||
std::vector<const Edge*> fwd_while_edges;
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (!e->IsControlEdge() && e->dst()->type_string() == "While") {
|
||||
fwd_while_edges.push_back(e);
|
||||
}
|
||||
}
|
||||
if (fwd_while_edges.size() != 1) {
|
||||
// No forward While op found, or multiple forward While ops.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find the backward While op.
|
||||
Node* fwd_while = fwd_while_edges[0]->dst();
|
||||
int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
|
||||
std::vector<const Edge*> bwd_while_edges;
|
||||
for (const Edge* e : fwd_while->out_edges()) {
|
||||
if (e->src_output() == fwd_while_dst_input &&
|
||||
e->dst()->type_string() == "While") {
|
||||
bwd_while_edges.push_back(e);
|
||||
}
|
||||
}
|
||||
if (bwd_while_edges.size() != 1) {
|
||||
// No backward While op found, or multiple backward While ops.
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* bwd_while = bwd_while_edges[0]->dst();
|
||||
int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
|
||||
|
||||
// Look into forward While body function and check if TensorListPushBack op
|
||||
// has a Const input.
|
||||
NameAttrList fwd_body_attr;
|
||||
TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
|
||||
const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
|
||||
if (!fwd_body) {
|
||||
return errors::InvalidArgument("Cannot find function ",
|
||||
fwd_body_attr.name(), " for While node ",
|
||||
fwd_while->DebugString());
|
||||
}
|
||||
std::unique_ptr<FunctionBody> fwd_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(
|
||||
*fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
|
||||
|
||||
// Find the TensorListPushBack node; it's one of fwd_arg's successors.
|
||||
Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
|
||||
std::vector<Node*> tl_push_nodes;
|
||||
for (const Edge* out_edge : fwd_arg->out_edges()) {
|
||||
if (out_edge->dst()->type_string() == "TensorListPushBack") {
|
||||
tl_push_nodes.push_back(out_edge->dst());
|
||||
}
|
||||
}
|
||||
if (tl_push_nodes.size() != 1) {
|
||||
// No TensorListPushBack found, or multiple TensorListPushBack.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input for the TensorListPushBack node.
|
||||
Node* input_node;
|
||||
TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
|
||||
if (input_node->type_string() != "Const") {
|
||||
// Input for the TensorList is not Const node.
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef const_input_nodedef = input_node->def();
|
||||
|
||||
// Rewrite backward While body function, replace usages of
|
||||
// TensorListPopBack with a Const node.
|
||||
NameAttrList bwd_body_attr;
|
||||
TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
|
||||
const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
|
||||
if (!bwd_body) {
|
||||
return errors::InvalidArgument("Cannot find function ",
|
||||
bwd_body_attr.name(), " for While node ",
|
||||
bwd_while->DebugString());
|
||||
}
|
||||
std::unique_ptr<FunctionBody> bwd_fbody;
|
||||
TF_CHECK_OK(FunctionDefToBodyHelper(
|
||||
*bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
|
||||
|
||||
// Find the TensorListPopBack node; it's one of bwd_arg's successors.
|
||||
Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
|
||||
std::vector<Node*> tl_pop_nodes;
|
||||
for (const Edge* out_edge : bwd_arg->out_edges()) {
|
||||
if (out_edge->dst()->type_string() == "TensorListPopBack") {
|
||||
tl_pop_nodes.push_back(out_edge->dst());
|
||||
}
|
||||
}
|
||||
if (tl_pop_nodes.size() != 1) {
|
||||
// No TensorListPopBack found, or multiple TensorListPopBack.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Replace TensorListPopBack usages with Const node.
|
||||
std::vector<const Edge*> edges_to_replace;
|
||||
for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
|
||||
if (e->src_output() == 1) {
|
||||
edges_to_replace.push_back(e);
|
||||
}
|
||||
}
|
||||
if (edges_to_replace.empty()) {
|
||||
continue;
|
||||
}
|
||||
Status s;
|
||||
const_input_nodedef.set_name(
|
||||
bwd_fbody->graph->NewName(const_input_nodedef.name()));
|
||||
Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
for (const Edge* e : edges_to_replace) {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
bwd_fbody->graph->RemoveEdge(e);
|
||||
bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
|
||||
}
|
||||
|
||||
// Add rewritten backward While body function.
|
||||
FunctionDef new_fdef;
|
||||
string new_name = fld->UniqueFunctionName(
|
||||
absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
|
||||
|
||||
// Change backward While op to use the new body function.
|
||||
bwd_body_attr.set_name(new_name);
|
||||
bwd_while->ClearAttr("body");
|
||||
bwd_while->AddAttr("body", bwd_body_attr);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -202,6 +202,16 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
// Finds the following pattern in the graph:
|
||||
// 1) EmptyTensorList -> forward While op -> backward While op,
|
||||
// 2) in forward While op, a Const node is pushed,
|
||||
// 3) in backward While op, data is popped from the tensor list.
|
||||
// And rewrites backward While op to use Const node instead of TensorListPopBack
|
||||
// result.
|
||||
// TODO(b/128633174) remove the TensorList and related TensorList ops.
|
||||
Status RewriteTensorListWithConstElement(Graph* g,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
@ -22,8 +22,10 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/list_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -416,5 +418,86 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) {
|
||||
EXPECT_EQ(const_def->second.op(), "Const");
|
||||
}
|
||||
|
||||
TEST(PropagateConstIntoFunctionalNodes, RewriteTensorListWithConstMember) {
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
||||
{
|
||||
// Cond graph
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
auto result =
|
||||
ops::Const(scope.WithOpName("result"), false, TensorShape({}));
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), result, 0);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
FunctionDef fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &fdef));
|
||||
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
||||
}
|
||||
{
|
||||
// Forward body graph
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
auto element = ops::Const(scope.WithOpName("element"), 0, TensorShape({}));
|
||||
auto push =
|
||||
ops::TensorListPushBack(scope.WithOpName("push"), input, element);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), push.output_handle, 0);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
FunctionDef fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(graph, "fwd_body", &fdef));
|
||||
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
||||
}
|
||||
{
|
||||
// Backward body graph
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
|
||||
auto pop =
|
||||
ops::TensorListPopBack(scope.WithOpName("pop"), input, shape, DT_INT32);
|
||||
auto identity = ops::Identity(scope.WithOpName("identity"), pop.tensor);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), pop.output_handle, 0);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
FunctionDef fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(graph, "bwd_body", &fdef));
|
||||
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
||||
}
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
|
||||
auto max_num_elements =
|
||||
ops::Const(scope.WithOpName("max_num_elements"), 10, TensorShape({}));
|
||||
auto tl = ops::EmptyTensorList(scope.WithOpName("tl"), shape,
|
||||
max_num_elements, DT_INT32);
|
||||
NameAttrList cond_fn, fwd_body_fn, bwd_body_fn;
|
||||
cond_fn.set_name("cond");
|
||||
fwd_body_fn.set_name("fwd_body");
|
||||
bwd_body_fn.set_name("bwd_body");
|
||||
auto fwd_while_op =
|
||||
ops::While(scope.WithOpName("fwd_while"),
|
||||
std::initializer_list<Input>{tl}, cond_fn, fwd_body_fn);
|
||||
auto bwd_while_op =
|
||||
ops::While(scope.WithOpName("bwd_while"),
|
||||
std::initializer_list<Input>{fwd_while_op.output[0]}, cond_fn,
|
||||
bwd_body_fn);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
|
||||
TF_EXPECT_OK(RewriteTensorListWithConstElement(&graph, &fld));
|
||||
|
||||
// Check that in rewritten backward While body function, the Identity node now
|
||||
// has Const node as input.
|
||||
const FunctionDef* bwd_body = fld.Find("bwd_body_tl_rewrite_0");
|
||||
ASSERT_NE(bwd_body, nullptr);
|
||||
std::unique_ptr<FunctionBody> bwd_fbody;
|
||||
TF_CHECK_OK(
|
||||
FunctionDefToBodyHelper(*bwd_body, AttrSlice(), &fld, &bwd_fbody));
|
||||
auto node_name_index = bwd_fbody->graph->BuildNodeNameIndex();
|
||||
const Node* identity = node_name_index.at("identity");
|
||||
ASSERT_NE(identity, nullptr);
|
||||
const Node* input;
|
||||
TF_ASSERT_OK(identity->input_node(0, &input));
|
||||
EXPECT_EQ(input->type_string(), "Const");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user