diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 92ba474fbcd..14bd9d4b138 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -317,11 +317,13 @@ tf_cc_test( ":tf2xla_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 18d87727c50..c64f78e1a1b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -265,6 +265,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, } // Check if i-th retval's input comes from i-th arg directly. + // For resource variable input of While nodes, TF2XLA convention is to place + // them at the end of all inputs (after all data inputs), and *not* return + // them. So number of While node inputs might be larger than number of its + // outputs. + if (i >= body_func->signature().output_arg_size()) { + continue; + } const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); auto output_arg_input = body_func->ret().find(output_arg.name()); if (output_arg_input == body_func->ret().end()) { diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 202e929315c..9e9c3cecee6 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -329,5 +331,37 @@ TEST(CachedFunctionHandles, Basic) { TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); } +TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); +} + } // namespace } // namespace tensorflow