Fix bug with control deps in (experimental) NodeDef function bodies.

Change: 139494948
This commit is contained in:
A. Unique TensorFlower 2016-11-17 12:50:25 -08:00 committed by TensorFlower Gardener
parent 4dcd9c50c7
commit f12b8253d6
2 changed files with 57 additions and 3 deletions

View File

@ -423,12 +423,21 @@ Status InstantiateNode(const NodeDef& fnode,
return errors::InvalidArgument("Expected input[", i, "] == '", input,
"' to be a control input.");
}
const NameInfoItem* item = gtl::FindOrNull(name_info, input.substr(1));
if (item == nullptr) {
int nid = -1;
const string node_name = input.substr(1);
const string node_colon = node_name + ":";
for (const auto& p : name_info) {
if (p.first == node_name ||
tensorflow::StringPiece(p.first).starts_with(node_colon)) {
nid = p.second.nid;
break;
}
}
if (nid == -1) {
return errors::InvalidArgument("input[", i, "] == '", input,
"', is not found.");
}
gnode->add_input(Dep(item->nid));
gnode->add_input(Dep(nid));
}
// Attrs.

View File

@ -137,6 +137,51 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, ControlDepNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"ControlDep",
// Inputs
{"x: int32"},
// Outputs
{"y: int32"},
// Attrs
{},
// Nodes
{// a = Identity<int32>(x)
{{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
// o = NoOp(^a)
{{"o"}, "NoOp", {"^a"}, {}},
// y = Identity<int32>(a, ^o)
{{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
// Returns
{{"y", "y:output:0"}});
const char* e = R"P(
ControlDep(x:int32) -> (y:int32) {
a = Identity[T=int32](x)
o = NoOp() @ a
y = Identity[T=int32](a:output:0) @ o
return y = y:output:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
// Instantiate one with T=float
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
const char* e2 = R"P(
(n0:int32) -> (n3:int32) {
n1 = Identity[T=int32](n0)
n2 = NoOp() @ n1
n3 = Identity[T=int32](n1) @ n2
}
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(DebugString(result.gdef), e2);
}
REGISTER_OP("HasDefaultType")
.Output("out: T")
.Attr("T: {float, double, int32, int64} = DT_FLOAT");