Adds proto comparators and options for nan-equivalence to graph util.

PiperOrigin-RevId: 353883024
Change-Id: I05423b52de13571294867b9284901f1567577b2f
This commit is contained in:
Keith Rush 2021-01-26 09:31:12 -08:00 committed by TensorFlower Gardener
parent e3e5a24f93
commit 85d50ab3a0
6 changed files with 298 additions and 7 deletions

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "google/protobuf/message.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/util/field_comparator.h"
#include "google/protobuf/util/json_util.h"
#include "google/protobuf/util/message_differencer.h"
#include "google/protobuf/util/type_resolver_util.h"
namespace tensorflow {
@ -116,7 +118,6 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
tstring* target_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_

View File

@ -965,6 +965,7 @@ py_library(
],
srcs_version = "PY3",
deps = [
":_proto_comparators",
":dtypes",
":framework_ops",
":platform",
@ -6583,6 +6584,21 @@ tf_gen_op_wrapper_private_py(
],
)
tf_python_pybind_extension(
name = "_proto_comparators",
srcs = ["framework/proto_comparators.cc"],
module_name = "_proto_comparators",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:protobuf",
"//tensorflow/python/lib/core:pybind11_status",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@pybind11",
],
)
py_library(
name = "proto_ops",
srcs = ["ops/proto_ops.py"],

View File

@ -24,6 +24,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants
from tensorflow.python.framework.graph_util_impl import extract_sub_graph
from tensorflow.python.framework.graph_util_impl import graph_defs_equal
from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name

View File

@ -25,6 +25,7 @@ import six
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python import _proto_comparators
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import deprecation
@ -374,3 +375,37 @@ def remove_training_nodes(input_graph, protected_nodes=None):
output_graph = graph_pb2.GraphDef()
output_graph.node.extend(nodes_after_splicing)
return output_graph
def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
graph_def_2: graph_pb2.GraphDef,
treat_nan_as_equal: bool = False) -> bool:
"""Returns True iff the graph def arguments are structurally equivalent.
The notion of equivalence encoded here checks that the set of NodeDefs in
the GraphDef's function library and main graph body are identical.
Additionally, it checks that the functions in the function library are equal
as sets.
Args:
graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
treat_nan_as_equal: Boolean indicating whether or not to treat nan
floating-point values as equal. This is crucial for any equivalence
relation defined over GraphDefs, to ensure symmetry.
Returns:
Boolean indicating structural equivalence as described above.
Raises:
TypeError: If either of the GraphDefs are not instances of
`graph_pb2.GraphDef`.
"""
if not isinstance(graph_def_1, graph_pb2.GraphDef):
raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
if not isinstance(graph_def_2, graph_pb2.GraphDef):
raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
graph_def_2.SerializeToString(),
options)

View File

