In FunctionalizeCond pass, check if Const node has been added before copying over to Graph.
This caused an error in the generated FunctionDef from GraphToFunctionDef. PiperOrigin-RevId: 276343151 Change-Id: Id3a435f3195d3c35efdd8a3acdd7aa27c637cef0
This commit is contained in:
parent
2e84f72d2d
commit
4c9d9580c5
@ -706,6 +706,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -713,10 +714,14 @@ tf_cc_test(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -668,7 +668,11 @@ Status Conditional::ExtractBodies(Graph* graph) {
|
||||
// * constant nodes copy them;
|
||||
// * non-constant nodes, insert a switch along the edge;
|
||||
if (IsConstant(src)) {
|
||||
node_map.at(src->id()) = output->CopyNode(src);
|
||||
// Check if constant node was added already. It is possible to have
|
||||
// multiple uses of a constant node.
|
||||
if (node_map.at(src->id()) == nullptr) {
|
||||
node_map.at(src->id()) = output->CopyNode(src);
|
||||
}
|
||||
} else {
|
||||
StateMap::CondState state = *dst_id;
|
||||
state.erase(predicate_);
|
||||
|
@ -17,9 +17,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -112,6 +119,45 @@ TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) {
|
||||
EXPECT_FALSE(joined_or.ok());
|
||||
}
|
||||
|
||||
TEST(FunctionalizeCond, DuplicateConstNodes) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
auto const_op = ops::Const(root.WithOpName("const"), 1);
|
||||
auto arg_0_op = ops::_Arg(root.WithOpName("arg_0"), DT_BOOL, 0);
|
||||
auto arg_1_op = ops::_Arg(root.WithOpName("arg_1"), DT_INT32, 1);
|
||||
auto switch_op = ops::Switch(root.WithOpName("switch"), arg_1_op, arg_0_op);
|
||||
auto identity_n_false_op =
|
||||
ops::IdentityN(root.WithOpName("identity_n_0"),
|
||||
{switch_op.output_false, const_op, const_op});
|
||||
auto identity_n_true_op =
|
||||
ops::IdentityN(root.WithOpName("identity_n_1"),
|
||||
{switch_op.output_true, const_op, const_op});
|
||||
auto merge_op = ops::Merge(
|
||||
root.WithOpName("merge"),
|
||||
{identity_n_false_op.output.front(), identity_n_true_op.output.front()});
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
GraphConstructorOptions options;
|
||||
TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, &graph));
|
||||
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
|
||||
auto status = tensorflow::FunctionalizeCond(&graph, &flib_def);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
FunctionDefLibrary flib_def_proto = flib_def.ToProto();
|
||||
for (const auto& fdef : flib_def_proto.function()) {
|
||||
absl::flat_hash_set<absl::string_view> node_names;
|
||||
for (const auto& node : fdef.node_def()) {
|
||||
EXPECT_TRUE(node_names.insert(node.name()).second)
|
||||
<< node.op() << " with duplicate node name '" << node.name()
|
||||
<< "' found.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace functionalize_cond
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user