Delete the FunctionDef.Node code paths, now that we have switched

to the NodeDef representation.
Change: 144281952
This commit is contained in:
A. Unique TensorFlower 2017-01-11 20:06:17 -08:00 committed by TensorFlower Gardener
parent ded01d2902
commit 873473ef0f
7 changed files with 71 additions and 595 deletions

View File

@ -211,43 +211,6 @@ Status AddRetName(NameInfoIndex* name_info, const string& ret,
return Status::OK();
}
Status BuildNodeOutputIndex(const FunctionDef::Node& node,
const InstantiateAttrValueMap& attrs,
GetFunctionSignature get_function,
const int arg_index, NameInfoIndex* name_info) {
const OpDef* node_sig = nullptr;
TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
if (node_sig->output_arg_size() == 0) {
// This node produces no output.
if (node.ret_size() != 1) {
return errors::InvalidArgument("Expect one ret name.");
}
return AddRetName(name_info, node.ret(0), {false, arg_index, 0, false, {}});
}
const int num_retval = node_sig->output_arg_size();
if (num_retval != node.ret_size()) {
return errors::InvalidArgument("Malformed function node (#ret): ",
num_retval, " vs. ", node.ret_size());
}
int start = 0;
bool is_type_list;
DataTypeVector dtypes;
for (int i = 0; i < num_retval; ++i) {
TF_RETURN_IF_ERROR(
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
TF_RETURN_IF_ERROR(
AddRetName(name_info, node.ret(i),
{false, arg_index, start, is_type_list, dtypes}));
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
TF_RETURN_IF_ERROR(
AddRetName(name_info, strings::StrCat(node.ret(i), ":", j),
{false, arg_index, start + j, false, {dtypes[j]}}));
}
start += dtypes.size();
}
return Status::OK();
}
Status BuildNodeOutputIndex(const NodeDef& node,
const InstantiateAttrValueMap& attrs,
GetFunctionSignature get_function,
@ -280,85 +243,6 @@ Status BuildNodeOutputIndex(const NodeDef& node,
return Status::OK();
}
Status InstantiateNode(const FunctionDef::Node& fnode,
const InstantiateAttrValueMap& attrs,
GetFunctionSignature get_function,
const NameInfoIndex& name_info, GraphDef* gdef) {
const OpDef* fnode_sig = nullptr;
TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
NodeDef* gnode = gdef->add_node();
gnode->set_name(Name(gdef->node_size() - 1));
gnode->set_op(fnode.op());
// Input
const int num_args = fnode_sig->input_arg_size();
bool is_type_list;
DataTypeVector dtypes;
int fnode_arg_index = 0;
for (int i = 0; i < num_args; ++i) {
TF_RETURN_IF_ERROR(
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
if (!is_type_list) {
const NameInfoItem* item =
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index));
if (item == nullptr) {
return errors::InvalidArgument("arg[", i, "] is not found: ",
ProtoShortDebugString(fnode));
}
if (dtypes != item->dtypes) {
return errors::InvalidArgument("Invalid arg(", i,
") for function arg: ",
DataTypeSliceString(dtypes), " vs. ",
DataTypeSliceString(item->dtypes), ".");
}
for (size_t j = 0; j < dtypes.size(); ++j) {
if (item->is_func_arg) {
gnode->add_input(Name(item->nid + j));
} else {
gnode->add_input(Name(item->nid, item->idx + j));
}
}
++fnode_arg_index;
} else {
for (size_t j = 0; j < dtypes.size(); ++j) {
const NameInfoItem* item =
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index + j));
if (item == nullptr) {
return errors::InvalidArgument("arg[", i + j, "] is not found: ",
ProtoShortDebugString(fnode));
}
if (item->dtypes.size() != 1 || (item->dtypes[0] != dtypes[j])) {
return errors::InvalidArgument(
"Invalid typelist arg(", i + j, ") for function arg: ",
DataTypeSliceString(dtypes), " vs. ",
DataTypeSliceString(item->dtypes), ".");
}
if (item->is_func_arg) {
gnode->add_input(Name(item->nid));
} else {
gnode->add_input(Name(item->nid, item->idx));
}
}
fnode_arg_index += dtypes.size();
}
}
// Control deps.
for (int i = 0; i < fnode.dep_size(); ++i) {
const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
if (item == nullptr) {
return errors::InvalidArgument("dep[", i, "] is not found.");
}
gnode->add_input(Dep(item->nid));
}
// Attrs.
for (const auto& p : attrs) {
(*gnode->mutable_attr())[p.first] = p.second;
}
return Status::OK();
}
Status InstantiateNode(const NodeDef& fnode,
const InstantiateAttrValueMap& attrs,
GetFunctionSignature get_function,
@ -448,38 +332,6 @@ Status InstantiateNode(const NodeDef& fnode,
return Status::OK();
}
// FunctionDef::Node version
Status AddReturnNode(const OpDef::ArgDef& ret_def,
const InstantiateAttrValueMap& attrs,
const NameInfoIndex& name_info, int* ret_index,
InstantiationResult* result) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
CHECK_GE(dtypes.size(), size_t{1});
const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
if (item == nullptr) {
return errors::InvalidArgument("ret is not found.");
}
if (dtypes != item->dtypes) {
return errors::InvalidArgument("Invalid ret types ", ret_def.name(), " : ",
DataTypeVectorString(dtypes), " vs. ",
DataTypeVectorString(item->dtypes));
}
GraphDef* gdef = &result->gdef;
for (size_t i = 0; i < dtypes.size(); ++i) {
NodeDef* gnode = gdef->add_node();
gnode->set_name(Name(gdef->node_size() - 1));
gnode->set_op("_Retval");
gnode->add_input(Name(item->nid, item->idx + i));
AddAttr("T", dtypes[i], gnode);
AddAttr("index", (*ret_index)++, gnode);
result->ret_types.push_back(dtypes[i]);
}
return Status::OK();
}
// NodeDef version
Status AddReturnNode(const OpDef::ArgDef& ret_def,
const InstantiateAttrValueMap& attrs,
const ::tensorflow::protobuf::Map<string, string>& ret_map,
@ -561,38 +413,6 @@ string Print(const AttrValue& attr_value) {
return SummarizeAttrValue(attr_value);
}
string Print(const FunctionDef::Node& node) {
string out;
for (int i = 0; i < node.ret_size(); ++i) {
const auto& name = node.ret(i);
if (i > 0) strings::StrAppend(&out, ", ");
strings::StrAppend(&out, name);
}
strings::StrAppend(&out, " = ", node.op());
if (node.attr_size() > 0) {
std::vector<string> entries;
for (auto p : node.attr()) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
sort(entries.begin(), entries.end());
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
}
strings::StrAppend(&out, "(");
for (int i = 0; i < node.arg_size(); ++i) {
if (i > 0) strings::StrAppend(&out, ", ");
strings::StrAppend(&out, node.arg(i));
}
strings::StrAppend(&out, ")");
if (node.dep_size() > 0) {
strings::StrAppend(&out, " @ ");
for (int i = 0; i < node.dep_size(); ++i) {
if (i > 0) strings::StrAppend(&out, ", ");
strings::StrAppend(&out, node.dep(i));
}
}
return out;
}
// TODO(josh11b): Merge this with SummarizeNodeDef().
string Print(const NodeDef& n) {
string out;
@ -650,18 +470,12 @@ string Print(const FunctionDef& fdef) {
strings::StrAppend(&out, Print(sig.output_arg(i)));
}
strings::StrAppend(&out, ") {\n");
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
for (const auto& n : fdef.node_def()) {
strings::StrAppend(&out, " ", Print(n), "\n");
}
for (const auto& r : fdef.ret()) {
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
}
} else { // TODO(josh11b): Eventually remove this case.
for (const auto& n : fdef.node()) {
strings::StrAppend(&out, " ", Print(n), "\n");
}
}
strings::StrAppend(&out, "}\n");
return out;
}
@ -772,7 +586,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
// Makes a copy of all attrs in fdef and substitutes placeholders.
// After this step, every attr is bound to a concrete value.
std::vector<InstantiateAttrValueMap> node_attrs;
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
node_attrs.resize(fdef.node_def_size());
for (int i = 0; i < fdef.node_def_size(); ++i) {
for (auto attr : fdef.node_def(i).attr()) {
@ -816,50 +629,6 @@ Status InstantiateFunction(const FunctionDef& fdef,
return s;
}
}
} else { // TODO(josh11b): Eventually remove this case.
node_attrs.resize(fdef.node_size());
for (int i = 0; i < fdef.node_size(); ++i) {
for (auto attr : fdef.node(i).attr()) {
if (!SubstitutePlaceholders(substitute, &attr.second)) {
return errors::InvalidArgument("Failed to bind all placeholders in ",
SummarizeAttrValue(attr.second));
}
if (!node_attrs[i].insert(attr).second) {
return errors::Internal("Somehow duplicated: ", attr.first);
}
}
TF_RETURN_IF_ERROR(
AddDefaultAttrs(fdef.node(i).op(), get_function, &node_attrs[i]));
}
for (int i = 0; i < fdef.node_size(); ++i) {
s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
gdef->node_size() + i, &name_info);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
return s;
}
}
// Emits one gdef.node for each fdef.node.
for (int i = 0; i < fdef.node_size(); ++i) {
s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
gdef);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
return s;
}
}
// Emits nodes for the function's return values.
int ret_index = 0;
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
if (!s.ok()) {
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
return s;
}
}
}
return Status::OK();
}

