Do not assume number of outputs for While node is always the same as number of inputs.

PiperOrigin-RevId: 231620915
This commit is contained in:
Tong Shen 2019-01-30 10:32:00 -08:00 committed by TensorFlower Gardener
parent 1924fe1c4d
commit 19c79e944b
3 changed files with 43 additions and 0 deletions

View File

@ -317,11 +317,13 @@ tf_cc_test(
":tf2xla_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -265,6 +265,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
}
// Check if i-th retval's input comes from i-th arg directly.
// For resource variable input of While nodes, TF2XLA convention is to place
// them at the end of all inputs (after all data inputs), and *not* return
// them. So number of While node inputs might be larger than number of its
// outputs.
if (i >= body_func->signature().output_arg_size()) {
continue;
}
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
auto output_arg_input = body_func->ret().find(output_arg.name());
if (output_arg_input == body_func->ret().end()) {

View File

@ -21,11 +21,13 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
@ -329,5 +331,37 @@ TEST(CachedFunctionHandles, Basic) {
TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
}
TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
{
// Cond graph & body graph.
Scope scope = Scope::NewRootScope().ExitOnError();
auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
FunctionDef cond_fdef;
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
FunctionDef body_fdef;
TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
NameAttrList cond_fn, body_fn;
cond_fn.set_name("cond");
body_fn.set_name("body");
auto while_op =
ops::While(scope.WithOpName("while"),
std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
}
} // namespace
} // namespace tensorflow