diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 5d260c62ec9..8874e99078a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -423,12 +423,21 @@ Status InstantiateNode(const NodeDef& fnode, return errors::InvalidArgument("Expected input[", i, "] == '", input, "' to be a control input."); } - const NameInfoItem* item = gtl::FindOrNull(name_info, input.substr(1)); - if (item == nullptr) { + int nid = -1; + const string node_name = input.substr(1); + const string node_colon = node_name + ":"; + for (const auto& p : name_info) { + if (p.first == node_name || + tensorflow::StringPiece(p.first).starts_with(node_colon)) { + nid = p.second.nid; + break; + } + } + if (nid == -1) { return errors::InvalidArgument("input[", i, "] == '", input, "', is not found."); } - gnode->add_input(Dep(item->nid)); + gnode->add_input(Dep(nid)); } // Attrs. diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 8fe11df7294..eb5aa9a5343 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -137,6 +137,51 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { EXPECT_EQ(DebugString(result.gdef), e2); } +TEST(TFunc, ControlDepNodeDef) { + auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. + // Name + "ControlDep", + // Inputs + {"x: int32"}, + // Outputs + {"y: int32"}, + // Attrs + {}, + // Nodes + {// a = Identity(x) + {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}}, + // o = NoOp(^a) + {{"o"}, "NoOp", {"^a"}, {}}, + // y = Identity(a, ^o) + {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}}, + // Returns + {{"y", "y:output:0"}}); + + const char* e = R"P( +ControlDep(x:int32) -> (y:int32) { + a = Identity[T=int32](x) + o = NoOp() @ a + y = Identity[T=int32](a:output:0) @ o + return y = y:output:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + const char* e2 = R"P( +(n0:int32) -> (n3:int32) { + n1 = Identity[T=int32](n0) + n2 = NoOp() @ n1 + n3 = Identity[T=int32](n1) @ n2 +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + REGISTER_OP("HasDefaultType") .Output("out: T") .Attr("T: {float, double, int32, int64} = DT_FLOAT");