From 3b4e53b0739804af7e8f51412bac366dd842a3f1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 23 Jan 2017 18:13:01 -0800 Subject: [PATCH] Add an options argument to EqualGraphDef and EqualNodeDef. Currently the only option is controlling whether internal attributes (whose names start with "_") are tested for equality. Change: 145362690 --- tensorflow/core/graph/equal_graph_def.cc | 16 ++++++++++------ tensorflow/core/graph/equal_graph_def.h | 13 ++++++++++--- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc index 0c019fc5c18..21b6d55ca85 100644 --- a/tensorflow/core/graph/equal_graph_def.cc +++ b/tensorflow/core/graph/equal_graph_def.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, - string* diff) { + string* diff, const EqualGraphDefOptions& options) { // Intentionally do not check that versions match so that this routine can // be used for less brittle golden file tests. @@ -44,7 +44,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, return false; } - if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false; + if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) { + return false; + } actual_index.erase(actual_iter); } @@ -75,8 +77,8 @@ string JoinStringField(const protobuf::RepeatedPtrField& f) { } // namespace -bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, - string* diff) { +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, + const EqualGraphDefOptions& options) { if (actual.name() != expected.name()) { if (diff != nullptr) { *diff = strings::StrCat("Actual node name '", actual.name(), @@ -156,13 +158,15 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, std::unordered_set actual_attr; for (const auto& a : actual.attr()) { - if (!a.first.empty() && a.first[0] == '_') { + if (options.ignore_internal_attrs && !a.first.empty() && + a.first[0] == '_') { continue; } actual_attr.insert(a.first); } for (const auto& e : expected.attr()) { - if (!e.first.empty() && e.first[0] == '_') { + if (options.ignore_internal_attrs && !e.first.empty() && + e.first[0] == '_') { continue; } diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h index 8d997fdff8c..82f8bd0713b 100644 --- a/tensorflow/core/graph/equal_graph_def.h +++ b/tensorflow/core/graph/equal_graph_def.h @@ -22,20 +22,27 @@ limitations under the License. namespace tensorflow { +struct EqualGraphDefOptions { + // Should internal attributes (attribute names that start with '_') be + // ignored? + bool ignore_internal_attrs = true; +}; + // Determines if actual and expected are equal, ignoring versions and ordering // of nodes, attrs, and control inputs. If the GraphDefs are different and // diff != nullptr, *diff is set to an explanation of the difference. Note that // we use node names to match up nodes between the graphs, and so the naming of // nodes must be consistent. bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, - string* diff); + string* diff, const EqualGraphDefOptions& options = {}); // Determines if actual and expected are equal, ignoring: ordering of -// attrs, internal attributes, and control inputs. +// attrs, internal attributes (if set in `options`), and control inputs. // // If the NodeDefs are different and // diff != nullptr, *diff is set to an explanation of the difference. -bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff); +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, + const EqualGraphDefOptions& options = {}); #define TF_EXPECT_GRAPH_EQ(expected, actual) \ do { \