Add an options argument to EqualGraphDef and EqualNodeDef. Currently the only option is controlling whether internal attributes (whose names start with "_") are tested for equality.
Change: 145362690
This commit is contained in:
parent
cd4a96499b
commit
3b4e53b073
@ -25,7 +25,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff) {
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
// Intentionally do not check that versions match so that this routine can
|
||||
// be used for less brittle golden file tests.
|
||||
|
||||
@ -44,7 +44,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
|
||||
if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
actual_index.erase(actual_iter);
|
||||
}
|
||||
@ -75,8 +77,8 @@ string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
|
||||
|
||||
} // namespace
|
||||
|
||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
|
||||
string* diff) {
|
||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||
const EqualGraphDefOptions& options) {
|
||||
if (actual.name() != expected.name()) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Actual node name '", actual.name(),
|
||||
@ -156,13 +158,15 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
|
||||
|
||||
std::unordered_set<string> actual_attr;
|
||||
for (const auto& a : actual.attr()) {
|
||||
if (!a.first.empty() && a.first[0] == '_') {
|
||||
if (options.ignore_internal_attrs && !a.first.empty() &&
|
||||
a.first[0] == '_') {
|
||||
continue;
|
||||
}
|
||||
actual_attr.insert(a.first);
|
||||
}
|
||||
for (const auto& e : expected.attr()) {
|
||||
if (!e.first.empty() && e.first[0] == '_') {
|
||||
if (options.ignore_internal_attrs && !e.first.empty() &&
|
||||
e.first[0] == '_') {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -22,20 +22,27 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct EqualGraphDefOptions {
|
||||
// Should internal attributes (attribute names that start with '_') be
|
||||
// ignored?
|
||||
bool ignore_internal_attrs = true;
|
||||
};
|
||||
|
||||
// Determines if actual and expected are equal, ignoring versions and ordering
|
||||
// of nodes, attrs, and control inputs. If the GraphDefs are different and
|
||||
// diff != nullptr, *diff is set to an explanation of the difference. Note that
|
||||
// we use node names to match up nodes between the graphs, and so the naming of
|
||||
// nodes must be consistent.
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff);
|
||||
string* diff, const EqualGraphDefOptions& options = {});
|
||||
|
||||
// Determines if actual and expected are equal, ignoring: ordering of
|
||||
// attrs, internal attributes, and control inputs.
|
||||
// attrs, internal attributes (if set in `options`), and control inputs.
|
||||
//
|
||||
// If the NodeDefs are different and
|
||||
// diff != nullptr, *diff is set to an explanation of the difference.
|
||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
|
||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||
const EqualGraphDefOptions& options = {});
|
||||
|
||||
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
|
||||
do { \
|
||||
|
Loading…
Reference in New Issue
Block a user