Add an options argument to EqualGraphDef and EqualNodeDef. Currently the only option is controlling whether internal attributes (whose names start with "_") are tested for equality.
Change: 145362690
This commit is contained in:
parent
cd4a96499b
commit
3b4e53b073
tensorflow/core/graph
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||||
string* diff) {
|
string* diff, const EqualGraphDefOptions& options) {
|
||||||
// Intentionally do not check that versions match so that this routine can
|
// Intentionally do not check that versions match so that this routine can
|
||||||
// be used for less brittle golden file tests.
|
// be used for less brittle golden file tests.
|
||||||
|
|
||||||
@ -44,7 +44,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
|
if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
actual_index.erase(actual_iter);
|
actual_index.erase(actual_iter);
|
||||||
}
|
}
|
||||||
@ -75,8 +77,8 @@ string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
|
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||||
string* diff) {
|
const EqualGraphDefOptions& options) {
|
||||||
if (actual.name() != expected.name()) {
|
if (actual.name() != expected.name()) {
|
||||||
if (diff != nullptr) {
|
if (diff != nullptr) {
|
||||||
*diff = strings::StrCat("Actual node name '", actual.name(),
|
*diff = strings::StrCat("Actual node name '", actual.name(),
|
||||||
@ -156,13 +158,15 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
|
|||||||
|
|
||||||
std::unordered_set<string> actual_attr;
|
std::unordered_set<string> actual_attr;
|
||||||
for (const auto& a : actual.attr()) {
|
for (const auto& a : actual.attr()) {
|
||||||
if (!a.first.empty() && a.first[0] == '_') {
|
if (options.ignore_internal_attrs && !a.first.empty() &&
|
||||||
|
a.first[0] == '_') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
actual_attr.insert(a.first);
|
actual_attr.insert(a.first);
|
||||||
}
|
}
|
||||||
for (const auto& e : expected.attr()) {
|
for (const auto& e : expected.attr()) {
|
||||||
if (!e.first.empty() && e.first[0] == '_') {
|
if (options.ignore_internal_attrs && !e.first.empty() &&
|
||||||
|
e.first[0] == '_') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,20 +22,27 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
struct EqualGraphDefOptions {
|
||||||
|
// Should internal attributes (attribute names that start with '_') be
|
||||||
|
// ignored?
|
||||||
|
bool ignore_internal_attrs = true;
|
||||||
|
};
|
||||||
|
|
||||||
// Determines if actual and expected are equal, ignoring versions and ordering
|
// Determines if actual and expected are equal, ignoring versions and ordering
|
||||||
// of nodes, attrs, and control inputs. If the GraphDefs are different and
|
// of nodes, attrs, and control inputs. If the GraphDefs are different and
|
||||||
// diff != nullptr, *diff is set to an explanation of the difference. Note that
|
// diff != nullptr, *diff is set to an explanation of the difference. Note that
|
||||||
// we use node names to match up nodes between the graphs, and so the naming of
|
// we use node names to match up nodes between the graphs, and so the naming of
|
||||||
// nodes must be consistent.
|
// nodes must be consistent.
|
||||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||||
string* diff);
|
string* diff, const EqualGraphDefOptions& options = {});
|
||||||
|
|
||||||
// Determines if actual and expected are equal, ignoring: ordering of
|
// Determines if actual and expected are equal, ignoring: ordering of
|
||||||
// attrs, internal attributes, and control inputs.
|
// attrs, internal attributes (if set in `options`), and control inputs.
|
||||||
//
|
//
|
||||||
// If the NodeDefs are different and
|
// If the NodeDefs are different and
|
||||||
// diff != nullptr, *diff is set to an explanation of the difference.
|
// diff != nullptr, *diff is set to an explanation of the difference.
|
||||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
|
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||||
|
const EqualGraphDefOptions& options = {});
|
||||||
|
|
||||||
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
|
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
|
||||||
do { \
|
do { \
|
||||||
|
Loading…
Reference in New Issue
Block a user