In FunctionalizeCond pass, check if Const node has been added before copying over to Graph.

This caused an error in the generated FunctionDef from GraphToFunctionDef.

PiperOrigin-RevId: 276343151
Change-Id: Id3a435f3195d3c35efdd8a3acdd7aa27c637cef0
This commit is contained in:
Andy Ly 2019-10-23 13:46:59 -07:00 committed by TensorFlower Gardener
parent 2e84f72d2d
commit 4c9d9580c5
3 changed files with 56 additions and 1 deletions

View File

@ -706,6 +706,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
@ -713,10 +714,14 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)

View File

@ -668,7 +668,11 @@ Status Conditional::ExtractBodies(Graph* graph) {
// * constant nodes copy them;
// * non-constant nodes, insert a switch along the edge;
if (IsConstant(src)) {
node_map.at(src->id()) = output->CopyNode(src);
// Check if constant node was added already. It is possible to have
// multiple uses of a constant node.
if (node_map.at(src->id()) == nullptr) {
node_map.at(src->id()) = output->CopyNode(src);
}
} else {
StateMap::CondState state = *dst_id;
state.erase(predicate_);

View File

@ -17,9 +17,16 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -112,6 +119,45 @@ TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) {
EXPECT_FALSE(joined_or.ok());
}
TEST(FunctionalizeCond, DuplicateConstNodes) {
Scope root = Scope::NewRootScope().ExitOnError();
auto const_op = ops::Const(root.WithOpName("const"), 1);
auto arg_0_op = ops::_Arg(root.WithOpName("arg_0"), DT_BOOL, 0);
auto arg_1_op = ops::_Arg(root.WithOpName("arg_1"), DT_INT32, 1);
auto switch_op = ops::Switch(root.WithOpName("switch"), arg_1_op, arg_0_op);
auto identity_n_false_op =
ops::IdentityN(root.WithOpName("identity_n_0"),
{switch_op.output_false, const_op, const_op});
auto identity_n_true_op =
ops::IdentityN(root.WithOpName("identity_n_1"),
{switch_op.output_true, const_op, const_op});
auto merge_op = ops::Merge(
root.WithOpName("merge"),
{identity_n_false_op.output.front(), identity_n_true_op.output.front()});
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
Graph graph(OpRegistry::Global());
GraphConstructorOptions options;
TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, &graph));
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
auto status = tensorflow::FunctionalizeCond(&graph, &flib_def);
TF_ASSERT_OK(status);
FunctionDefLibrary flib_def_proto = flib_def.ToProto();
for (const auto& fdef : flib_def_proto.function()) {
absl::flat_hash_set<absl::string_view> node_names;
for (const auto& node : fdef.node_def()) {
EXPECT_TRUE(node_names.insert(node.name()).second)
<< node.op() << " with duplicate node name '" << node.name()
<< "' found.";
}
}
}
} // namespace
} // namespace functionalize_cond
} // namespace tensorflow