View File

@ -30,61 +30,7 @@ message FunctionDef {
// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;
// TO BE REPLACED
// The body of the function.
repeated Node node = 2; // function.node.ret[*] are unique.
// A node is a multi-value assignment:
// (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
//
// By convention, "func" is resolved by consulting with a user-defined
// library first. If not resolved, "func" is assumed to be a builtin op.
message Node {
// This node produces multiple outputs. They are named ret[0],
// ret[1], ..., etc.
//
// REQUIRES: function.node.ret[*] are unique across all nodes.
// REQUIRES: ret.size == func/op def's number of output args.
repeated string ret = 1;
// The op/function name.
string op = 2;
// Arguments passed to this func/op.
//
// arg[i] must be either one of
// function.signature.input_args[*].name or one of
// function.node[*].ret[*].
//
// REQUIRES: arg.size == func/op def's number of input args.
repeated string arg = 3;
// Control dependencies.
//
// dep[i] must be one of function.node[*].ret[*] or one of
// function.signature.input_args[*].name.
repeated string dep = 4;
// Attrs.
//
// 'attr' maps names defined by 'func's attr defs to attr values.
// attr values may have placeholders which are substituted
// recursively by concrete values when this node is instantiated.
// These placeholders must name an attr listed in the FunctionDef's
// signature.
map<string, AttrValue> attr = 5;
}
// WILL REPLACE THE ABOVE
// If node_def is present, and the consumer is at GraphDef version
// >= 12, then these fields are used and `node` is ignored. If the
// consumer's GraphDef version is < 12 or this field is empty, then
// `node` is used. This allows producers to fill both fields to
// remain compatible with old consumers. At some future GraphDef
// version, `node` will be ignored even if `node_def` is empty.
// TODO(josh11b): Finish this transition.
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
@ -120,6 +66,10 @@ message FunctionDef {
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.
// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;
// A mapping from the output arg names from `signature` to the

View File

@ -48,52 +48,8 @@ y: A scalar in type T.
static InstantiateAttrValueMap kNoAttrs;
TEST(TFunc, SquarePlusOneOld) {
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
// Name
"SquarePlusOne",
// Args
{"x: T"},
// Return values
{"y: T"},
// Attrs
{"T: {float, double, int32, int64}"},
// Nodes
{// a = Square<T>(x)
{{"a"}, "Square", {"x"}, {{"T", "$T"}}},
// o = One<T>()
// NOTE: We can also have a Cast<Tin, Tout>(x) instead.
{{"o"}, "One", {}, {{"T", "$T"}}},
// y = Add<T>(a, o)
{{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
const char* e = R"P(
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
a = Square[T=$T](x)
o = One[T=$T]()
y = Add[T=$T](a:y:0, o:y:0)
return y = y:z:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
// Instantiate one with T=float
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
const char* e2 = R"P(
(n0:float) -> (n3:float) {
n1 = Square[T=float](n0)
n2 = One[T=float]()
n3 = Add[T=float](n1, n2)
}
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, SquarePlusOneNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(TFunc, SquarePlusOne) {
auto fdef = FDH::Create(
// Name
"SquarePlusOne",
// Inputs
@ -138,8 +94,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, ControlDepNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(TFunc, ControlDep) {
auto fdef = FDH::Create(
// Name
"ControlDep",
// Inputs
@ -190,44 +146,8 @@ REGISTER_OP("HasDefaultType")
// This verifies that a function using an op before a type attr (with
// a default) is added, still works. This is important for backwards
// compatibilty.
TEST(TFunc, MissingTypeAttrOld) {
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
// Name
"BackCompat",
// Args
{},
// Return values
{"y: float"},
// Attrs
{},
// Nodes
{// y = HasDefaultType(x), T missing, defaults to float
{{"y"}, "HasDefaultType", {}, {}}});
const char* e = R"P(
BackCompat() -> (y:float) {
y = HasDefaultType()
return y = y:out:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
// Should get T=float from Op's default.
const char* e2 = R"P(
() -> (n0:float) {
n0 = HasDefaultType[T=float]()
}
)P";
EXPECT_EQ(result.arg_types, DataTypeVector());
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, MissingTypeAttrNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(TFunc, MissingTypeAttr) {
auto fdef = FDH::Create(
// Name
"BackCompat",
// Args
@ -264,11 +184,8 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, NTimesTNodeDef) {
// Note that the equivalent FunctionDef using FunctionDef::Node requires
// using a _ListToArray to package up the two inputs to AddN as a single
// N*T edge.
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(TFunc, NTimesT) {
auto fdef = FDH::Create(
// Name
"NTimesT",
// Inputs
@ -790,8 +707,8 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
"input unknown is not found");
}
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(InstantiateErrors, TooManyInputs) {
auto fdef = FDH::Create(
// Name
"TooManyInputs",
// Inputs
@ -811,8 +728,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputs) {
"Expected input[2] == 'x' to be a control input.");
}
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(InstantiateErrors, TooFewInputs) {
auto fdef = FDH::Create(
// Name
"TooFewInputs",
// Inputs
@ -832,8 +749,8 @@ TEST(InstantiateErrors, NodeDef_TooFewInputs) {
"Attempt to access beyond input size: 2 >= 2");
}
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(InstantiateErrors, TooManyInputsFromArray1) {
auto fdef = FDH::Create(
// Name
"TooManyInputsFromArray",
// Inputs
@ -860,8 +777,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
"Expected input[1] == 'y' to be a control input.");
}
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(InstantiateErrors, TooManyInputsFromArray2) {
auto fdef = FDH::Create(
// Name
"TooManyInputsFromArray",
// Inputs
@ -888,8 +805,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
"Input a:output too long for inputs");
}
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
TEST(InstantiateErrors, TypeMismatch) {
auto fdef = FDH::Create(
// Name
"TypeMismatch",
// Inputs

View File

@ -178,15 +178,9 @@ void OpsUsedByGraph(const GraphDef& graph_def,
while (!functions_to_process.empty()) {
const FunctionDef* fun = functions_to_process.back();
functions_to_process.pop_back();
if (fun->node_def_size() > 0) {
for (const auto& node : fun->node_def()) {
mark_op_as_used(node.op());
}
} else { // TODO(josh11b): Eventually drop support for this.
for (const auto& node : fun->node()) {
mark_op_as_used(node.op());
}
}
}
// Filter out function names to produce output.

View File

@ -79,10 +79,12 @@ limitations under the License.
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
// now used by tf.concat_v2 (and soon tf.concat). Graphs use flooring
// division and mod semantics. TensorArrayV3. (12dec2016)
// 21. Dropped FunctionDef.Node support, switched to node_def introduced
// in version 12. (11jan2017)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 20
#define TF_GRAPH_DEF_VERSION 21
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//

View File

@ -25,8 +25,6 @@ import hashlib
import inspect
import re
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import op_def_pb2
@ -67,82 +65,6 @@ def _get_node_def(op):
return op._node_def # pylint: disable=protected-access
def _add_input_array(op, start, limit, dtype, func):
"""Adds a _ListToArray node in the func for op.inputs[start:limit]."""
node = function_pb2.FunctionDef.Node()
node.op = "_ListToArray"
ret_name = op.name + "_L2A_" + str(start)
node.ret.extend([ret_name])
node.arg.extend(
[_make_argname_from_tensor_name(x.name) for x in op.inputs[start:limit]])
num = limit - start
node.attr["Tin"].CopyFrom(
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
type=[dtype] * num)))
node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
func.node.extend([node])
return ret_name
def _add_identity_dtype_proto(func, src, dst, dtype_proto):
node = function_pb2.FunctionDef.Node()
node.op = "Identity"
node.arg.append(src)
node.ret.append(dst)
node.attr["T"].CopyFrom(dtype_proto)
func.node.extend([node])
def _add_identity_dtype_enum(func, src, dst, dtype):
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
_add_identity_dtype_proto(func, src, dst, dtype_proto)
def _add_output_array(op, start, limit, dtype, func):
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
# A node converting N*T to list(T)
node = function_pb2.FunctionDef.Node()
node.op = "_ArrayToList"
arg_name = op.name + "_A2L_" + str(start)
ret_name = arg_name + "_out"
node.ret.append(ret_name)
node.arg.append(arg_name)
node.attr["T"].CopyFrom(dtype_proto)
num = limit - start
node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
node.attr["out_types"].CopyFrom(
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
type=[dtype] * num)))
func.node.extend([node])
num = limit - start
# Adds an identity node for each element in the array N*T so that
# uses of each element can be added easily later. These Identity
# will be eliminated before graph execution.
for i in xrange(num):
_add_identity_dtype_proto(
func, ret_name + ":" + str(i),
_make_argname_from_tensor_name(op.outputs[i].name), dtype_proto)
return arg_name
def _add_output_list(op, start, limit, dtype_lst, func):
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit)
num = limit - start
assert len(dtype_lst) == num
# Adds an identity node for each element in the array N*T so that
# uses of each element can be added easily later. These Identity
# will be eliminated before graph execution.
for i in xrange(num):
_add_identity_dtype_enum(func,
ret_name + ":" + str(i),
_make_argname_from_tensor_name(op.outputs[i].name),
dtype_lst[i])
return ret_name
def _get_op_def(op):
# pylint: disable=protected-access
if hasattr(op, "_sig"):
@ -197,76 +119,6 @@ def _add_op_node(op, func, input_dict):
"%s missing from %s" % (node_def.input[i], input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
# To support legacy consumers, add an entry in func.node.
# TODO(josh11b): Delete this.
node = function_pb2.FunctionDef.Node()
node.op = op.type
op_def = _get_op_def(op)
attrs = node_def.attr
if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name))
else:
out_index = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.ret.append(
_add_output_array(op, out_index, out_index + num, dtype, func))
out_index += num
elif arg_def.type_list_attr:
dtype_lst = attrs[arg_def.type_list_attr].list.type
num = len(dtype_lst)
node.ret.append(
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
out_index += num
else:
node.ret.append(
_make_argname_from_tensor_name(op.outputs[out_index].name))
out_index += 1
inp_index = 0
for arg_def in op_def.input_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.arg.append(
_add_input_array(op, inp_index, inp_index + num, dtype, func))
inp_index += num
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
node.arg.extend([
_make_argname_from_tensor_name(op.inputs[i].name)
for i in range(inp_index, inp_index + num)
])
inp_index += num
else:
node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
inp_index += 1
node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in attrs.items():
node.attr[k].CopyFrom(v)
func.node.extend([node])
def _replace_ret(func, original, replacement):
for n in func.node:
for i, r in enumerate(n.ret):
if r == original:
n.ret[i] = replacement
return
raise ValueError("Could not find ret == '%s'" % original)
def _replace_arg(func, original, replacement):
for n in func.node:
for i, a in enumerate(n.arg):
if a == original:
n.arg[i] = replacement
for i, d in enumerate(n.dep):
if d == original:
n.dep[i] = replacement
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
@ -323,20 +175,9 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
for index, o in enumerate(outputs):
k = func.signature.output_arg[index].name
func.ret[k] = input_dict[o.name]
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
# function bodies.
orig = _make_argname_from_tensor_name(o.name)
if k != orig:
_add_identity_dtype_enum(func, orig, k,
func.signature.output_arg[index].type)
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
# function bodies.
k = _make_argname_from_tensor_name(o.name)
_replace_ret(func, k, n)
_replace_arg(func, k, n)
return func

View File

@ -2299,6 +2299,9 @@ class Graph(object):
if (function.grad_func_name is not None) and (
function.python_grad_func is not None):
raise ValueError("Gradient defined twice for function %s" % name)
# Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12:
self._graph_def_versions.min_consumer = 12
self._functions[name] = function
@property