Make function instantiation use std::vector<NodeDef> instead of GraphDef

It's about to turn into std::vector<NodeInfoPtr>; this change gets us partway there.

RELNOTES: n/a
PiperOrigin-RevId: 157771141
This commit is contained in:
Geoffrey Irving 2017-06-01 15:24:16 -07:00 committed by TensorFlower Gardener
parent 2e44be35dc
commit 8032e1f75d
11 changed files with 162 additions and 88 deletions

View File

@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
if (!s.ok()) {
delete graph;
} else {

View File

@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
const int version = g->versions().producer();
LocalExecutorParams params;
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
pass(g.get());
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
CopyGraph(*g, g1.get());

View File

@ -140,7 +140,7 @@ class FunctionInstantiationHelper {
FunctionInstantiationHelper(GetFunctionSignature get_function,
InstantiationResult* result)
: get_function_(std ::move(get_function)), result_(*result) {
result_.gdef.Clear();
result_.nodes.clear();
}
// Builds index for nodes that can be used as node's input arguments.
@ -151,15 +151,14 @@ class FunctionInstantiationHelper {
TF_RETURN_IF_ERROR(
ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
CHECK_GE(dtypes.size(), size_t{1});
GraphDef* gdef = &result_.gdef;
int arg_index = gdef->node_size();
int arg_index = result_.nodes.size();
TF_RETURN_IF_ERROR(
AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
// Creates dtypes.size() nodes in the gdef.
// Creates dtypes.size() nodes in the graph.
for (size_t i = 0; i < dtypes.size(); ++i) {
TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
{true, arg_index, 0, false, {dtypes[i]}}));
DCHECK_EQ(arg_index, gdef->node_size());
DCHECK_EQ(arg_index, result_.nodes.size());
string name = arg_def.name();
if (dtypes.size() > 1) {
strings::StrAppend(&name, "_", i);
@ -332,13 +331,13 @@ class FunctionInstantiationHelper {
// Adds the actual node inputs to the result graph by converting indexes to
// the node names.
void AddNodeInputs() {
for (int i = 0; i < result_.gdef.node_size(); i++) {
for (int i = 0; i < result_.nodes.size(); i++) {
NodeInfo& node_info = nodes_[i];
for (const auto& p : node_info.data_inputs) {
result_.gdef.mutable_node(i)->add_input(Name(p.first, p.second));
result_.nodes[i].add_input(Name(p.first, p.second));
}
for (int index : node_info.control_inputs) {
result_.gdef.mutable_node(i)->add_input(Dep(index));
result_.nodes[i].add_input(Dep(index));
}
}
}
@ -348,11 +347,10 @@ class FunctionInstantiationHelper {
// node's input arguments.
//
// If is_func_arg is true, the name is a function's argument. In
// this case, the produced graph def has gdef.node[nid ... nid +
// dtype.size()).
// this case, the produced graph def has node[nid:nid + dtype.size()].
//
// Otherwise, the name is a function body's node return value. In
// this case, the produced graph def has one node gdef.node[nid] and
// this case, the produced graph def has one node node[nid] and
// the node's output index [idx ... idx + num) corresponds to the
// named outputs.
//
@ -398,10 +396,11 @@ class FunctionInstantiationHelper {
}
NodeDef* AddNode(const string& name) {
NodeDef* gnode = result_.gdef.add_node();
result_.nodes.emplace_back();
NodeDef* gnode = &result_.nodes.back();
gnode->set_name(name);
nodes_.push_back({name, {}, {}});
CHECK_EQ(result_.gdef.node_size(), nodes_.size());
CHECK_EQ(result_.nodes.size(), nodes_.size());
return gnode;
}
@ -429,7 +428,7 @@ class FunctionInstantiationHelper {
// Control inputs (dependencies).
std::vector<int> control_inputs;
};
// nodes_[i] is the information about result_.gdef.node(i).
// nodes_[i] is the information about result_.nodes[i].
std::vector<NodeInfo> nodes_;
};
@ -545,17 +544,17 @@ string Print(const FunctionDef& fdef) {
return out;
}
string Print(const GraphDef& gdef) {
string Print(gtl::ArraySlice<const NodeDef*> nodes) {
std::vector<const NodeDef*> arg;
std::vector<const NodeDef*> ret;
std::vector<const NodeDef*> body;
for (const NodeDef& n : gdef.node()) {
if (n.op() == "_Arg") {
arg.push_back(&n);
} else if (n.op() == "_Retval") {
ret.push_back(&n);
for (const NodeDef* n : nodes) {
if (n->op() == "_Arg") {
arg.push_back(n);
} else if (n->op() == "_Retval") {
ret.push_back(n);
} else {
body.push_back(&n);
body.push_back(n);
}
}
auto comp = [](const NodeDef* x, const NodeDef* y) {
@ -570,12 +569,11 @@ string Print(const GraphDef& gdef) {
string out;
strings::StrAppend(&out, "\n(");
auto get_type = [](const NodeDef& n) {
for (auto a : n.attr()) {
if (a.first == "T") {
return DataTypeString(a.second.type());
}
DataType dt;
if (!GetNodeAttr(n, "T", &dt).ok()) {
dt = DT_INVALID;
}
return DataTypeString(DT_INVALID);
return DataTypeString(dt);
};
for (size_t i = 0; i < arg.size(); ++i) {
const NodeDef* n = arg[i];
@ -663,13 +661,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
result->gdef.node_size() + i);
result->nodes.size() + i);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
return s;
}
}
// Emits one gdef.node for each fdef.node_def.
// Emits one node for each fdef.node_def.
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
if (!s.ok()) {
@ -697,7 +695,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
string DebugString(const FunctionDef& func_def) { return Print(func_def); }
string DebugString(const GraphDef& instantiated_func_def) {
return Print(instantiated_func_def);
std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_def.node()) {
ptrs.push_back(&n);
}
return Print(ptrs);
}
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_nodes) {
ptrs.push_back(&n);
}
return Print(ptrs);
}
string DebugStringWhole(const GraphDef& gdef) {

View File

@ -200,7 +200,7 @@ typedef std::function<Status(const string&, const OpDef**)>
struct InstantiationResult {
DataTypeVector arg_types;
DataTypeVector ret_types;
GraphDef gdef;
std::vector<NodeDef> nodes;
};
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
GetFunctionSignature get_function,
@ -216,6 +216,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
// etc.)
string DebugString(const FunctionDef& func_def);
string DebugString(const GraphDef& instantiated_func_def);
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
// Returns a debug string for a top level graph (the main program and
// its supporting functions defined in its library).

View File

@ -108,7 +108,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
TEST(TFunc, ControlDep) {
@ -154,7 +154,7 @@ ControlDep(x:int32) -> (y:int32) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
REGISTER_OP("HasDefaultType")
@ -198,7 +198,7 @@ BackCompat() -> (y:float) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector());
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
TEST(TFunc, NTimesT) {
@ -234,7 +234,7 @@ NTimesT(x:float, y:float) -> (z:float) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
// NOTE: This is the simplest Map op. It takes a f:T->U.
@ -299,7 +299,7 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
TEST(TFunc, ControlDeps) {
@ -344,7 +344,7 @@ ControlDeps(x:float) -> () {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
TEST(TFunc, XTimesTwo) {
@ -425,7 +425,7 @@ Test(i:float) -> (o:float) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
REGISTER_OP("Cond")
@ -493,7 +493,7 @@ MySelect(x:float) -> (z:float) {
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
EXPECT_EQ(DebugString(result.gdef), e2);
EXPECT_EQ(DebugString(result.nodes), e2);
}
static void HasError(const Status& s, const string& substr) {
@ -1028,7 +1028,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
*proto.add_gradient() = grad;
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
};
}
TEST(FunctionLibraryDefinitionTest, ToProto) {
FunctionDefLibrary proto1;

View File

@ -39,6 +39,14 @@ string SummarizeGraphDef(const GraphDef& graph_def) {
return ret;
}
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs) {
string ret;
for (const NodeDef& node : node_defs) {
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
}
return ret;
}
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
for (const NodeDef& node : graph_def.node()) {
TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));

View File

@ -27,6 +27,7 @@ namespace tensorflow {
// Produce a human-readable version of a GraphDef that is more concise
// than a text-format proto.
string SummarizeGraphDef(const GraphDef& graph_def);
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs);
// Validates the syntax of a GraphDef provided externally.
//

View File

@ -91,24 +91,36 @@ class GraphConstructor {
bool importing;
};
static Status Construct(const Options& opts, const GraphDef* gdef, Graph* g,
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
// versions and library may be nullptr
static Status Construct(const Options& opts, NodeDefSlice node_defs,
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
TF_RETURN_IF_ERROR(CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
GraphConstructor c(opts, gdef, g, refiner, return_tensors);
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
return_tensors);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
}
private:
GraphConstructor(const Options& opts, const GraphDef* gdef, Graph* g,
GraphConstructor(const Options& opts, NodeDefSlice node_defs,
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors)
: opts_(opts),
gdef_(gdef),
node_defs_(node_defs),
versions_(versions),
library_(library),
g_(g),
original_versions_(g->versions()),
refiner_(refiner),
@ -159,7 +171,9 @@ class GraphConstructor {
// From constructor
const Options opts_;
const GraphDef* gdef_;
const NodeDefSlice node_defs_;
const VersionDef* versions_;
const FunctionDefLibrary* library_;
Graph* g_;
const VersionDef original_versions_;
@ -168,7 +182,7 @@ class GraphConstructor {
// May be null. Not owned.
std::vector<std::pair<Node*, int>>* return_tensors_;
// Mapping from node name to the index within gdef_
// Mapping from node name to the index within node_defs_
struct NodeInfo {
explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
// std::unordered_map<> requires that we have a default constructor.
@ -183,18 +197,18 @@ class GraphConstructor {
// Mapping from node name to the existing node in g_
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_;
// Index of NodeDefs in gdef_ with all inputs already converted.
// Index of NodeDefs in node_defs_ with all inputs already converted.
std::vector<int> ready_;
// Mapping between index within gdef_ and the number of inputs that
// Mapping between index within node_defs_ and the number of inputs that
// still need to be converted.
std::vector<int> pending_count_;
// Mapping between index within gdef_ and the index within gdef_ of
// Mapping between index within node_defs_ and the index within node_defs_ of
// all nodes it outputs to.
std::vector<gtl::InlinedVector<int, 4>> outputs_;
// Used in the conversion from gdef_ to g_ to represent the ith input
// Used in the conversion from node_defs_ to g_ to represent the ith input
// of a node.
struct InputInfo {
explicit InputInfo(const string& node_name, Node* n, int i)
@ -205,7 +219,7 @@ class GraphConstructor {
int index;
};
// Used in the conversion from gdef_ to g_ to represent an edge from
// Used in the conversion from node_defs_ to g_ to represent an edge from
// the node named 'name' to node 'n'.
struct EdgeInfo {
explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
@ -254,8 +268,8 @@ Status GraphConstructor::EnsureNoNameCollisions() {
}
}
if (opts_.prefix.empty() && opts_.importing) {
for (int n = 0; n < gdef_->node_size(); ++n) {
const string& name = gdef_->node(n).name();
for (const NodeDef* n : node_defs_) {
const string& name = n->name();
if (existing_nodes_.find(name) != existing_nodes_.end()) {
return errors::InvalidArgument("Node '", name,
"' already exists in the Graph");
@ -312,8 +326,8 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() {
Status GraphConstructor::BuildNodeIndex() {
// Validate the node names and add them to gdef_nodes_.
for (int n = 0; n < gdef_->node_size(); ++n) {
const NodeDef& node_def(gdef_->node(n));
for (int n = 0; n < node_defs_.size(); ++n) {
const NodeDef& node_def = *node_defs_[n];
if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
return errors::InvalidArgument(
"Node '", node_def.name(),
@ -351,13 +365,13 @@ Status GraphConstructor::BuildNodeIndex() {
}
Status GraphConstructor::InitFromEdges() {
const int num_nodes = gdef_->node_size();
const int num_nodes = node_defs_.size();
pending_count_.reserve(num_nodes);
outputs_.resize(num_nodes);
// Parse the inputs for each node.
for (int n = 0; n < num_nodes; ++n) {
const NodeDef& node_def(gdef_->node(n));
const NodeDef& node_def = *node_defs_[n];
if (IsMerge(node_def)) {
// for merge only wait for one non-control input.
int32 num_control_edges = 0;
@ -489,7 +503,9 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
AddDefaultsToNodeDef(*op_def, node_def);
TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, gdef_->versions().producer()));
if (versions_) {
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
}
return Status::OK();
}
@ -608,7 +624,9 @@ void GraphConstructor::AddPrefixToNodeDef(
Status GraphConstructor::Convert() {
// Import functions before adding nodes, since imported nodes may refer to
// functions
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library()));
if (library_) {
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
}
std::vector<InputInfo> inputs;
int processed = 0;
@ -626,14 +644,14 @@ Status GraphConstructor::Convert() {
inputs.clear();
bool has_data_back_edge = false;
const NodeDef& original_node_def = gdef_->node(o);
const NodeDef& original_node_def = *node_defs_[o];
NodeDef imported_node_def;
const NodeDef* node_def;
// input_already_exists[i] is true iff the i-th input of the node we're
// importing refers to a preexisting node in g_ (i.e. input[i] existed prior
// to importing gdef_). Conversely, input_already_exists[i] is false iff
// the input refers to a node in gdef_.
// to importing node_defs_). Conversely, input_already_exists[i] is false
// iff the input refers to a node in node_defs_.
input_already_exists.clear();
input_already_exists.resize(original_node_def.input_size(), false);
@ -731,8 +749,8 @@ Status GraphConstructor::Convert() {
}
}
if (processed < gdef_->node_size()) {
return errors::InvalidArgument(gdef_->node_size() - processed,
if (processed < node_defs_.size()) {
return errors::InvalidArgument(node_defs_.size() - processed,
" nodes in a cycle");
}
return Status::OK();
@ -756,20 +774,21 @@ Status GraphConstructor::AddBackEdges() {
}
Status GraphConstructor::UpdateVersionDef() {
if (versions_ == nullptr) return Status::OK();
if (!opts_.importing) {
g_->set_versions(gdef_->versions());
g_->set_versions(*versions_);
return Status::OK();
}
VersionDef versions = g_->versions();
versions.set_producer(
std::min(versions.producer(), gdef_->versions().producer()));
versions.set_producer(std::min(versions.producer(), versions_->producer()));
versions.set_min_consumer(
std::max(versions.min_consumer(), gdef_->versions().min_consumer()));
if (gdef_->versions().bad_consumers_size() > 0) {
std::max(versions.min_consumer(), versions_->min_consumer()));
if (versions_->bad_consumers_size() > 0) {
std::set<int> bad(versions.bad_consumers().begin(),
versions.bad_consumers().end());
bad.insert(gdef_->versions().bad_consumers().begin(),
gdef_->versions().bad_consumers().end());
bad.insert(versions_->bad_consumers().begin(),
versions_->bad_consumers().end());
versions.clear_bad_consumers();
for (int v : bad) {
versions.add_bad_consumers(v);
@ -837,7 +856,20 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr);
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, &refiner, nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
// TODO(irving): Copy will go away once NodeInfo exists
std::vector<const NodeDef*> node_defs;
for (const auto& n : nodes) {
node_defs.push_back(&n);
}
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
@ -886,7 +918,9 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer()));
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors);
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, refiner,
return_tensors);
}
void CopyGraph(const Graph& src, Graph* dest) {

View File

@ -46,6 +46,12 @@ struct GraphConstructorOptions {
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
// instantiation.
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g);
// Add the graph in GraphDef gdef into an existing Graph *g.
//
// On error, returns non-OK and leaves *g unmodified.

View File

@ -24,16 +24,10 @@ limitations under the License.
namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options) {
// Intentionally do not check that versions match so that this routine can
// be used for less brittle golden file tests.
return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
}
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
const protobuf::RepeatedPtrField<NodeDef>& expected,
string* diff, const EqualGraphDefOptions& options) {
template <class NodeDefs>
static bool EqualNodeDefsHelper(
const NodeDefs& actual, const protobuf::RepeatedPtrField<NodeDef>& expected,
string* diff, const EqualGraphDefOptions& options) {
std::unordered_map<string, const NodeDef*> actual_index;
for (const NodeDef& node : actual) {
actual_index[node.name()] = &node;
@ -68,6 +62,24 @@ bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
return true;
}
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options) {
// Intentionally do not check that versions match so that this routine can
// be used for less brittle golden file tests.
return EqualNodeDefsHelper(actual.node(), expected.node(), diff, options);
}
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options) {
return EqualNodeDefsHelper(actual, expected.node(), diff, options);
}
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
const protobuf::RepeatedPtrField<NodeDef>& expected,
string* diff, const EqualGraphDefOptions& options) {
return EqualNodeDefsHelper(actual, expected, diff, options);
}
namespace {
string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {

View File

@ -36,6 +36,8 @@ struct EqualGraphDefOptions {
// nodes must be consistent.
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options = {});
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
string* diff, const EqualGraphDefOptions& options = {});
// Determines if actual and expected are equal, ignoring: ordering of
// attrs, internal attributes (if set in `options`), and control inputs.