Delete the FunctionDef.Node code paths, now that we have switched
to the NodeDef representation. Change: 144281952
This commit is contained in:
parent
ded01d2902
commit
873473ef0f
@ -211,43 +211,6 @@ Status AddRetName(NameInfoIndex* name_info, const string& ret,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BuildNodeOutputIndex(const FunctionDef::Node& node,
|
|
||||||
const InstantiateAttrValueMap& attrs,
|
|
||||||
GetFunctionSignature get_function,
|
|
||||||
const int arg_index, NameInfoIndex* name_info) {
|
|
||||||
const OpDef* node_sig = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
|
|
||||||
if (node_sig->output_arg_size() == 0) {
|
|
||||||
// This node produces no output.
|
|
||||||
if (node.ret_size() != 1) {
|
|
||||||
return errors::InvalidArgument("Expect one ret name.");
|
|
||||||
}
|
|
||||||
return AddRetName(name_info, node.ret(0), {false, arg_index, 0, false, {}});
|
|
||||||
}
|
|
||||||
const int num_retval = node_sig->output_arg_size();
|
|
||||||
if (num_retval != node.ret_size()) {
|
|
||||||
return errors::InvalidArgument("Malformed function node (#ret): ",
|
|
||||||
num_retval, " vs. ", node.ret_size());
|
|
||||||
}
|
|
||||||
int start = 0;
|
|
||||||
bool is_type_list;
|
|
||||||
DataTypeVector dtypes;
|
|
||||||
for (int i = 0; i < num_retval; ++i) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddRetName(name_info, node.ret(i),
|
|
||||||
{false, arg_index, start, is_type_list, dtypes}));
|
|
||||||
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddRetName(name_info, strings::StrCat(node.ret(i), ":", j),
|
|
||||||
{false, arg_index, start + j, false, {dtypes[j]}}));
|
|
||||||
}
|
|
||||||
start += dtypes.size();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status BuildNodeOutputIndex(const NodeDef& node,
|
Status BuildNodeOutputIndex(const NodeDef& node,
|
||||||
const InstantiateAttrValueMap& attrs,
|
const InstantiateAttrValueMap& attrs,
|
||||||
GetFunctionSignature get_function,
|
GetFunctionSignature get_function,
|
||||||
@ -280,85 +243,6 @@ Status BuildNodeOutputIndex(const NodeDef& node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InstantiateNode(const FunctionDef::Node& fnode,
|
|
||||||
const InstantiateAttrValueMap& attrs,
|
|
||||||
GetFunctionSignature get_function,
|
|
||||||
const NameInfoIndex& name_info, GraphDef* gdef) {
|
|
||||||
const OpDef* fnode_sig = nullptr;
|
|
||||||
TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
|
|
||||||
NodeDef* gnode = gdef->add_node();
|
|
||||||
gnode->set_name(Name(gdef->node_size() - 1));
|
|
||||||
gnode->set_op(fnode.op());
|
|
||||||
|
|
||||||
// Input
|
|
||||||
const int num_args = fnode_sig->input_arg_size();
|
|
||||||
bool is_type_list;
|
|
||||||
DataTypeVector dtypes;
|
|
||||||
int fnode_arg_index = 0;
|
|
||||||
for (int i = 0; i < num_args; ++i) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
|
|
||||||
if (!is_type_list) {
|
|
||||||
const NameInfoItem* item =
|
|
||||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index));
|
|
||||||
if (item == nullptr) {
|
|
||||||
return errors::InvalidArgument("arg[", i, "] is not found: ",
|
|
||||||
ProtoShortDebugString(fnode));
|
|
||||||
}
|
|
||||||
if (dtypes != item->dtypes) {
|
|
||||||
return errors::InvalidArgument("Invalid arg(", i,
|
|
||||||
") for function arg: ",
|
|
||||||
DataTypeSliceString(dtypes), " vs. ",
|
|
||||||
DataTypeSliceString(item->dtypes), ".");
|
|
||||||
}
|
|
||||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
|
||||||
if (item->is_func_arg) {
|
|
||||||
gnode->add_input(Name(item->nid + j));
|
|
||||||
} else {
|
|
||||||
gnode->add_input(Name(item->nid, item->idx + j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++fnode_arg_index;
|
|
||||||
} else {
|
|
||||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
|
||||||
const NameInfoItem* item =
|
|
||||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index + j));
|
|
||||||
if (item == nullptr) {
|
|
||||||
return errors::InvalidArgument("arg[", i + j, "] is not found: ",
|
|
||||||
ProtoShortDebugString(fnode));
|
|
||||||
}
|
|
||||||
if (item->dtypes.size() != 1 || (item->dtypes[0] != dtypes[j])) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Invalid typelist arg(", i + j, ") for function arg: ",
|
|
||||||
DataTypeSliceString(dtypes), " vs. ",
|
|
||||||
DataTypeSliceString(item->dtypes), ".");
|
|
||||||
}
|
|
||||||
if (item->is_func_arg) {
|
|
||||||
gnode->add_input(Name(item->nid));
|
|
||||||
} else {
|
|
||||||
gnode->add_input(Name(item->nid, item->idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fnode_arg_index += dtypes.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Control deps.
|
|
||||||
for (int i = 0; i < fnode.dep_size(); ++i) {
|
|
||||||
const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
|
|
||||||
if (item == nullptr) {
|
|
||||||
return errors::InvalidArgument("dep[", i, "] is not found.");
|
|
||||||
}
|
|
||||||
gnode->add_input(Dep(item->nid));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attrs.
|
|
||||||
for (const auto& p : attrs) {
|
|
||||||
(*gnode->mutable_attr())[p.first] = p.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status InstantiateNode(const NodeDef& fnode,
|
Status InstantiateNode(const NodeDef& fnode,
|
||||||
const InstantiateAttrValueMap& attrs,
|
const InstantiateAttrValueMap& attrs,
|
||||||
GetFunctionSignature get_function,
|
GetFunctionSignature get_function,
|
||||||
@ -448,38 +332,6 @@ Status InstantiateNode(const NodeDef& fnode,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// FunctionDef::Node version
|
|
||||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
|
||||||
const InstantiateAttrValueMap& attrs,
|
|
||||||
const NameInfoIndex& name_info, int* ret_index,
|
|
||||||
InstantiationResult* result) {
|
|
||||||
bool is_type_list;
|
|
||||||
DataTypeVector dtypes;
|
|
||||||
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
|
||||||
CHECK_GE(dtypes.size(), size_t{1});
|
|
||||||
const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
|
|
||||||
if (item == nullptr) {
|
|
||||||
return errors::InvalidArgument("ret is not found.");
|
|
||||||
}
|
|
||||||
if (dtypes != item->dtypes) {
|
|
||||||
return errors::InvalidArgument("Invalid ret types ", ret_def.name(), " : ",
|
|
||||||
DataTypeVectorString(dtypes), " vs. ",
|
|
||||||
DataTypeVectorString(item->dtypes));
|
|
||||||
}
|
|
||||||
GraphDef* gdef = &result->gdef;
|
|
||||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
||||||
NodeDef* gnode = gdef->add_node();
|
|
||||||
gnode->set_name(Name(gdef->node_size() - 1));
|
|
||||||
gnode->set_op("_Retval");
|
|
||||||
gnode->add_input(Name(item->nid, item->idx + i));
|
|
||||||
AddAttr("T", dtypes[i], gnode);
|
|
||||||
AddAttr("index", (*ret_index)++, gnode);
|
|
||||||
result->ret_types.push_back(dtypes[i]);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// NodeDef version
|
|
||||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
||||||
const InstantiateAttrValueMap& attrs,
|
const InstantiateAttrValueMap& attrs,
|
||||||
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
||||||
@ -561,38 +413,6 @@ string Print(const AttrValue& attr_value) {
|
|||||||
return SummarizeAttrValue(attr_value);
|
return SummarizeAttrValue(attr_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
string Print(const FunctionDef::Node& node) {
|
|
||||||
string out;
|
|
||||||
for (int i = 0; i < node.ret_size(); ++i) {
|
|
||||||
const auto& name = node.ret(i);
|
|
||||||
if (i > 0) strings::StrAppend(&out, ", ");
|
|
||||||
strings::StrAppend(&out, name);
|
|
||||||
}
|
|
||||||
strings::StrAppend(&out, " = ", node.op());
|
|
||||||
if (node.attr_size() > 0) {
|
|
||||||
std::vector<string> entries;
|
|
||||||
for (auto p : node.attr()) {
|
|
||||||
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
|
||||||
}
|
|
||||||
sort(entries.begin(), entries.end());
|
|
||||||
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
|
|
||||||
}
|
|
||||||
strings::StrAppend(&out, "(");
|
|
||||||
for (int i = 0; i < node.arg_size(); ++i) {
|
|
||||||
if (i > 0) strings::StrAppend(&out, ", ");
|
|
||||||
strings::StrAppend(&out, node.arg(i));
|
|
||||||
}
|
|
||||||
strings::StrAppend(&out, ")");
|
|
||||||
if (node.dep_size() > 0) {
|
|
||||||
strings::StrAppend(&out, " @ ");
|
|
||||||
for (int i = 0; i < node.dep_size(); ++i) {
|
|
||||||
if (i > 0) strings::StrAppend(&out, ", ");
|
|
||||||
strings::StrAppend(&out, node.dep(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(josh11b): Merge this with SummarizeNodeDef().
|
// TODO(josh11b): Merge this with SummarizeNodeDef().
|
||||||
string Print(const NodeDef& n) {
|
string Print(const NodeDef& n) {
|
||||||
string out;
|
string out;
|
||||||
@ -650,18 +470,12 @@ string Print(const FunctionDef& fdef) {
|
|||||||
strings::StrAppend(&out, Print(sig.output_arg(i)));
|
strings::StrAppend(&out, Print(sig.output_arg(i)));
|
||||||
}
|
}
|
||||||
strings::StrAppend(&out, ") {\n");
|
strings::StrAppend(&out, ") {\n");
|
||||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
|
||||||
for (const auto& n : fdef.node_def()) {
|
for (const auto& n : fdef.node_def()) {
|
||||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||||
}
|
}
|
||||||
for (const auto& r : fdef.ret()) {
|
for (const auto& r : fdef.ret()) {
|
||||||
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
||||||
}
|
}
|
||||||
} else { // TODO(josh11b): Eventually remove this case.
|
|
||||||
for (const auto& n : fdef.node()) {
|
|
||||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strings::StrAppend(&out, "}\n");
|
strings::StrAppend(&out, "}\n");
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -772,7 +586,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
|
|||||||
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
||||||
// After this step, every attr is bound to a concrete value.
|
// After this step, every attr is bound to a concrete value.
|
||||||
std::vector<InstantiateAttrValueMap> node_attrs;
|
std::vector<InstantiateAttrValueMap> node_attrs;
|
||||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
|
||||||
node_attrs.resize(fdef.node_def_size());
|
node_attrs.resize(fdef.node_def_size());
|
||||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||||
for (auto attr : fdef.node_def(i).attr()) {
|
for (auto attr : fdef.node_def(i).attr()) {
|
||||||
@ -816,50 +629,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else { // TODO(josh11b): Eventually remove this case.
|
|
||||||
node_attrs.resize(fdef.node_size());
|
|
||||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
|
||||||
for (auto attr : fdef.node(i).attr()) {
|
|
||||||
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
|
||||||
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
|
||||||
SummarizeAttrValue(attr.second));
|
|
||||||
}
|
|
||||||
if (!node_attrs[i].insert(attr).second) {
|
|
||||||
return errors::Internal("Somehow duplicated: ", attr.first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AddDefaultAttrs(fdef.node(i).op(), get_function, &node_attrs[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
|
||||||
s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
|
|
||||||
gdef->node_size() + i, &name_info);
|
|
||||||
if (!s.ok()) {
|
|
||||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Emits one gdef.node for each fdef.node.
|
|
||||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
|
||||||
s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
|
|
||||||
gdef);
|
|
||||||
if (!s.ok()) {
|
|
||||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emits nodes for the function's return values.
|
|
||||||
int ret_index = 0;
|
|
||||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
|
||||||
s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
|
|
||||||
if (!s.ok()) {
|
|
||||||
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -30,61 +30,7 @@ message FunctionDef {
|
|||||||
// Attributes specific to this function definition.
|
// Attributes specific to this function definition.
|
||||||
map<string, AttrValue> attr = 5;
|
map<string, AttrValue> attr = 5;
|
||||||
|
|
||||||
// TO BE REPLACED
|
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
|
||||||
|
|
||||||
// The body of the function.
|
|
||||||
repeated Node node = 2; // function.node.ret[*] are unique.
|
|
||||||
|
|
||||||
// A node is a multi-value assignment:
|
|
||||||
// (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
|
|
||||||
//
|
|
||||||
// By convention, "func" is resolved by consulting with a user-defined
|
|
||||||
// library first. If not resolved, "func" is assumed to be a builtin op.
|
|
||||||
message Node {
|
|
||||||
// This node produces multiple outputs. They are named ret[0],
|
|
||||||
// ret[1], ..., etc.
|
|
||||||
//
|
|
||||||
// REQUIRES: function.node.ret[*] are unique across all nodes.
|
|
||||||
// REQUIRES: ret.size == func/op def's number of output args.
|
|
||||||
repeated string ret = 1;
|
|
||||||
|
|
||||||
// The op/function name.
|
|
||||||
string op = 2;
|
|
||||||
|
|
||||||
// Arguments passed to this func/op.
|
|
||||||
//
|
|
||||||
// arg[i] must be either one of
|
|
||||||
// function.signature.input_args[*].name or one of
|
|
||||||
// function.node[*].ret[*].
|
|
||||||
//
|
|
||||||
// REQUIRES: arg.size == func/op def's number of input args.
|
|
||||||
repeated string arg = 3;
|
|
||||||
|
|
||||||
// Control dependencies.
|
|
||||||
//
|
|
||||||
// dep[i] must be one of function.node[*].ret[*] or one of
|
|
||||||
// function.signature.input_args[*].name.
|
|
||||||
repeated string dep = 4;
|
|
||||||
|
|
||||||
// Attrs.
|
|
||||||
//
|
|
||||||
// 'attr' maps names defined by 'func's attr defs to attr values.
|
|
||||||
// attr values may have placeholders which are substituted
|
|
||||||
// recursively by concrete values when this node is instantiated.
|
|
||||||
// These placeholders must name an attr listed in the FunctionDef's
|
|
||||||
// signature.
|
|
||||||
map<string, AttrValue> attr = 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
// WILL REPLACE THE ABOVE
|
|
||||||
|
|
||||||
// If node_def is present, and the consumer is at GraphDef version
|
|
||||||
// >= 12, then these fields are used and `node` is ignored. If the
|
|
||||||
// consumer's GraphDef version is < 12 or this field is empty, then
|
|
||||||
// `node` is used. This allows producers to fill both fields to
|
|
||||||
// remain compatible with old consumers. At some future GraphDef
|
|
||||||
// version, `node` will be ignored even if `node_def` is empty.
|
|
||||||
// TODO(josh11b): Finish this transition.
|
|
||||||
|
|
||||||
// In both of the following fields, there is the need to specify an
|
// In both of the following fields, there is the need to specify an
|
||||||
// output that is used as either the input to another node (in
|
// output that is used as either the input to another node (in
|
||||||
@ -120,6 +66,10 @@ message FunctionDef {
|
|||||||
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
||||||
// may have values of type `placeholder` and the `input` field uses
|
// may have values of type `placeholder` and the `input` field uses
|
||||||
// the "output" format above.
|
// the "output" format above.
|
||||||
|
|
||||||
|
// By convention, "op" in node_def is resolved by consulting with a
|
||||||
|
// user-defined library first. If not resolved, "func" is assumed to
|
||||||
|
// be a builtin op.
|
||||||
repeated NodeDef node_def = 3;
|
repeated NodeDef node_def = 3;
|
||||||
|
|
||||||
// A mapping from the output arg names from `signature` to the
|
// A mapping from the output arg names from `signature` to the
|
||||||
|
@ -48,52 +48,8 @@ y: A scalar in type T.
|
|||||||
|
|
||||||
static InstantiateAttrValueMap kNoAttrs;
|
static InstantiateAttrValueMap kNoAttrs;
|
||||||
|
|
||||||
TEST(TFunc, SquarePlusOneOld) {
|
TEST(TFunc, SquarePlusOne) {
|
||||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
|
||||||
"SquarePlusOne",
|
|
||||||
// Args
|
|
||||||
{"x: T"},
|
|
||||||
// Return values
|
|
||||||
{"y: T"},
|
|
||||||
// Attrs
|
|
||||||
{"T: {float, double, int32, int64}"},
|
|
||||||
// Nodes
|
|
||||||
{// a = Square<T>(x)
|
|
||||||
{{"a"}, "Square", {"x"}, {{"T", "$T"}}},
|
|
||||||
// o = One<T>()
|
|
||||||
// NOTE: We can also have a Cast<Tin, Tout>(x) instead.
|
|
||||||
{{"o"}, "One", {}, {{"T", "$T"}}},
|
|
||||||
// y = Add<T>(a, o)
|
|
||||||
{{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
|
|
||||||
|
|
||||||
const char* e = R"P(
|
|
||||||
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
|
||||||
a = Square[T=$T](x)
|
|
||||||
o = One[T=$T]()
|
|
||||||
y = Add[T=$T](a:y:0, o:y:0)
|
|
||||||
return y = y:z:0
|
|
||||||
}
|
|
||||||
)P";
|
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
|
||||||
|
|
||||||
// Instantiate one with T=float
|
|
||||||
InstantiationResult result;
|
|
||||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
|
||||||
const char* e2 = R"P(
|
|
||||||
(n0:float) -> (n3:float) {
|
|
||||||
n1 = Square[T=float](n0)
|
|
||||||
n2 = One[T=float]()
|
|
||||||
n3 = Add[T=float](n1, n2)
|
|
||||||
}
|
|
||||||
)P";
|
|
||||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
|
||||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
|
||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(TFunc, SquarePlusOneNodeDef) {
|
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
|
||||||
// Name
|
// Name
|
||||||
"SquarePlusOne",
|
"SquarePlusOne",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -138,8 +94,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, ControlDepNodeDef) {
|
TEST(TFunc, ControlDep) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"ControlDep",
|
"ControlDep",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -190,44 +146,8 @@ REGISTER_OP("HasDefaultType")
|
|||||||
// This verifies that a function using an op before a type attr (with
|
// This verifies that a function using an op before a type attr (with
|
||||||
// a default) is added, still works. This is important for backwards
|
// a default) is added, still works. This is important for backwards
|
||||||
// compatibilty.
|
// compatibilty.
|
||||||
TEST(TFunc, MissingTypeAttrOld) {
|
TEST(TFunc, MissingTypeAttr) {
|
||||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
|
||||||
"BackCompat",
|
|
||||||
// Args
|
|
||||||
{},
|
|
||||||
// Return values
|
|
||||||
{"y: float"},
|
|
||||||
// Attrs
|
|
||||||
{},
|
|
||||||
// Nodes
|
|
||||||
{// y = HasDefaultType(x), T missing, defaults to float
|
|
||||||
{{"y"}, "HasDefaultType", {}, {}}});
|
|
||||||
|
|
||||||
const char* e = R"P(
|
|
||||||
BackCompat() -> (y:float) {
|
|
||||||
y = HasDefaultType()
|
|
||||||
return y = y:out:0
|
|
||||||
}
|
|
||||||
)P";
|
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
|
||||||
|
|
||||||
InstantiationResult result;
|
|
||||||
TF_ASSERT_OK(
|
|
||||||
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
|
|
||||||
// Should get T=float from Op's default.
|
|
||||||
const char* e2 = R"P(
|
|
||||||
() -> (n0:float) {
|
|
||||||
n0 = HasDefaultType[T=float]()
|
|
||||||
}
|
|
||||||
)P";
|
|
||||||
EXPECT_EQ(result.arg_types, DataTypeVector());
|
|
||||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
|
||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(TFunc, MissingTypeAttrNodeDef) {
|
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
|
||||||
// Name
|
// Name
|
||||||
"BackCompat",
|
"BackCompat",
|
||||||
// Args
|
// Args
|
||||||
@ -264,11 +184,8 @@ BackCompat() -> (y:float) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, NTimesTNodeDef) {
|
TEST(TFunc, NTimesT) {
|
||||||
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
auto fdef = FDH::Create(
|
||||||
// using a _ListToArray to package up the two inputs to AddN as a single
|
|
||||||
// N*T edge.
|
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
|
||||||
// Name
|
// Name
|
||||||
"NTimesT",
|
"NTimesT",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -790,8 +707,8 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
|||||||
"input unknown is not found");
|
"input unknown is not found");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
TEST(InstantiateErrors, TooManyInputs) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputs",
|
"TooManyInputs",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -811,8 +728,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
|||||||
"Expected input[2] == 'x' to be a control input.");
|
"Expected input[2] == 'x' to be a control input.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
TEST(InstantiateErrors, TooFewInputs) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"TooFewInputs",
|
"TooFewInputs",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -832,8 +749,8 @@ TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
|||||||
"Attempt to access beyond input size: 2 >= 2");
|
"Attempt to access beyond input size: 2 >= 2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
TEST(InstantiateErrors, TooManyInputsFromArray1) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputsFromArray",
|
"TooManyInputsFromArray",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -860,8 +777,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
|||||||
"Expected input[1] == 'y' to be a control input.");
|
"Expected input[1] == 'y' to be a control input.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
TEST(InstantiateErrors, TooManyInputsFromArray2) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputsFromArray",
|
"TooManyInputsFromArray",
|
||||||
// Inputs
|
// Inputs
|
||||||
@ -888,8 +805,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
|||||||
"Input a:output too long for inputs");
|
"Input a:output too long for inputs");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
|
TEST(InstantiateErrors, TypeMismatch) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"TypeMismatch",
|
"TypeMismatch",
|
||||||
// Inputs
|
// Inputs
|
||||||
|
@ -178,15 +178,9 @@ void OpsUsedByGraph(const GraphDef& graph_def,
|
|||||||
while (!functions_to_process.empty()) {
|
while (!functions_to_process.empty()) {
|
||||||
const FunctionDef* fun = functions_to_process.back();
|
const FunctionDef* fun = functions_to_process.back();
|
||||||
functions_to_process.pop_back();
|
functions_to_process.pop_back();
|
||||||
if (fun->node_def_size() > 0) {
|
|
||||||
for (const auto& node : fun->node_def()) {
|
for (const auto& node : fun->node_def()) {
|
||||||
mark_op_as_used(node.op());
|
mark_op_as_used(node.op());
|
||||||
}
|
}
|
||||||
} else { // TODO(josh11b): Eventually drop support for this.
|
|
||||||
for (const auto& node : fun->node()) {
|
|
||||||
mark_op_as_used(node.op());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter out function names to produce output.
|
// Filter out function names to produce output.
|
||||||
|
@ -79,10 +79,12 @@ limitations under the License.
|
|||||||
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
|
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
|
||||||
// now used by tf.concat_v2 (and soon tf.concat). Graphs use flooring
|
// now used by tf.concat_v2 (and soon tf.concat). Graphs use flooring
|
||||||
// division and mod semantics. TensorArrayV3. (12dec2016)
|
// division and mod semantics. TensorArrayV3. (12dec2016)
|
||||||
|
// 21. Dropped FunctionDef.Node support, switched to node_def introduced
|
||||||
|
// in version 12. (11jan2017)
|
||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
#define TF_GRAPH_DEF_VERSION 20
|
#define TF_GRAPH_DEF_VERSION 21
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||||
//
|
//
|
||||||
|
@ -25,8 +25,6 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import function_pb2
|
from tensorflow.core.framework import function_pb2
|
||||||
from tensorflow.core.framework import op_def_pb2
|
from tensorflow.core.framework import op_def_pb2
|
||||||
@ -67,82 +65,6 @@ def _get_node_def(op):
|
|||||||
return op._node_def # pylint: disable=protected-access
|
return op._node_def # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _add_input_array(op, start, limit, dtype, func):
|
|
||||||
"""Adds a _ListToArray node in the func for op.inputs[start:limit]."""
|
|
||||||
node = function_pb2.FunctionDef.Node()
|
|
||||||
node.op = "_ListToArray"
|
|
||||||
ret_name = op.name + "_L2A_" + str(start)
|
|
||||||
node.ret.extend([ret_name])
|
|
||||||
node.arg.extend(
|
|
||||||
[_make_argname_from_tensor_name(x.name) for x in op.inputs[start:limit]])
|
|
||||||
num = limit - start
|
|
||||||
node.attr["Tin"].CopyFrom(
|
|
||||||
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
|
|
||||||
type=[dtype] * num)))
|
|
||||||
node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
|
|
||||||
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
|
|
||||||
func.node.extend([node])
|
|
||||||
return ret_name
|
|
||||||
|
|
||||||
|
|
||||||
def _add_identity_dtype_proto(func, src, dst, dtype_proto):
|
|
||||||
node = function_pb2.FunctionDef.Node()
|
|
||||||
node.op = "Identity"
|
|
||||||
node.arg.append(src)
|
|
||||||
node.ret.append(dst)
|
|
||||||
node.attr["T"].CopyFrom(dtype_proto)
|
|
||||||
func.node.extend([node])
|
|
||||||
|
|
||||||
|
|
||||||
def _add_identity_dtype_enum(func, src, dst, dtype):
|
|
||||||
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
|
|
||||||
_add_identity_dtype_proto(func, src, dst, dtype_proto)
|
|
||||||
|
|
||||||
|
|
||||||
def _add_output_array(op, start, limit, dtype, func):
|
|
||||||
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
|
|
||||||
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
|
|
||||||
# A node converting N*T to list(T)
|
|
||||||
node = function_pb2.FunctionDef.Node()
|
|
||||||
node.op = "_ArrayToList"
|
|
||||||
arg_name = op.name + "_A2L_" + str(start)
|
|
||||||
ret_name = arg_name + "_out"
|
|
||||||
node.ret.append(ret_name)
|
|
||||||
node.arg.append(arg_name)
|
|
||||||
node.attr["T"].CopyFrom(dtype_proto)
|
|
||||||
num = limit - start
|
|
||||||
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
|
|
||||||
node.attr["out_types"].CopyFrom(
|
|
||||||
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
|
|
||||||
type=[dtype] * num)))
|
|
||||||
func.node.extend([node])
|
|
||||||
num = limit - start
|
|
||||||
# Adds an identity node for each element in the array N*T so that
|
|
||||||
# uses of each element can be added easily later. These Identity
|
|
||||||
# will be eliminated before graph execution.
|
|
||||||
for i in xrange(num):
|
|
||||||
_add_identity_dtype_proto(
|
|
||||||
func, ret_name + ":" + str(i),
|
|
||||||
_make_argname_from_tensor_name(op.outputs[i].name), dtype_proto)
|
|
||||||
return arg_name
|
|
||||||
|
|
||||||
|
|
||||||
def _add_output_list(op, start, limit, dtype_lst, func):
|
|
||||||
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
|
|
||||||
ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit)
|
|
||||||
num = limit - start
|
|
||||||
assert len(dtype_lst) == num
|
|
||||||
# Adds an identity node for each element in the array N*T so that
|
|
||||||
# uses of each element can be added easily later. These Identity
|
|
||||||
# will be eliminated before graph execution.
|
|
||||||
for i in xrange(num):
|
|
||||||
_add_identity_dtype_enum(func,
|
|
||||||
ret_name + ":" + str(i),
|
|
||||||
_make_argname_from_tensor_name(op.outputs[i].name),
|
|
||||||
dtype_lst[i])
|
|
||||||
return ret_name
|
|
||||||
|
|
||||||
|
|
||||||
def _get_op_def(op):
|
def _get_op_def(op):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if hasattr(op, "_sig"):
|
if hasattr(op, "_sig"):
|
||||||
@ -197,76 +119,6 @@ def _add_op_node(op, func, input_dict):
|
|||||||
"%s missing from %s" % (node_def.input[i], input_dict.items()))
|
"%s missing from %s" % (node_def.input[i], input_dict.items()))
|
||||||
node_def.input[i] = input_dict[node_def.input[i]]
|
node_def.input[i] = input_dict[node_def.input[i]]
|
||||||
|
|
||||||
# To support legacy consumers, add an entry in func.node.
|
|
||||||
# TODO(josh11b): Delete this.
|
|
||||||
node = function_pb2.FunctionDef.Node()
|
|
||||||
node.op = op.type
|
|
||||||
op_def = _get_op_def(op)
|
|
||||||
attrs = node_def.attr
|
|
||||||
if not op_def.output_arg:
|
|
||||||
node.ret.append(_make_argname_from_tensor_name(op.name))
|
|
||||||
else:
|
|
||||||
out_index = 0
|
|
||||||
for arg_def in op_def.output_arg:
|
|
||||||
if arg_def.number_attr:
|
|
||||||
dtype = arg_def.type or attrs[arg_def.type_attr].type
|
|
||||||
num = attrs[arg_def.number_attr].i
|
|
||||||
node.ret.append(
|
|
||||||
_add_output_array(op, out_index, out_index + num, dtype, func))
|
|
||||||
out_index += num
|
|
||||||
elif arg_def.type_list_attr:
|
|
||||||
dtype_lst = attrs[arg_def.type_list_attr].list.type
|
|
||||||
num = len(dtype_lst)
|
|
||||||
node.ret.append(
|
|
||||||
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
|
|
||||||
out_index += num
|
|
||||||
else:
|
|
||||||
node.ret.append(
|
|
||||||
_make_argname_from_tensor_name(op.outputs[out_index].name))
|
|
||||||
out_index += 1
|
|
||||||
inp_index = 0
|
|
||||||
for arg_def in op_def.input_arg:
|
|
||||||
if arg_def.number_attr:
|
|
||||||
dtype = arg_def.type or attrs[arg_def.type_attr].type
|
|
||||||
num = attrs[arg_def.number_attr].i
|
|
||||||
node.arg.append(
|
|
||||||
_add_input_array(op, inp_index, inp_index + num, dtype, func))
|
|
||||||
inp_index += num
|
|
||||||
elif arg_def.type_list_attr:
|
|
||||||
num = len(attrs[arg_def.type_list_attr].list.type)
|
|
||||||
node.arg.extend([
|
|
||||||
_make_argname_from_tensor_name(op.inputs[i].name)
|
|
||||||
for i in range(inp_index, inp_index + num)
|
|
||||||
])
|
|
||||||
inp_index += num
|
|
||||||
else:
|
|
||||||
node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
|
|
||||||
inp_index += 1
|
|
||||||
node.dep.extend(
|
|
||||||
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
|
|
||||||
for k, v in attrs.items():
|
|
||||||
node.attr[k].CopyFrom(v)
|
|
||||||
func.node.extend([node])
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_ret(func, original, replacement):
|
|
||||||
for n in func.node:
|
|
||||||
for i, r in enumerate(n.ret):
|
|
||||||
if r == original:
|
|
||||||
n.ret[i] = replacement
|
|
||||||
return
|
|
||||||
raise ValueError("Could not find ret == '%s'" % original)
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_arg(func, original, replacement):
|
|
||||||
for n in func.node:
|
|
||||||
for i, a in enumerate(n.arg):
|
|
||||||
if a == original:
|
|
||||||
n.arg[i] = replacement
|
|
||||||
for i, d in enumerate(n.dep):
|
|
||||||
if d == original:
|
|
||||||
n.dep[i] = replacement
|
|
||||||
|
|
||||||
|
|
||||||
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||||
@ -323,20 +175,9 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
|||||||
for index, o in enumerate(outputs):
|
for index, o in enumerate(outputs):
|
||||||
k = func.signature.output_arg[index].name
|
k = func.signature.output_arg[index].name
|
||||||
func.ret[k] = input_dict[o.name]
|
func.ret[k] = input_dict[o.name]
|
||||||
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
|
|
||||||
# function bodies.
|
|
||||||
orig = _make_argname_from_tensor_name(o.name)
|
|
||||||
if k != orig:
|
|
||||||
_add_identity_dtype_enum(func, orig, k,
|
|
||||||
func.signature.output_arg[index].type)
|
|
||||||
else:
|
else:
|
||||||
for o, n in zip(outputs, out_names):
|
for o, n in zip(outputs, out_names):
|
||||||
func.ret[n] = input_dict[o.name]
|
func.ret[n] = input_dict[o.name]
|
||||||
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
|
|
||||||
# function bodies.
|
|
||||||
k = _make_argname_from_tensor_name(o.name)
|
|
||||||
_replace_ret(func, k, n)
|
|
||||||
_replace_arg(func, k, n)
|
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -2299,6 +2299,9 @@ class Graph(object):
|
|||||||
if (function.grad_func_name is not None) and (
|
if (function.grad_func_name is not None) and (
|
||||||
function.python_grad_func is not None):
|
function.python_grad_func is not None):
|
||||||
raise ValueError("Gradient defined twice for function %s" % name)
|
raise ValueError("Gradient defined twice for function %s" % name)
|
||||||
|
# Need a new-enough consumer to support the functions we add to the graph.
|
||||||
|
if self._graph_def_versions.min_consumer < 12:
|
||||||
|
self._graph_def_versions.min_consumer = 12
|
||||||
self._functions[name] = function
|
self._functions[name] = function
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
x
Reference in New Issue
Block a user