@ -19,27 +19,30 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# Utility device function to use for testing
def test_device_func_pin_variable_to_cpu(op):
def TestDeviceFuncPinVariableToCpu(op):
if op.device:
return op.device
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
class DeviceFunctionsTest(test.TestCase):
class GraphUtilTest(test.TestCase):
def testTwoDeviceFunctions(self):
with ops.Graph().as_default() as g:
@ -49,7 +52,7 @@ class DeviceFunctionsTest(test.TestCase):
name="var_0",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
with g.device(TestDeviceFuncPinVariableToCpu):
var_1 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@ -68,7 +71,7 @@ class DeviceFunctionsTest(test.TestCase):
name="var_3",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
with g.device(TestDeviceFuncPinVariableToCpu):
var_4 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@ -101,7 +104,7 @@ class DeviceFunctionsTest(test.TestCase):
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
with ops.device(TestDeviceFuncPinVariableToCpu):
var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
var_2 = variables.VariableV1(2)
@ -136,7 +139,7 @@ class DeviceFunctionsTest(test.TestCase):
def testDefaultDevice(self):
with ops.Graph().as_default() as g, g.device(
test_device_func_pin_variable_to_cpu):
TestDeviceFuncPinVariableToCpu):
with g.device("/job:ps"):
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
@ -303,6 +306,162 @@ class DeviceFunctionsTest(test.TestCase):
self.assertProtoEquals(graph_def,
graph_util.remove_training_nodes(graph_def))
def testSimpleGraphdefsCompareEqual(self):
graph_def1 = graph_pb2.GraphDef()
graph_def1.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
graph_def2 = graph_pb2.GraphDef()
graph_def2.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testNodeDefsInDifferentOrderCompareEqual(self):
graph_def1 = graph_pb2.GraphDef()
graph_def1.node.extend([
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", []),
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
])
graph_def2 = graph_pb2.GraphDef()
graph_def2.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testDifferentGraphDefsCompareNotEqual(self):
graph_def1 = graph_pb2.GraphDef()
graph_def1.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
graph_def2 = graph_pb2.GraphDef()
graph_def2.node.extend([
self.create_constant_node_def("C", 2, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testGraphdefsWithNanCompareNonEqual(self):
graph_def1 = graph_pb2.GraphDef()
graph_def1.node.extend([
self.create_constant_node_def(
"C", float("nan"), dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
graph_def2 = graph_pb2.GraphDef()
graph_def2.node.extend([
self.create_constant_node_def(
"C", float("nan"), dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testSimpleGraphdefEqualityWithNansEqual(self):
graph_def1 = graph_pb2.GraphDef()
graph_def1.node.extend([
self.create_constant_node_def(
"C", float("nan"), dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
graph_def2 = graph_pb2.GraphDef()
graph_def2.node.extend([
self.create_constant_node_def(
"C", float("nan"), dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertTrue(
graph_util.graph_defs_equal(
graph_def1, graph_def2, treat_nan_as_equal=True))
def testGraphDefsWithFunctionLibsCompareEqual(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
library = function_pb2.FunctionDefLibrary()
library.function.extend([F1.definition])
graph_def1 = graph_pb2.GraphDef()
graph_def1.library.CopyFrom(library)
graph_def2 = graph_pb2.GraphDef()
graph_def2.library.CopyFrom(library)
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testGraphDefsWithPermutedFunctionsCompareEqual(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@function.Defun(dtypes.float32)
def F2(x):
return math_ops.exp(x)
definition_1 = F1.definition
definition_2 = F2.definition
library = function_pb2.FunctionDefLibrary()
library.function.extend([definition_1, definition_2])
graph_def1 = graph_pb2.GraphDef()
graph_def1.library.CopyFrom(library)
reversed_library = function_pb2.FunctionDefLibrary()
reversed_library.function.extend([definition_2, definition_1])
graph_def2 = graph_pb2.GraphDef()
graph_def2.library.CopyFrom(reversed_library)
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
f1_def = F1.definition
library = function_pb2.FunctionDefLibrary()
library.function.extend([f1_def])
graph_def1 = graph_pb2.GraphDef()
graph_def1.library.CopyFrom(library)
reversed_function = function_pb2.FunctionDef()
reversed_function.CopyFrom(f1_def)
# Clear the node_def attribute.
del reversed_function.node_def[:]
reversed_function.node_def.extend(reversed(f1_def.node_def))
reversed_library = function_pb2.FunctionDefLibrary()
reversed_library.function.extend([reversed_function])
graph_def2 = graph_pb2.GraphDef()
graph_def2.library.CopyFrom(reversed_library)
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,79 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <Python.h>
#include <memory>
#include <string>
#include "pybind11/detail/common.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
struct ProtoComparisonOptions; // Forward declaration
namespace tensorflow {
namespace {
namespace py = pybind11;
namespace tf = tensorflow;
struct ProtoComparisonOptions {
bool treat_nan_as_equal;
};
bool EqualsGraphDef(string graphdef_string1, string graphdef_string2,
const ProtoComparisonOptions& options) {
GraphDef graph_def_1;
if (!graph_def_1.ParseFromString(graphdef_string1)) {
MaybeRaiseFromStatus(errors::InvalidArgument(
"Couldn't interpret first argument as a GraphDef"));
}
GraphDef graph_def_2;
if (!graph_def_2.ParseFromString(graphdef_string2)) {
MaybeRaiseFromStatus(errors::InvalidArgument(
"Couldn't interpret second argument as a GraphDef"));
}
tf::protobuf::util::MessageDifferencer differencer;
// Order doesnt matter in node defs, or functions in the function library and
// their nested nodes.
differencer.TreatAsSet(GraphDef::descriptor()->FindFieldByName("node"));
differencer.TreatAsSet(
FunctionDefLibrary::descriptor()->FindFieldByName("function"));
differencer.TreatAsSet(
FunctionDefLibrary::descriptor()->FindFieldByName("gradient"));
differencer.TreatAsSet(
FunctionDef::descriptor()->FindFieldByName("node_def"));
tf::protobuf::util::DefaultFieldComparator comparator;
comparator.set_treat_nan_as_equal(options.treat_nan_as_equal);
differencer.set_field_comparator(&comparator);
return differencer.Compare(graph_def_1, graph_def_2);
}
PYBIND11_MODULE(_proto_comparators, m) {
py::class_<tensorflow::ProtoComparisonOptions>(m, "ProtoComparisonOptions")
.def(py::init<const bool&>());
m.def("EqualsGraphDef", &EqualsGraphDef,
"GraphDef equality test taking comparison options.");
}
} // anonymous namespace
} // namespace tensorflow