Fix bug with control deps in (experimental) NodeDef function bodies.
Change: 139494948
This commit is contained in:
parent
4dcd9c50c7
commit
f12b8253d6
@ -423,12 +423,21 @@ Status InstantiateNode(const NodeDef& fnode,
|
||||
return errors::InvalidArgument("Expected input[", i, "] == '", input,
|
||||
"' to be a control input.");
|
||||
}
|
||||
const NameInfoItem* item = gtl::FindOrNull(name_info, input.substr(1));
|
||||
if (item == nullptr) {
|
||||
int nid = -1;
|
||||
const string node_name = input.substr(1);
|
||||
const string node_colon = node_name + ":";
|
||||
for (const auto& p : name_info) {
|
||||
if (p.first == node_name ||
|
||||
tensorflow::StringPiece(p.first).starts_with(node_colon)) {
|
||||
nid = p.second.nid;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (nid == -1) {
|
||||
return errors::InvalidArgument("input[", i, "] == '", input,
|
||||
"', is not found.");
|
||||
}
|
||||
gnode->add_input(Dep(item->nid));
|
||||
gnode->add_input(Dep(nid));
|
||||
}
|
||||
|
||||
// Attrs.
|
||||
|
@ -137,6 +137,51 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, ControlDepNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"ControlDep",
|
||||
// Inputs
|
||||
{"x: int32"},
|
||||
// Outputs
|
||||
{"y: int32"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{// a = Identity<int32>(x)
|
||||
{{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
|
||||
// o = NoOp(^a)
|
||||
{{"o"}, "NoOp", {"^a"}, {}},
|
||||
// y = Identity<int32>(a, ^o)
|
||||
{{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
|
||||
// Returns
|
||||
{{"y", "y:output:0"}});
|
||||
|
||||
const char* e = R"P(
|
||||
ControlDep(x:int32) -> (y:int32) {
|
||||
a = Identity[T=int32](x)
|
||||
o = NoOp() @ a
|
||||
y = Identity[T=int32](a:output:0) @ o
|
||||
return y = y:output:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(n0:int32) -> (n3:int32) {
|
||||
n1 = Identity[T=int32](n0)
|
||||
n2 = NoOp() @ n1
|
||||
n3 = Identity[T=int32](n1) @ n2
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
REGISTER_OP("HasDefaultType")
|
||||
.Output("out: T")
|
||||
.Attr("T: {float, double, int32, int64} = DT_FLOAT");
|
||||
|
Loading…
Reference in New Issue
Block a user