Do not assume number of outputs for While node is always the same as number of inputs.
PiperOrigin-RevId: 231620915
This commit is contained in:
parent
1924fe1c4d
commit
19c79e944b
@ -317,11 +317,13 @@ tf_cc_test(
|
||||
":tf2xla_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
||||
@ -265,6 +265,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
|
||||
}
|
||||
|
||||
// Check if i-th retval's input comes from i-th arg directly.
|
||||
// For resource variable input of While nodes, TF2XLA convention is to place
|
||||
// them at the end of all inputs (after all data inputs), and *not* return
|
||||
// them. So number of While node inputs might be larger than number of its
|
||||
// outputs.
|
||||
if (i >= body_func->signature().output_arg_size()) {
|
||||
continue;
|
||||
}
|
||||
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
|
||||
auto output_arg_input = body_func->ret().find(output_arg.name());
|
||||
if (output_arg_input == body_func->ret().end()) {
|
||||
|
||||
@ -21,11 +21,13 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -329,5 +331,37 @@ TEST(CachedFunctionHandles, Basic) {
|
||||
TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
|
||||
}
|
||||
|
||||
TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
||||
{
|
||||
// Cond graph & body graph.
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
|
||||
auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
FunctionDef cond_fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
|
||||
TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
|
||||
FunctionDef body_fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
|
||||
TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
|
||||
}
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
|
||||
auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("cond");
|
||||
body_fn.set_name("body");
|
||||
auto while_op =
|
||||
ops::While(scope.WithOpName("while"),
|
||||
std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
|
||||
TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user