Make sure while body DT_RESOURCE _Retval comes from _Arg with same index.
PiperOrigin-RevId: 245816414
This commit is contained in:
parent
a6aa3e6c20
commit
b146fdcdf1
tensorflow/compiler
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -211,4 +212,67 @@ TEST_F(RearrangeFunctionArgumentForFunctionTest, Basic) {
|
||||
EXPECT_EQ(input_node->name(), "while");
|
||||
}
|
||||
|
||||
TEST_F(RearrangeFunctionArgumentForFunctionTest,
|
||||
WhileResourceRetvalFromDifferentArgUnimplemented) {
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
// Function for While's "body".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "ret0" = "arg1"
|
||||
// "ret1" = "arg0"
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
|
||||
auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
|
||||
auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Function for While's "cond".
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "ret0" = true
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
|
||||
auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
|
||||
}
|
||||
{
|
||||
// Build the XLA computation func.
|
||||
// "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
|
||||
// "arg0", "arg1" -> "while" (While)
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
|
||||
Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
|
||||
Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("f1");
|
||||
body_fn.set_name("f2");
|
||||
auto while_op = ops::While(s.WithOpName("while"),
|
||||
std::initializer_list<Input>{arg0, arg1, arg2},
|
||||
cond_fn, body_fn);
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(s.ToGraph(g.get()));
|
||||
FunctionDef *xla_fdef = fdl.add_function();
|
||||
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
|
||||
}
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
|
||||
|
||||
bool modified;
|
||||
protobuf::Map<string, tensorflow::AttrValue> attrs;
|
||||
Status s = RearrangeFunctionArgumentTest("cluster", "cluster_rewritten",
|
||||
attrs, &fld, &modified);
|
||||
EXPECT_EQ(s.code(), error::UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -309,6 +309,43 @@ Status MaybeRewriteWhileNode(Graph* g, Node* n, FunctionLibraryDefinition* fld,
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, AttrSlice(), fld, &fbody));
|
||||
|
||||
// Check that resource _Arg nodes for While node are always returned with
|
||||
// the same index, and we don't have cases like this:
|
||||
// tf.while_loop(
|
||||
// cond,
|
||||
// lambda resource_var1, resource_var2: [resource_var2, resource_var1],
|
||||
// [resource_var1, resource_var2])
|
||||
if (attr_name == "body") {
|
||||
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
|
||||
Node* n = fbody->ret_nodes[i];
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
|
||||
if (dtype != DT_RESOURCE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
|
||||
while (input_node->IsIdentity()) {
|
||||
TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node));
|
||||
}
|
||||
if (input_node->IsArg()) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
|
||||
if (index != i) {
|
||||
return errors::Unimplemented("While node ", n->DebugString(),
|
||||
" has resource _Retval[", i,
|
||||
"] coming from _Arg[", index, "]");
|
||||
}
|
||||
} else {
|
||||
return errors::Unimplemented("Encountered node ",
|
||||
input_node->DebugString(),
|
||||
" while tracing _Arg node for _Retval[",
|
||||
i, "] of while node ", n->DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
|
||||
if (attr_name == "body") {
|
||||
for (int i = 0; i < fbody->ret_nodes.size(); i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user