Make sure while body DT_RESOURCE _Retval comes from _Arg with same index.

PiperOrigin-RevId: 245816414
This commit is contained in:
Tong Shen 2019-04-29 13:52:17 -07:00 committed by TensorFlower Gardener
parent a6aa3e6c20
commit b146fdcdf1
2 changed files with 101 additions and 0 deletions

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@ -211,4 +212,67 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
EXPECT_EQ(input_node->name(), "while");
}
TEST_F(RearrangeFunctionArgumentForFunctionTest,
WhileResourceRetvalFromDifferentArgUnimplemented) {
FunctionDefLibrary fdl;
{
// Function for While's "body".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "ret0" = "arg1"
// "ret1" = "arg0"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
}
{
// Function for While's "cond".
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "ret0" = true
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
}
{
// Build the XLA computation func.
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
// "arg0", "arg1" -> "while" (While)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("f1");
body_fn.set_name("f2");
auto while_op = ops::While(s.WithOpName("while"),
std::initializer_list<Input>{arg0, arg1, arg2},
cond_fn, body_fn);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
bool modified;
protobuf::Map<string, tensorflow::AttrValue> attrs;
Status s = RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
attrs, &fld, &modified);
EXPECT_EQ(s.code(), error::UNIMPLEMENTED);
}
} // namespace tensorflow

View File

@ -309,6 +309,43 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
// Check that resource _Arg nodes for While node are always returned with
// the same index, and we don't have cases like this:
// tf.while_loop(
// cond,
// lambda resource_var1, resource_var2: [resource_var2, resource_var1],
// [resource_var1, resource_var2])
if (attr_name == "body") {
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
Node* n = fbody->ret_nodes[i];
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
if (dtype != DT_RESOURCE) {
continue;
}
Node* input_node;
TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
while (input_node->IsIdentity()) {
TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node));
}
if (input_node->IsArg()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
if (index != i) {
return errors::Unimplemented("While node ", n->DebugString(),
" has resource _Retval[", i,
"] coming from _Arg[", index, "]");
}
} else {
return errors::Unimplemented("Encountered node ",
input_node->DebugString(),
" while tracing _Arg node for _Retval[",
i, "] of while node ", n->DebugString());
}
}
}
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
if (attr_name == "body") {
for (int i = 0; i < fbody->ret_nodes.size(); i++) {