Add an options argument to EqualGraphDef and EqualNodeDef. Currently the only option is controlling whether internal attributes (whose names start with "_") are tested for equality.

Change: 145362690
This commit is contained in:
Peter Hawkins 2017-01-23 18:13:01 -08:00 committed by TensorFlower Gardener
parent cd4a96499b
commit 3b4e53b073
2 changed files with 20 additions and 9 deletions

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff) {
string* diff, const EqualGraphDefOptions& options) {
// Intentionally do not check that versions match so that this routine can
// be used for less brittle golden file tests.
@ -44,7 +44,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
return false;
}
if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
return false;
}
actual_index.erase(actual_iter);
}
@ -75,8 +77,8 @@ string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
} // namespace
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
string* diff) {
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
const EqualGraphDefOptions& options) {
if (actual.name() != expected.name()) {
if (diff != nullptr) {
*diff = strings::StrCat("Actual node name '", actual.name(),
@ -156,13 +158,15 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
std::unordered_set<string> actual_attr;
for (const auto& a : actual.attr()) {
if (!a.first.empty() && a.first[0] == '_') {
if (options.ignore_internal_attrs && !a.first.empty() &&
a.first[0] == '_') {
continue;
}
actual_attr.insert(a.first);
}
for (const auto& e : expected.attr()) {
if (!e.first.empty() && e.first[0] == '_') {
if (options.ignore_internal_attrs && !e.first.empty() &&
e.first[0] == '_') {
continue;
}

View File

@ -22,20 +22,27 @@ limitations under the License.
namespace tensorflow {
struct EqualGraphDefOptions {
// Should internal attributes (attribute names that start with '_') be
// ignored?
bool ignore_internal_attrs = true;
};
// Determines if actual and expected are equal, ignoring versions and ordering
// of nodes, attrs, and control inputs. If the GraphDefs are different and
// diff != nullptr, *diff is set to an explanation of the difference. Note that
// we use node names to match up nodes between the graphs, and so the naming of
// nodes must be consistent.
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff);
string* diff, const EqualGraphDefOptions& options = {});
// Determines if actual and expected are equal, ignoring: ordering of
// attrs, internal attributes, and control inputs.
// attrs, internal attributes (if set in `options`), and control inputs.
//
// If the NodeDefs are different and
// diff != nullptr, *diff is set to an explanation of the difference.
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
const EqualGraphDefOptions& options = {});
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
do { \