Fix bug with control deps in (experimental) NodeDef function bodies.
Change: 139494948
This commit is contained in:
parent
4dcd9c50c7
commit
f12b8253d6
@ -423,12 +423,21 @@ Status InstantiateNode(const NodeDef& fnode,
|
|||||||
return errors::InvalidArgument("Expected input[", i, "] == '", input,
|
return errors::InvalidArgument("Expected input[", i, "] == '", input,
|
||||||
"' to be a control input.");
|
"' to be a control input.");
|
||||||
}
|
}
|
||||||
const NameInfoItem* item = gtl::FindOrNull(name_info, input.substr(1));
|
int nid = -1;
|
||||||
if (item == nullptr) {
|
const string node_name = input.substr(1);
|
||||||
|
const string node_colon = node_name + ":";
|
||||||
|
for (const auto& p : name_info) {
|
||||||
|
if (p.first == node_name ||
|
||||||
|
tensorflow::StringPiece(p.first).starts_with(node_colon)) {
|
||||||
|
nid = p.second.nid;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (nid == -1) {
|
||||||
return errors::InvalidArgument("input[", i, "] == '", input,
|
return errors::InvalidArgument("input[", i, "] == '", input,
|
||||||
"', is not found.");
|
"', is not found.");
|
||||||
}
|
}
|
||||||
gnode->add_input(Dep(item->nid));
|
gnode->add_input(Dep(nid));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attrs.
|
// Attrs.
|
||||||
|
@ -137,6 +137,51 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TFunc, ControlDepNodeDef) {
|
||||||
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
|
// Name
|
||||||
|
"ControlDep",
|
||||||
|
// Inputs
|
||||||
|
{"x: int32"},
|
||||||
|
// Outputs
|
||||||
|
{"y: int32"},
|
||||||
|
// Attrs
|
||||||
|
{},
|
||||||
|
// Nodes
|
||||||
|
{// a = Identity<int32>(x)
|
||||||
|
{{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
|
||||||
|
// o = NoOp(^a)
|
||||||
|
{{"o"}, "NoOp", {"^a"}, {}},
|
||||||
|
// y = Identity<int32>(a, ^o)
|
||||||
|
{{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
|
||||||
|
// Returns
|
||||||
|
{{"y", "y:output:0"}});
|
||||||
|
|
||||||
|
const char* e = R"P(
|
||||||
|
ControlDep(x:int32) -> (y:int32) {
|
||||||
|
a = Identity[T=int32](x)
|
||||||
|
o = NoOp() @ a
|
||||||
|
y = Identity[T=int32](a:output:0) @ o
|
||||||
|
return y = y:output:0
|
||||||
|
}
|
||||||
|
)P";
|
||||||
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
|
|
||||||
|
// Instantiate one with T=float
|
||||||
|
InstantiationResult result;
|
||||||
|
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||||
|
const char* e2 = R"P(
|
||||||
|
(n0:int32) -> (n3:int32) {
|
||||||
|
n1 = Identity[T=int32](n0)
|
||||||
|
n2 = NoOp() @ n1
|
||||||
|
n3 = Identity[T=int32](n1) @ n2
|
||||||
|
}
|
||||||
|
)P";
|
||||||
|
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
|
||||||
|
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
|
||||||
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_OP("HasDefaultType")
|
REGISTER_OP("HasDefaultType")
|
||||||
.Output("out: T")
|
.Output("out: T")
|
||||||
.Attr("T: {float, double, int32, int64} = DT_FLOAT");
|
.Attr("T: {float, double, int32, int64} = DT_FLOAT");
|
||||||
|
Loading…
Reference in New Issue
Block a user