Make function instantiation use std::vector<NodeDef> instead of GraphDef

It's about to turn into std::vector<NodeInfoPtr>; this change gets us partway there.

RELNOTES: n/a
PiperOrigin-RevId: 157771141
This commit is contained in:
Geoffrey Irving 2017-06-01 15:24:16 -07:00 committed by TensorFlower Gardener
parent 2e44be35dc
commit 8032e1f75d
11 changed files with 162 additions and 88 deletions

View File

@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph); Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
if (!s.ok()) { if (!s.ok()) {
delete graph; delete graph;
} else { } else {

View File

@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
const int version = g->versions().producer(); const int version = g->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get())); TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
pass(g.get()); pass(g.get());
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global())); std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
CopyGraph(*g, g1.get()); CopyGraph(*g, g1.get());

View File

@ -140,7 +140,7 @@ class FunctionInstantiationHelper {
FunctionInstantiationHelper(GetFunctionSignature get_function, FunctionInstantiationHelper(GetFunctionSignature get_function,
InstantiationResult* result) InstantiationResult* result)
: get_function_(std ::move(get_function)), result_(*result) { : get_function_(std ::move(get_function)), result_(*result) {
result_.gdef.Clear(); result_.nodes.clear();
} }
// Builds index for nodes that can be used as node's input arguments. // Builds index for nodes that can be used as node's input arguments.
@ -151,15 +151,14 @@ class FunctionInstantiationHelper {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
CHECK_GE(dtypes.size(), size_t{1}); CHECK_GE(dtypes.size(), size_t{1});
GraphDef* gdef = &result_.gdef; int arg_index = result_.nodes.size();
int arg_index = gdef->node_size();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
// Creates dtypes.size() nodes in the gdef. // Creates dtypes.size() nodes in the graph.
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i), TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
{true, arg_index, 0, false, {dtypes[i]}})); {true, arg_index, 0, false, {dtypes[i]}}));
DCHECK_EQ(arg_index, gdef->node_size()); DCHECK_EQ(arg_index, result_.nodes.size());
string name = arg_def.name(); string name = arg_def.name();
if (dtypes.size() > 1) { if (dtypes.size() > 1) {
strings::StrAppend(&name, "_", i); strings::StrAppend(&name, "_", i);
@ -332,13 +331,13 @@ class FunctionInstantiationHelper {
// Adds the actual node inputs to the result graph by converting indexes to // Adds the actual node inputs to the result graph by converting indexes to
// the node names. // the node names.
void AddNodeInputs() { void AddNodeInputs() {
for (int i = 0; i < result_.gdef.node_size(); i++) { for (int i = 0; i < result_.nodes.size(); i++) {
NodeInfo& node_info = nodes_[i]; NodeInfo& node_info = nodes_[i];
for (const auto& p : node_info.data_inputs) { for (const auto& p : node_info.data_inputs) {
result_.gdef.mutable_node(i)->add_input(Name(p.first, p.second)); result_.nodes[i].add_input(Name(p.first, p.second));
} }
for (int index : node_info.control_inputs) { for (int index : node_info.control_inputs) {
result_.gdef.mutable_node(i)->add_input(Dep(index)); result_.nodes[i].add_input(Dep(index));
} }
} }
} }
@ -348,11 +347,10 @@ class FunctionInstantiationHelper {
// node's input arguments. // node's input arguments.
// //
// If is_func_arg is true, the name is a function's argument. In // If is_func_arg is true, the name is a function's argument. In
// this case, the produced graph def has gdef.node[nid ... nid + // this case, the produced graph def has node[nid:nid + dtype.size()].
// dtype.size()).
// //
// Otherwise, the name is a function body's node return value. In // Otherwise, the name is a function body's node return value. In
// this case, the produced graph def has one node gdef.node[nid] and // this case, the produced graph def has one node node[nid] and
// the node's output index [idx ... idx + num) corresponds to the // the node's output index [idx ... idx + num) corresponds to the
// named outputs. // named outputs.
// //
@ -398,10 +396,11 @@ class FunctionInstantiationHelper {
} }
NodeDef* AddNode(const string& name) { NodeDef* AddNode(const string& name) {
NodeDef* gnode = result_.gdef.add_node(); result_.nodes.emplace_back();
NodeDef* gnode = &result_.nodes.back();
gnode->set_name(name); gnode->set_name(name);
nodes_.push_back({name, {}, {}}); nodes_.push_back({name, {}, {}});
CHECK_EQ(result_.gdef.node_size(), nodes_.size()); CHECK_EQ(result_.nodes.size(), nodes_.size());
return gnode; return gnode;
} }
@ -429,7 +428,7 @@ class FunctionInstantiationHelper {
// Control inputs (dependencies). // Control inputs (dependencies).
std::vector<int> control_inputs; std::vector<int> control_inputs;
}; };
// nodes_[i] is the information about result_.gdef.node(i). // nodes_[i] is the information about result_.nodes[i].
std::vector<NodeInfo> nodes_; std::vector<NodeInfo> nodes_;
}; };
@ -545,17 +544,17 @@ string Print(const FunctionDef& fdef) {
return out; return out;
} }
string Print(const GraphDef& gdef) { string Print(gtl::ArraySlice<const NodeDef*> nodes) {
std::vector<const NodeDef*> arg; std::vector<const NodeDef*> arg;
std::vector<const NodeDef*> ret; std::vector<const NodeDef*> ret;
std::vector<const NodeDef*> body; std::vector<const NodeDef*> body;
for (const NodeDef& n : gdef.node()) { for (const NodeDef* n : nodes) {
if (n.op() == "_Arg") { if (n->op() == "_Arg") {
arg.push_back(&n); arg.push_back(n);
} else if (n.op() == "_Retval") { } else if (n->op() == "_Retval") {
ret.push_back(&n); ret.push_back(n);
} else { } else {
body.push_back(&n); body.push_back(n);
} }
} }
auto comp = [](const NodeDef* x, const NodeDef* y) { auto comp = [](const NodeDef* x, const NodeDef* y) {
@ -570,12 +569,11 @@ string Print(const GraphDef& gdef) {
string out; string out;
strings::StrAppend(&out, "\n("); strings::StrAppend(&out, "\n(");
auto get_type = [](const NodeDef& n) { auto get_type = [](const NodeDef& n) {
for (auto a : n.attr()) { DataType dt;
if (a.first == "T") { if (!GetNodeAttr(n, "T", &dt).ok()) {
return DataTypeString(a.second.type()); dt = DT_INVALID;
} }
} return DataTypeString(dt);
return DataTypeString(DT_INVALID);
}; };
for (size_t i = 0; i < arg.size(); ++i) { for (size_t i = 0; i < arg.size(); ++i) {
const NodeDef* n = arg[i]; const NodeDef* n = arg[i];
@ -663,13 +661,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
for (int i = 0; i < fdef.node_def_size(); ++i) { for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
result->gdef.node_size() + i); result->nodes.size() + i);
if (!s.ok()) { if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
return s; return s;
} }
} }
// Emits one gdef.node for each fdef.node_def. // Emits one node for each fdef.node_def.
for (int i = 0; i < fdef.node_def_size(); ++i) { for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
if (!s.ok()) { if (!s.ok()) {
@ -697,7 +695,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
string DebugString(const FunctionDef& func_def) { return Print(func_def); } string DebugString(const FunctionDef& func_def) { return Print(func_def); }
string DebugString(const GraphDef& instantiated_func_def) { string DebugString(const GraphDef& instantiated_func_def) {
return Print(instantiated_func_def); std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_def.node()) {
ptrs.push_back(&n);
}
return Print(ptrs);
}
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_nodes) {
ptrs.push_back(&n);
}
return Print(ptrs);
} }
string DebugStringWhole(const GraphDef& gdef) { string DebugStringWhole(const GraphDef& gdef) {

View File

@ -200,7 +200,7 @@ typedef std::function<Status(const string&, const OpDef**)>
struct InstantiationResult { struct InstantiationResult {
DataTypeVector arg_types; DataTypeVector arg_types;
DataTypeVector ret_types; DataTypeVector ret_types;
GraphDef gdef; std::vector<NodeDef> nodes;
}; };
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
GetFunctionSignature get_function, GetFunctionSignature get_function,
@ -216,6 +216,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
// etc.) // etc.)
string DebugString(const FunctionDef& func_def); string DebugString(const FunctionDef& func_def);
string DebugString(const GraphDef& instantiated_func_def); string DebugString(const GraphDef& instantiated_func_def);
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
// Returns a debug string for a top level graph (the main program and // Returns a debug string for a top level graph (the main program and
// its supporting functions defined in its library). // its supporting functions defined in its library).

View File

@ -108,7 +108,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
TEST(TFunc, ControlDep) { TEST(TFunc, ControlDep) {
@ -154,7 +154,7 @@ ControlDep(x:int32) -> (y:int32) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
REGISTER_OP("HasDefaultType") REGISTER_OP("HasDefaultType")
@ -198,7 +198,7 @@ BackCompat() -> (y:float) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector()); EXPECT_EQ(result.arg_types, DataTypeVector());
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
TEST(TFunc, NTimesT) { TEST(TFunc, NTimesT) {
@ -234,7 +234,7 @@ NTimesT(x:float, y:float) -> (z:float) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
// NOTE: This is the simplest Map op. It takes a f:T->U. // NOTE: This is the simplest Map op. It takes a f:T->U.
@ -299,7 +299,7 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
TEST(TFunc, ControlDeps) { TEST(TFunc, ControlDeps) {
@ -344,7 +344,7 @@ ControlDeps(x:float) -> () {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({})); EXPECT_EQ(result.ret_types, DataTypeVector({}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
TEST(TFunc, XTimesTwo) { TEST(TFunc, XTimesTwo) {
@ -425,7 +425,7 @@ Test(i:float) -> (o:float) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
REGISTER_OP("Cond") REGISTER_OP("Cond")
@ -493,7 +493,7 @@ MySelect(x:float) -> (z:float) {
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.nodes), e2);
} }
static void HasError(const Status& s, const string& substr) { static void HasError(const Status& s, const string& substr) {
@ -1028,7 +1028,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
*proto.add_gradient() = grad; *proto.add_gradient() = grad;
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
TF_EXPECT_OK(lib_def.AddLibrary(lib_def3)); TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
}; }
TEST(FunctionLibraryDefinitionTest, ToProto) { TEST(FunctionLibraryDefinitionTest, ToProto) {
FunctionDefLibrary proto1; FunctionDefLibrary proto1;

View File

@ -39,6 +39,14 @@ string SummarizeGraphDef(const GraphDef& graph_def) {
return ret; return ret;
} }
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs) {
string ret;
for (const NodeDef& node : node_defs) {
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
}
return ret;
}
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
for (const NodeDef& node : graph_def.node()) { for (const NodeDef& node : graph_def.node()) {
TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));

View File

@ -27,6 +27,7 @@ namespace tensorflow {
// Produce a human-readable version of a GraphDef that is more concise // Produce a human-readable version of a GraphDef that is more concise
// than a text-format proto. // than a text-format proto.
string SummarizeGraphDef(const GraphDef& graph_def); string SummarizeGraphDef(const GraphDef& graph_def);
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs);
// Validates the syntax of a GraphDef provided externally. // Validates the syntax of a GraphDef provided externally.
// //

View File

@ -91,24 +91,36 @@ class GraphConstructor {
bool importing; bool importing;
}; };
static Status Construct(const Options& opts, const GraphDef* gdef, Graph* g, typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
// versions and library may be nullptr
static Status Construct(const Options& opts, NodeDefSlice node_defs,
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) { std::vector<std::pair<Node*, int>>* return_tensors) {
TF_RETURN_IF_ERROR(CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION, if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER, TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph")); "GraphDef", "graph"));
GraphConstructor c(opts, gdef, g, refiner, return_tensors); }
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
return_tensors);
const Status s = c.TryImport(); const Status s = c.TryImport();
if (!s.ok()) c.Undo(); if (!s.ok()) c.Undo();
return s; return s;
} }
private: private:
GraphConstructor(const Options& opts, const GraphDef* gdef, Graph* g, GraphConstructor(const Options& opts, NodeDefSlice node_defs,
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) std::vector<std::pair<Node*, int>>* return_tensors)
: opts_(opts), : opts_(opts),
gdef_(gdef), node_defs_(node_defs),
versions_(versions),
library_(library),
g_(g), g_(g),
original_versions_(g->versions()), original_versions_(g->versions()),
refiner_(refiner), refiner_(refiner),
@ -159,7 +171,9 @@ class GraphConstructor {
// From constructor // From constructor
const Options opts_; const Options opts_;
const GraphDef* gdef_; const NodeDefSlice node_defs_;
const VersionDef* versions_;
const FunctionDefLibrary* library_;
Graph* g_; Graph* g_;
const VersionDef original_versions_; const VersionDef original_versions_;
@ -168,7 +182,7 @@ class GraphConstructor {
// May be null. Not owned. // May be null. Not owned.
std::vector<std::pair<Node*, int>>* return_tensors_; std::vector<std::pair<Node*, int>>* return_tensors_;
// Mapping from node name to the index within gdef_ // Mapping from node name to the index within node_defs_
struct NodeInfo { struct NodeInfo {
explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
// std::unordered_map<> requires that we have a default constructor. // std::unordered_map<> requires that we have a default constructor.
@ -183,18 +197,18 @@ class GraphConstructor {
// Mapping from node name to the existing node in g_ // Mapping from node name to the existing node in g_
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_; std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_;
// Index of NodeDefs in gdef_ with all inputs already converted. // Index of NodeDefs in node_defs_ with all inputs already converted.
std::vector<int> ready_; std::vector<int> ready_;
// Mapping between index within gdef_ and the number of inputs that // Mapping between index within node_defs_ and the number of inputs that
// still need to be converted. // still need to be converted.
std::vector<int> pending_count_; std::vector<int> pending_count_;
// Mapping between index within gdef_ and the index within gdef_ of // Mapping between index within node_defs_ and the index within node_defs_ of
// all nodes it outputs to. // all nodes it outputs to.
std::vector<gtl::InlinedVector<int, 4>> outputs_; std::vector<gtl::InlinedVector<int, 4>> outputs_;
// Used in the conversion from gdef_ to g_ to represent the ith input // Used in the conversion from node_defs_ to g_ to represent the ith input
// of a node. // of a node.
struct InputInfo { struct InputInfo {
explicit InputInfo(const string& node_name, Node* n, int i) explicit InputInfo(const string& node_name, Node* n, int i)
@ -205,7 +219,7 @@ class GraphConstructor {
int index; int index;
}; };
// Used in the conversion from gdef_ to g_ to represent an edge from // Used in the conversion from node_defs_ to g_ to represent an edge from
// the node named 'name' to node 'n'. // the node named 'name' to node 'n'.
struct EdgeInfo { struct EdgeInfo {
explicit EdgeInfo(const string& name, int i1, Node* n, int i2) explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
@ -254,8 +268,8 @@ Status GraphConstructor::EnsureNoNameCollisions() {
} }
} }
if (opts_.prefix.empty() && opts_.importing) { if (opts_.prefix.empty() && opts_.importing) {
for (int n = 0; n < gdef_->node_size(); ++n) { for (const NodeDef* n : node_defs_) {
const string& name = gdef_->node(n).name(); const string& name = n->name();
if (existing_nodes_.find(name) != existing_nodes_.end()) { if (existing_nodes_.find(name) != existing_nodes_.end()) {
return errors::InvalidArgument("Node '", name, return errors::InvalidArgument("Node '", name,
"' already exists in the Graph"); "' already exists in the Graph");
@ -312,8 +326,8 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() {
Status GraphConstructor::BuildNodeIndex() { Status GraphConstructor::BuildNodeIndex() {
// Validate the node names and add them to gdef_nodes_. // Validate the node names and add them to gdef_nodes_.
for (int n = 0; n < gdef_->node_size(); ++n) { for (int n = 0; n < node_defs_.size(); ++n) {
const NodeDef& node_def(gdef_->node(n)); const NodeDef& node_def = *node_defs_[n];
if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Node '", node_def.name(), "Node '", node_def.name(),
@ -351,13 +365,13 @@ Status GraphConstructor::BuildNodeIndex() {
} }
Status GraphConstructor::InitFromEdges() { Status GraphConstructor::InitFromEdges() {
const int num_nodes = gdef_->node_size(); const int num_nodes = node_defs_.size();
pending_count_.reserve(num_nodes); pending_count_.reserve(num_nodes);
outputs_.resize(num_nodes); outputs_.resize(num_nodes);
// Parse the inputs for each node. // Parse the inputs for each node.
for (int n = 0; n < num_nodes; ++n) { for (int n = 0; n < num_nodes; ++n) {
const NodeDef& node_def(gdef_->node(n)); const NodeDef& node_def = *node_defs_[n];
if (IsMerge(node_def)) { if (IsMerge(node_def)) {
// for merge only wait for one non-control input. // for merge only wait for one non-control input.
int32 num_control_edges = 0; int32 num_control_edges = 0;
@ -489,7 +503,9 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
AddDefaultsToNodeDef(*op_def, node_def); AddDefaultsToNodeDef(*op_def, node_def);
TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, gdef_->versions().producer())); if (versions_) {
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
}
return Status::OK(); return Status::OK();
} }
@ -608,7 +624,9 @@ void GraphConstructor::AddPrefixToNodeDef(
Status GraphConstructor::Convert() { Status GraphConstructor::Convert() {
// Import functions before adding nodes, since imported nodes may refer to // Import functions before adding nodes, since imported nodes may refer to
// functions // functions
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library())); if (library_) {
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
}
std::vector<InputInfo> inputs; std::vector<InputInfo> inputs;
int processed = 0; int processed = 0;
@ -626,14 +644,14 @@ Status GraphConstructor::Convert() {
inputs.clear(); inputs.clear();
bool has_data_back_edge = false; bool has_data_back_edge = false;
const NodeDef& original_node_def = gdef_->node(o); const NodeDef& original_node_def = *node_defs_[o];
NodeDef imported_node_def; NodeDef imported_node_def;
const NodeDef* node_def; const NodeDef* node_def;
// input_already_exists[i] is true iff the i-th input of the node we're // input_already_exists[i] is true iff the i-th input of the node we're
// importing refers to a preexisting node in g_ (i.e. input[i] existed prior // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
// to importing gdef_). Conversely, input_already_exists[i] is false iff // to importing node_defs_). Conversely, input_already_exists[i] is false
// the input refers to a node in gdef_. // iff the input refers to a node in node_defs_.
input_already_exists.clear(); input_already_exists.clear();
input_already_exists.resize(original_node_def.input_size(), false); input_already_exists.resize(original_node_def.input_size(), false);
@ -731,8 +749,8 @@ Status GraphConstructor::Convert() {
} }
} }
if (processed < gdef_->node_size()) { if (processed < node_defs_.size()) {
return errors::InvalidArgument(gdef_->node_size() - processed, return errors::InvalidArgument(node_defs_.size() - processed,
" nodes in a cycle"); " nodes in a cycle");
} }
return Status::OK(); return Status::OK();
@ -756,20 +774,21 @@ Status GraphConstructor::AddBackEdges() {
} }
Status GraphConstructor::UpdateVersionDef() { Status GraphConstructor::UpdateVersionDef() {
if (versions_ == nullptr) return Status::OK();
if (!opts_.importing) { if (!opts_.importing) {
g_->set_versions(gdef_->versions()); g_->set_versions(*versions_);
return Status::OK(); return Status::OK();
} }
VersionDef versions = g_->versions(); VersionDef versions = g_->versions();
versions.set_producer( versions.set_producer(std::min(versions.producer(), versions_->producer()));
std::min(versions.producer(), gdef_->versions().producer()));
versions.set_min_consumer( versions.set_min_consumer(
std::max(versions.min_consumer(), gdef_->versions().min_consumer())); std::max(versions.min_consumer(), versions_->min_consumer()));
if (gdef_->versions().bad_consumers_size() > 0) { if (versions_->bad_consumers_size() > 0) {
std::set<int> bad(versions.bad_consumers().begin(), std::set<int> bad(versions.bad_consumers().begin(),
versions.bad_consumers().end()); versions.bad_consumers().end());
bad.insert(gdef_->versions().bad_consumers().begin(), bad.insert(versions_->bad_consumers().begin(),
gdef_->versions().bad_consumers().end()); versions_->bad_consumers().end());
versions.clear_bad_consumers(); versions.clear_bad_consumers();
for (int v : bad) { for (int v : bad) {
versions.add_bad_consumers(v); versions.add_bad_consumers(v);
@ -837,7 +856,20 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) { const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr); return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, &refiner, nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
// TODO(irving): Copy will go away once NodeInfo exists
std::vector<const NodeDef*> node_defs;
for (const auto& n : nodes) {
node_defs.push_back(&n);
}
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, nullptr);
} }
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
@ -886,7 +918,9 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
refiner->set_graph_def_version( refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer())); std::min(refiner->graph_def_version(), gdef.versions().producer()));
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors); return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, refiner,
return_tensors);
} }
void CopyGraph(const Graph& src, Graph* dest) { void CopyGraph(const Graph& src, Graph* dest) {

View File

@ -46,6 +46,12 @@ struct GraphConstructorOptions {
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g); const GraphDef& gdef, Graph* g);
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
// instantiation.
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g);
// Add the graph in GraphDef gdef into an existing Graph *g. // Add the graph in GraphDef gdef into an existing Graph *g.
// //
// On error, returns non-OK and leaves *g unmodified. // On error, returns non-OK and leaves *g unmodified.

View File

@ -24,15 +24,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, template <class NodeDefs>
string* diff, const EqualGraphDefOptions& options) { static bool EqualNodeDefsHelper(
// Intentionally do not check that versions match so that this routine can const NodeDefs& actual, const protobuf::RepeatedPtrField<NodeDef>& expected,
// be used for less brittle golden file tests.
return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
}
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
const protobuf::RepeatedPtrField<NodeDef>& expected,
string* diff, const EqualGraphDefOptions& options) { string* diff, const EqualGraphDefOptions& options) {
std::unordered_map<string, const NodeDef*> actual_index; std::unordered_map<string, const NodeDef*> actual_index;
for (const NodeDef& node : actual) { for (const NodeDef& node : actual) {
@ -68,6 +62,24 @@ bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
return true; return true;
} }
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options) {
// Intentionally do not check that versions match so that this routine can
// be used for less brittle golden file tests.
return EqualNodeDefsHelper(actual.node(), expected.node(), diff, options);
}
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options) {
return EqualNodeDefsHelper(actual, expected.node(), diff, options);
}
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
const protobuf::RepeatedPtrField<NodeDef>& expected,
string* diff, const EqualGraphDefOptions& options) {
return EqualNodeDefsHelper(actual, expected, diff, options);
}
namespace { namespace {
string JoinStringField(const protobuf::RepeatedPtrField<string>& f) { string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {

View File

@ -36,6 +36,8 @@ struct EqualGraphDefOptions {
// nodes must be consistent. // nodes must be consistent.
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options = {}); string* diff, const EqualGraphDefOptions& options = {});
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options = {});
// Determines if actual and expected are equal, ignoring: ordering of // Determines if actual and expected are equal, ignoring: ordering of
// attrs, internal attributes (if set in `options`), and control inputs. // attrs, internal attributes (if set in `options`), and control inputs.