Delete the FunctionDef.Node code paths, now that we have switched
to the NodeDef representation. Change: 144281952
This commit is contained in:
parent
ded01d2902
commit
873473ef0f
@ -211,43 +211,6 @@ Status AddRetName(NameInfoIndex* name_info, const string& ret,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildNodeOutputIndex(const FunctionDef::Node& node,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
const int arg_index, NameInfoIndex* name_info) {
|
||||
const OpDef* node_sig = nullptr;
|
||||
TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
|
||||
if (node_sig->output_arg_size() == 0) {
|
||||
// This node produces no output.
|
||||
if (node.ret_size() != 1) {
|
||||
return errors::InvalidArgument("Expect one ret name.");
|
||||
}
|
||||
return AddRetName(name_info, node.ret(0), {false, arg_index, 0, false, {}});
|
||||
}
|
||||
const int num_retval = node_sig->output_arg_size();
|
||||
if (num_retval != node.ret_size()) {
|
||||
return errors::InvalidArgument("Malformed function node (#ret): ",
|
||||
num_retval, " vs. ", node.ret_size());
|
||||
}
|
||||
int start = 0;
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
for (int i = 0; i < num_retval; ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetName(name_info, node.ret(i),
|
||||
{false, arg_index, start, is_type_list, dtypes}));
|
||||
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetName(name_info, strings::StrCat(node.ret(i), ":", j),
|
||||
{false, arg_index, start + j, false, {dtypes[j]}}));
|
||||
}
|
||||
start += dtypes.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildNodeOutputIndex(const NodeDef& node,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
@ -280,85 +243,6 @@ Status BuildNodeOutputIndex(const NodeDef& node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateNode(const FunctionDef::Node& fnode,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
const NameInfoIndex& name_info, GraphDef* gdef) {
|
||||
const OpDef* fnode_sig = nullptr;
|
||||
TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
|
||||
NodeDef* gnode = gdef->add_node();
|
||||
gnode->set_name(Name(gdef->node_size() - 1));
|
||||
gnode->set_op(fnode.op());
|
||||
|
||||
// Input
|
||||
const int num_args = fnode_sig->input_arg_size();
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
int fnode_arg_index = 0;
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
|
||||
if (!is_type_list) {
|
||||
const NameInfoItem* item =
|
||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("arg[", i, "] is not found: ",
|
||||
ProtoShortDebugString(fnode));
|
||||
}
|
||||
if (dtypes != item->dtypes) {
|
||||
return errors::InvalidArgument("Invalid arg(", i,
|
||||
") for function arg: ",
|
||||
DataTypeSliceString(dtypes), " vs. ",
|
||||
DataTypeSliceString(item->dtypes), ".");
|
||||
}
|
||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
||||
if (item->is_func_arg) {
|
||||
gnode->add_input(Name(item->nid + j));
|
||||
} else {
|
||||
gnode->add_input(Name(item->nid, item->idx + j));
|
||||
}
|
||||
}
|
||||
++fnode_arg_index;
|
||||
} else {
|
||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
||||
const NameInfoItem* item =
|
||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index + j));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("arg[", i + j, "] is not found: ",
|
||||
ProtoShortDebugString(fnode));
|
||||
}
|
||||
if (item->dtypes.size() != 1 || (item->dtypes[0] != dtypes[j])) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid typelist arg(", i + j, ") for function arg: ",
|
||||
DataTypeSliceString(dtypes), " vs. ",
|
||||
DataTypeSliceString(item->dtypes), ".");
|
||||
}
|
||||
if (item->is_func_arg) {
|
||||
gnode->add_input(Name(item->nid));
|
||||
} else {
|
||||
gnode->add_input(Name(item->nid, item->idx));
|
||||
}
|
||||
}
|
||||
fnode_arg_index += dtypes.size();
|
||||
}
|
||||
}
|
||||
// Control deps.
|
||||
for (int i = 0; i < fnode.dep_size(); ++i) {
|
||||
const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("dep[", i, "] is not found.");
|
||||
}
|
||||
gnode->add_input(Dep(item->nid));
|
||||
}
|
||||
|
||||
// Attrs.
|
||||
for (const auto& p : attrs) {
|
||||
(*gnode->mutable_attr())[p.first] = p.second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateNode(const NodeDef& fnode,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
@ -448,38 +332,6 @@ Status InstantiateNode(const NodeDef& fnode,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// FunctionDef::Node version
|
||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
const NameInfoIndex& name_info, int* ret_index,
|
||||
InstantiationResult* result) {
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
||||
CHECK_GE(dtypes.size(), size_t{1});
|
||||
const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("ret is not found.");
|
||||
}
|
||||
if (dtypes != item->dtypes) {
|
||||
return errors::InvalidArgument("Invalid ret types ", ret_def.name(), " : ",
|
||||
DataTypeVectorString(dtypes), " vs. ",
|
||||
DataTypeVectorString(item->dtypes));
|
||||
}
|
||||
GraphDef* gdef = &result->gdef;
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
NodeDef* gnode = gdef->add_node();
|
||||
gnode->set_name(Name(gdef->node_size() - 1));
|
||||
gnode->set_op("_Retval");
|
||||
gnode->add_input(Name(item->nid, item->idx + i));
|
||||
AddAttr("T", dtypes[i], gnode);
|
||||
AddAttr("index", (*ret_index)++, gnode);
|
||||
result->ret_types.push_back(dtypes[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// NodeDef version
|
||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
||||
@ -561,38 +413,6 @@ string Print(const AttrValue& attr_value) {
|
||||
return SummarizeAttrValue(attr_value);
|
||||
}
|
||||
|
||||
string Print(const FunctionDef::Node& node) {
|
||||
string out;
|
||||
for (int i = 0; i < node.ret_size(); ++i) {
|
||||
const auto& name = node.ret(i);
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, name);
|
||||
}
|
||||
strings::StrAppend(&out, " = ", node.op());
|
||||
if (node.attr_size() > 0) {
|
||||
std::vector<string> entries;
|
||||
for (auto p : node.attr()) {
|
||||
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
||||
}
|
||||
sort(entries.begin(), entries.end());
|
||||
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
|
||||
}
|
||||
strings::StrAppend(&out, "(");
|
||||
for (int i = 0; i < node.arg_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, node.arg(i));
|
||||
}
|
||||
strings::StrAppend(&out, ")");
|
||||
if (node.dep_size() > 0) {
|
||||
strings::StrAppend(&out, " @ ");
|
||||
for (int i = 0; i < node.dep_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, node.dep(i));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO(josh11b): Merge this with SummarizeNodeDef().
|
||||
string Print(const NodeDef& n) {
|
||||
string out;
|
||||
@ -650,18 +470,12 @@ string Print(const FunctionDef& fdef) {
|
||||
strings::StrAppend(&out, Print(sig.output_arg(i)));
|
||||
}
|
||||
strings::StrAppend(&out, ") {\n");
|
||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
||||
for (const auto& n : fdef.node_def()) {
|
||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||
}
|
||||
for (const auto& r : fdef.ret()) {
|
||||
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually remove this case.
|
||||
for (const auto& n : fdef.node()) {
|
||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&out, "}\n");
|
||||
return out;
|
||||
}
|
||||
@ -772,7 +586,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
|
||||
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
||||
// After this step, every attr is bound to a concrete value.
|
||||
std::vector<InstantiateAttrValueMap> node_attrs;
|
||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
||||
node_attrs.resize(fdef.node_def_size());
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
for (auto attr : fdef.node_def(i).attr()) {
|
||||
@ -816,50 +629,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
|
||||
return s;
|
||||
}
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually remove this case.
|
||||
node_attrs.resize(fdef.node_size());
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
for (auto attr : fdef.node(i).attr()) {
|
||||
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
||||
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
||||
SummarizeAttrValue(attr.second));
|
||||
}
|
||||
if (!node_attrs[i].insert(attr).second) {
|
||||
return errors::Internal("Somehow duplicated: ", attr.first);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddDefaultAttrs(fdef.node(i).op(), get_function, &node_attrs[i]));
|
||||
}
|
||||
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
|
||||
gdef->node_size() + i, &name_info);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node.
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
|
||||
gdef);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
// Emits nodes for the function's return values.
|
||||
int ret_index = 0;
|
||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||
s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -30,61 +30,7 @@ message FunctionDef {
|
||||
// Attributes specific to this function definition.
|
||||
map<string, AttrValue> attr = 5;
|
||||
|
||||
// TO BE REPLACED
|
||||
|
||||
// The body of the function.
|
||||
repeated Node node = 2; // function.node.ret[*] are unique.
|
||||
|
||||
// A node is a multi-value assignment:
|
||||
// (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
|
||||
//
|
||||
// By convention, "func" is resolved by consulting with a user-defined
|
||||
// library first. If not resolved, "func" is assumed to be a builtin op.
|
||||
message Node {
|
||||
// This node produces multiple outputs. They are named ret[0],
|
||||
// ret[1], ..., etc.
|
||||
//
|
||||
// REQUIRES: function.node.ret[*] are unique across all nodes.
|
||||
// REQUIRES: ret.size == func/op def's number of output args.
|
||||
repeated string ret = 1;
|
||||
|
||||
// The op/function name.
|
||||
string op = 2;
|
||||
|
||||
// Arguments passed to this func/op.
|
||||
//
|
||||
// arg[i] must be either one of
|
||||
// function.signature.input_args[*].name or one of
|
||||
// function.node[*].ret[*].
|
||||
//
|
||||
// REQUIRES: arg.size == func/op def's number of input args.
|
||||
repeated string arg = 3;
|
||||
|
||||
// Control dependencies.
|
||||
//
|
||||
// dep[i] must be one of function.node[*].ret[*] or one of
|
||||
// function.signature.input_args[*].name.
|
||||
repeated string dep = 4;
|
||||
|
||||
// Attrs.
|
||||
//
|
||||
// 'attr' maps names defined by 'func's attr defs to attr values.
|
||||
// attr values may have placeholders which are substituted
|
||||
// recursively by concrete values when this node is instantiated.
|
||||
// These placeholders must name an attr listed in the FunctionDef's
|
||||
// signature.
|
||||
map<string, AttrValue> attr = 5;
|
||||
}
|
||||
|
||||
// WILL REPLACE THE ABOVE
|
||||
|
||||
// If node_def is present, and the consumer is at GraphDef version
|
||||
// >= 12, then these fields are used and `node` is ignored. If the
|
||||
// consumer's GraphDef version is < 12 or this field is empty, then
|
||||
// `node` is used. This allows producers to fill both fields to
|
||||
// remain compatible with old consumers. At some future GraphDef
|
||||
// version, `node` will be ignored even if `node_def` is empty.
|
||||
// TODO(josh11b): Finish this transition.
|
||||
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
|
||||
|
||||
// In both of the following fields, there is the need to specify an
|
||||
// output that is used as either the input to another node (in
|
||||
@ -120,6 +66,10 @@ message FunctionDef {
|
||||
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
||||
// may have values of type `placeholder` and the `input` field uses
|
||||
// the "output" format above.
|
||||
|
||||
// By convention, "op" in node_def is resolved by consulting with a
|
||||
// user-defined library first. If not resolved, "func" is assumed to
|
||||
// be a builtin op.
|
||||
repeated NodeDef node_def = 3;
|
||||
|
||||
// A mapping from the output arg names from `signature` to the
|
||||
|
@ -48,52 +48,8 @@ y: A scalar in type T.
|
||||
|
||||
static InstantiateAttrValueMap kNoAttrs;
|
||||
|
||||
TEST(TFunc, SquarePlusOneOld) {
|
||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
||||
// Name
|
||||
"SquarePlusOne",
|
||||
// Args
|
||||
{"x: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attrs
|
||||
{"T: {float, double, int32, int64}"},
|
||||
// Nodes
|
||||
{// a = Square<T>(x)
|
||||
{{"a"}, "Square", {"x"}, {{"T", "$T"}}},
|
||||
// o = One<T>()
|
||||
// NOTE: We can also have a Cast<Tin, Tout>(x) instead.
|
||||
{{"o"}, "One", {}, {{"T", "$T"}}},
|
||||
// y = Add<T>(a, o)
|
||||
{{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
|
||||
|
||||
const char* e = R"P(
|
||||
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
a = Square[T=$T](x)
|
||||
o = One[T=$T]()
|
||||
y = Add[T=$T](a:y:0, o:y:0)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(n0:float) -> (n3:float) {
|
||||
n1 = Square[T=float](n0)
|
||||
n2 = One[T=float]()
|
||||
n3 = Add[T=float](n1, n2)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, SquarePlusOneNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, SquarePlusOne) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"SquarePlusOne",
|
||||
// Inputs
|
||||
@ -138,8 +94,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, ControlDepNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, ControlDep) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"ControlDep",
|
||||
// Inputs
|
||||
@ -190,44 +146,8 @@ REGISTER_OP("HasDefaultType")
|
||||
// This verifies that a function using an op before a type attr (with
|
||||
// a default) is added, still works. This is important for backwards
|
||||
// compatibilty.
|
||||
TEST(TFunc, MissingTypeAttrOld) {
|
||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
||||
// Name
|
||||
"BackCompat",
|
||||
// Args
|
||||
{},
|
||||
// Return values
|
||||
{"y: float"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{// y = HasDefaultType(x), T missing, defaults to float
|
||||
{{"y"}, "HasDefaultType", {}, {}}});
|
||||
|
||||
const char* e = R"P(
|
||||
BackCompat() -> (y:float) {
|
||||
y = HasDefaultType()
|
||||
return y = y:out:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(
|
||||
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
|
||||
// Should get T=float from Op's default.
|
||||
const char* e2 = R"P(
|
||||
() -> (n0:float) {
|
||||
n0 = HasDefaultType[T=float]()
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector());
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, MissingTypeAttrNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, MissingTypeAttr) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"BackCompat",
|
||||
// Args
|
||||
@ -264,11 +184,8 @@ BackCompat() -> (y:float) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, NTimesTNodeDef) {
|
||||
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
||||
// using a _ListToArray to package up the two inputs to AddN as a single
|
||||
// N*T edge.
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, NTimesT) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"NTimesT",
|
||||
// Inputs
|
||||
@ -790,8 +707,8 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||
"input unknown is not found");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputs) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputs",
|
||||
// Inputs
|
||||
@ -811,8 +728,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||
"Expected input[2] == 'x' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooFewInputs) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooFewInputs",
|
||||
// Inputs
|
||||
@ -832,8 +749,8 @@ TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||
"Attempt to access beyond input size: 2 >= 2");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputsFromArray1) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
// Inputs
|
||||
@ -860,8 +777,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||
"Expected input[1] == 'y' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputsFromArray2) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
// Inputs
|
||||
@ -888,8 +805,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||
"Input a:output too long for inputs");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TypeMismatch) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TypeMismatch",
|
||||
// Inputs
|
||||
|
@ -178,15 +178,9 @@ void OpsUsedByGraph(const GraphDef& graph_def,
|
||||
while (!functions_to_process.empty()) {
|
||||
const FunctionDef* fun = functions_to_process.back();
|
||||
functions_to_process.pop_back();
|
||||
if (fun->node_def_size() > 0) {
|
||||
for (const auto& node : fun->node_def()) {
|
||||
mark_op_as_used(node.op());
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually drop support for this.
|
||||
for (const auto& node : fun->node()) {
|
||||
mark_op_as_used(node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out function names to produce output.
|
||||
|
@ -79,10 +79,12 @@ limitations under the License.
|
||||
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
|
||||
// now used by tf.concat_v2 (and soon tf.concat). Graphs use flooring
|
||||
// division and mod semantics. TensorArrayV3. (12dec2016)
|
||||
// 21. Dropped FunctionDef.Node support, switched to node_def introduced
|
||||
// in version 12. (11jan2017)
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 20
|
||||
#define TF_GRAPH_DEF_VERSION 21
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -25,8 +25,6 @@ import hashlib
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
@ -67,82 +65,6 @@ def _get_node_def(op):
|
||||
return op._node_def # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _add_input_array(op, start, limit, dtype, func):
|
||||
"""Adds a _ListToArray node in the func for op.inputs[start:limit]."""
|
||||
node = function_pb2.FunctionDef.Node()
|
||||
node.op = "_ListToArray"
|
||||
ret_name = op.name + "_L2A_" + str(start)
|
||||
node.ret.extend([ret_name])
|
||||
node.arg.extend(
|
||||
[_make_argname_from_tensor_name(x.name) for x in op.inputs[start:limit]])
|
||||
num = limit - start
|
||||
node.attr["Tin"].CopyFrom(
|
||||
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
|
||||
type=[dtype] * num)))
|
||||
node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
|
||||
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
|
||||
func.node.extend([node])
|
||||
return ret_name
|
||||
|
||||
|
||||
def _add_identity_dtype_proto(func, src, dst, dtype_proto):
|
||||
node = function_pb2.FunctionDef.Node()
|
||||
node.op = "Identity"
|
||||
node.arg.append(src)
|
||||
node.ret.append(dst)
|
||||
node.attr["T"].CopyFrom(dtype_proto)
|
||||
func.node.extend([node])
|
||||
|
||||
|
||||
def _add_identity_dtype_enum(func, src, dst, dtype):
|
||||
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
|
||||
_add_identity_dtype_proto(func, src, dst, dtype_proto)
|
||||
|
||||
|
||||
def _add_output_array(op, start, limit, dtype, func):
|
||||
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
|
||||
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
|
||||
# A node converting N*T to list(T)
|
||||
node = function_pb2.FunctionDef.Node()
|
||||
node.op = "_ArrayToList"
|
||||
arg_name = op.name + "_A2L_" + str(start)
|
||||
ret_name = arg_name + "_out"
|
||||
node.ret.append(ret_name)
|
||||
node.arg.append(arg_name)
|
||||
node.attr["T"].CopyFrom(dtype_proto)
|
||||
num = limit - start
|
||||
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
|
||||
node.attr["out_types"].CopyFrom(
|
||||
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
|
||||
type=[dtype] * num)))
|
||||
func.node.extend([node])
|
||||
num = limit - start
|
||||
# Adds an identity node for each element in the array N*T so that
|
||||
# uses of each element can be added easily later. These Identity
|
||||
# will be eliminated before graph execution.
|
||||
for i in xrange(num):
|
||||
_add_identity_dtype_proto(
|
||||
func, ret_name + ":" + str(i),
|
||||
_make_argname_from_tensor_name(op.outputs[i].name), dtype_proto)
|
||||
return arg_name
|
||||
|
||||
|
||||
def _add_output_list(op, start, limit, dtype_lst, func):
|
||||
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
|
||||
ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit)
|
||||
num = limit - start
|
||||
assert len(dtype_lst) == num
|
||||
# Adds an identity node for each element in the array N*T so that
|
||||
# uses of each element can be added easily later. These Identity
|
||||
# will be eliminated before graph execution.
|
||||
for i in xrange(num):
|
||||
_add_identity_dtype_enum(func,
|
||||
ret_name + ":" + str(i),
|
||||
_make_argname_from_tensor_name(op.outputs[i].name),
|
||||
dtype_lst[i])
|
||||
return ret_name
|
||||
|
||||
|
||||
def _get_op_def(op):
|
||||
# pylint: disable=protected-access
|
||||
if hasattr(op, "_sig"):
|
||||
@ -197,76 +119,6 @@ def _add_op_node(op, func, input_dict):
|
||||
"%s missing from %s" % (node_def.input[i], input_dict.items()))
|
||||
node_def.input[i] = input_dict[node_def.input[i]]
|
||||
|
||||
# To support legacy consumers, add an entry in func.node.
|
||||
# TODO(josh11b): Delete this.
|
||||
node = function_pb2.FunctionDef.Node()
|
||||
node.op = op.type
|
||||
op_def = _get_op_def(op)
|
||||
attrs = node_def.attr
|
||||
if not op_def.output_arg:
|
||||
node.ret.append(_make_argname_from_tensor_name(op.name))
|
||||
else:
|
||||
out_index = 0
|
||||
for arg_def in op_def.output_arg:
|
||||
if arg_def.number_attr:
|
||||
dtype = arg_def.type or attrs[arg_def.type_attr].type
|
||||
num = attrs[arg_def.number_attr].i
|
||||
node.ret.append(
|
||||
_add_output_array(op, out_index, out_index + num, dtype, func))
|
||||
out_index += num
|
||||
elif arg_def.type_list_attr:
|
||||
dtype_lst = attrs[arg_def.type_list_attr].list.type
|
||||
num = len(dtype_lst)
|
||||
node.ret.append(
|
||||
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
|
||||
out_index += num
|
||||
else:
|
||||
node.ret.append(
|
||||
_make_argname_from_tensor_name(op.outputs[out_index].name))
|
||||
out_index += 1
|
||||
inp_index = 0
|
||||
for arg_def in op_def.input_arg:
|
||||
if arg_def.number_attr:
|
||||
dtype = arg_def.type or attrs[arg_def.type_attr].type
|
||||
num = attrs[arg_def.number_attr].i
|
||||
node.arg.append(
|
||||
_add_input_array(op, inp_index, inp_index + num, dtype, func))
|
||||
inp_index += num
|
||||
elif arg_def.type_list_attr:
|
||||
num = len(attrs[arg_def.type_list_attr].list.type)
|
||||
node.arg.extend([
|
||||
_make_argname_from_tensor_name(op.inputs[i].name)
|
||||
for i in range(inp_index, inp_index + num)
|
||||
])
|
||||
inp_index += num
|
||||
else:
|
||||
node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
|
||||
inp_index += 1
|
||||
node.dep.extend(
|
||||
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
|
||||
for k, v in attrs.items():
|
||||
node.attr[k].CopyFrom(v)
|
||||
func.node.extend([node])
|
||||
|
||||
|
||||
def _replace_ret(func, original, replacement):
|
||||
for n in func.node:
|
||||
for i, r in enumerate(n.ret):
|
||||
if r == original:
|
||||
n.ret[i] = replacement
|
||||
return
|
||||
raise ValueError("Could not find ret == '%s'" % original)
|
||||
|
||||
|
||||
def _replace_arg(func, original, replacement):
|
||||
for n in func.node:
|
||||
for i, a in enumerate(n.arg):
|
||||
if a == original:
|
||||
n.arg[i] = replacement
|
||||
for i, d in enumerate(n.dep):
|
||||
if d == original:
|
||||
n.dep[i] = replacement
|
||||
|
||||
|
||||
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||
@ -323,20 +175,9 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
for index, o in enumerate(outputs):
|
||||
k = func.signature.output_arg[index].name
|
||||
func.ret[k] = input_dict[o.name]
|
||||
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
|
||||
# function bodies.
|
||||
orig = _make_argname_from_tensor_name(o.name)
|
||||
if k != orig:
|
||||
_add_identity_dtype_enum(func, orig, k,
|
||||
func.signature.output_arg[index].type)
|
||||
else:
|
||||
for o, n in zip(outputs, out_names):
|
||||
func.ret[n] = input_dict[o.name]
|
||||
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
|
||||
# function bodies.
|
||||
k = _make_argname_from_tensor_name(o.name)
|
||||
_replace_ret(func, k, n)
|
||||
_replace_arg(func, k, n)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -2299,6 +2299,9 @@ class Graph(object):
|
||||
if (function.grad_func_name is not None) and (
|
||||
function.python_grad_func is not None):
|
||||
raise ValueError("Gradient defined twice for function %s" % name)
|
||||
# Need a new-enough consumer to support the functions we add to the graph.
|
||||
if self._graph_def_versions.min_consumer < 12:
|
||||
self._graph_def_versions.min_consumer = 12
|
||||
self._functions[name] = function
|
||||
|
||||
@property
|
||||
|
Loading…
x
Reference in New Issue
Block a user