Do not assume number of outputs for While node is always the same as number of inputs.
PiperOrigin-RevId: 231620915
This commit is contained in:
parent
1924fe1c4d
commit
19c79e944b
@ -317,11 +317,13 @@ tf_cc_test(
|
|||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:function_ops",
|
"//tensorflow/cc:function_ops",
|
||||||
|
"//tensorflow/cc:functional_ops",
|
||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:math_ops_op_lib",
|
"//tensorflow/core:math_ops_op_lib",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
|||||||
@ -265,6 +265,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if i-th retval's input comes from i-th arg directly.
|
// Check if i-th retval's input comes from i-th arg directly.
|
||||||
|
// For resource variable input of While nodes, TF2XLA convention is to place
|
||||||
|
// them at the end of all inputs (after all data inputs), and *not* return
|
||||||
|
// them. So number of While node inputs might be larger than number of its
|
||||||
|
// outputs.
|
||||||
|
if (i >= body_func->signature().output_arg_size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
|
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
|
||||||
auto output_arg_input = body_func->ret().find(output_arg.name());
|
auto output_arg_input = body_func->ret().find(output_arg.name());
|
||||||
if (output_arg_input == body_func->ret().end()) {
|
if (output_arg_input == body_func->ret().end()) {
|
||||||
|
|||||||
@ -21,11 +21,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||||
#include "tensorflow/cc/ops/function_ops.h"
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/functional_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -329,5 +331,37 @@ TEST(CachedFunctionHandles, Basic) {
|
|||||||
TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
|
TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
|
||||||
|
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
||||||
|
{
|
||||||
|
// Cond graph & body graph.
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
|
||||||
|
auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
|
||||||
|
auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
|
FunctionDef cond_fdef;
|
||||||
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
|
||||||
|
TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
|
||||||
|
FunctionDef body_fdef;
|
||||||
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
|
||||||
|
TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
|
||||||
|
}
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
|
||||||
|
auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
|
||||||
|
NameAttrList cond_fn, body_fn;
|
||||||
|
cond_fn.set_name("cond");
|
||||||
|
body_fn.set_name("body");
|
||||||
|
auto while_op =
|
||||||
|
ops::While(scope.WithOpName("while"),
|
||||||
|
std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user