diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h index 371912cc2b7..6950cb9a1b6 100644 --- a/tensorflow/core/platform/protobuf.h +++ b/tensorflow/core/platform/protobuf.h @@ -37,7 +37,9 @@ limitations under the License. #include "google/protobuf/message.h" #include "google/protobuf/repeated_field.h" #include "google/protobuf/text_format.h" +#include "google/protobuf/util/field_comparator.h" #include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/message_differencer.h" #include "google/protobuf/util/type_resolver_util.h" namespace tensorflow { @@ -116,7 +118,6 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream { tstring* target_; }; - } // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c4a15c55f4d..7b1ccdae12e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -965,6 +965,7 @@ py_library( ], srcs_version = "PY3", deps = [ + ":_proto_comparators", ":dtypes", ":framework_ops", ":platform", @@ -6583,6 +6584,21 @@ tf_gen_op_wrapper_private_py( ], ) +tf_python_pybind_extension( + name = "_proto_comparators", + srcs = ["framework/proto_comparators.cc"], + module_name = "_proto_comparators", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:protobuf", + "//tensorflow/python/lib/core:pybind11_status", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@pybind11", + ], +) + py_library( name = "proto_ops", srcs = ["ops/proto_ops.py"], diff --git a/tensorflow/python/framework/graph_util.py b/tensorflow/python/framework/graph_util.py index c5cc1107343..117926a4ab7 100644 --- a/tensorflow/python/framework/graph_util.py +++ b/tensorflow/python/framework/graph_util.py @@ -24,6 +24,7 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants from tensorflow.python.framework.graph_util_impl import extract_sub_graph +from tensorflow.python.framework.graph_util_impl import graph_defs_equal from tensorflow.python.framework.graph_util_impl import must_run_on_cpu from tensorflow.python.framework.graph_util_impl import remove_training_nodes from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 4ef26fc8539..a2055241517 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -25,6 +25,7 @@ import six from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python import _proto_comparators from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.util import deprecation @@ -374,3 +375,37 @@ def remove_training_nodes(input_graph, protected_nodes=None): output_graph = graph_pb2.GraphDef() output_graph.node.extend(nodes_after_splicing) return output_graph + + +def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, + graph_def_2: graph_pb2.GraphDef, + treat_nan_as_equal: bool = False) -> bool: + """Returns True iff the graph def arguments are structurally equivalent. + + The notion of equivalence encoded here checks that the set of NodeDefs in + the GraphDef's function library and main graph body are identical. + Additionally, it checks that the functions in the function library are equal + as sets. + + Args: + graph_def_1: Instance of `graph_pb2.GraphDef` to compare. + graph_def_2: Instance of `graph_pb2.GraphDef` to compare. + treat_nan_as_equal: Boolean indicating whether or not to treat nan + floating-point values as equal. This is crucial for any equivalence + relation defined over GraphDefs, to ensure symmetry. + + Returns: + Boolean indicating structural equivalence as described above. + + Raises: + TypeError: If either of the GraphDefs are not instances of + `graph_pb2.GraphDef`. + """ + if not isinstance(graph_def_1, graph_pb2.GraphDef): + raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.") + if not isinstance(graph_def_2, graph_pb2.GraphDef): + raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.") + options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal) + return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(), + graph_def_2.SerializeToString(), + options) diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py index 4957ee7d97e..b23bdaf3b59 100644 --- a/tensorflow/python/framework/graph_util_test.py +++ b/tensorflow/python/framework/graph_util_test.py @@ -19,27 +19,30 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import graph_util from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test # Utility device function to use for testing -def test_device_func_pin_variable_to_cpu(op): +def TestDeviceFuncPinVariableToCpu(op): if op.device: return op.device return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device -class DeviceFunctionsTest(test.TestCase): +class GraphUtilTest(test.TestCase): def testTwoDeviceFunctions(self): with ops.Graph().as_default() as g: @@ -49,7 +52,7 @@ class DeviceFunctionsTest(test.TestCase): name="var_0", container="", shared_name="") - with g.device(test_device_func_pin_variable_to_cpu): + with g.device(TestDeviceFuncPinVariableToCpu): var_1 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, @@ -68,7 +71,7 @@ class DeviceFunctionsTest(test.TestCase): name="var_3", container="", shared_name="") - with g.device(test_device_func_pin_variable_to_cpu): + with g.device(TestDeviceFuncPinVariableToCpu): var_4 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, @@ -101,7 +104,7 @@ class DeviceFunctionsTest(test.TestCase): def testNestedDeviceFunctions(self): with ops.Graph().as_default(): var_0 = variables.VariableV1(0) - with ops.device(test_device_func_pin_variable_to_cpu): + with ops.device(TestDeviceFuncPinVariableToCpu): var_1 = variables.VariableV1(1) with ops.device(lambda op: "/device:GPU:0"): var_2 = variables.VariableV1(2) @@ -136,7 +139,7 @@ class DeviceFunctionsTest(test.TestCase): def testDefaultDevice(self): with ops.Graph().as_default() as g, g.device( - test_device_func_pin_variable_to_cpu): + TestDeviceFuncPinVariableToCpu): with g.device("/job:ps"): const_0 = constant_op.constant(5.0) with g.device("/device:GPU:0"): @@ -303,6 +306,162 @@ class DeviceFunctionsTest(test.TestCase): self.assertProtoEquals(graph_def, graph_util.remove_training_nodes(graph_def)) + def testSimpleGraphdefsCompareEqual(self): + graph_def1 = graph_pb2.GraphDef() + graph_def1.node.extend([ + self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.node.extend([ + self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testNodeDefsInDifferentOrderCompareEqual(self): + graph_def1 = graph_pb2.GraphDef() + graph_def1.node.extend([ + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []), + self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), + ]) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.node.extend([ + self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testDifferentGraphDefsCompareNotEqual(self): + graph_def1 = graph_pb2.GraphDef() + graph_def1.node.extend([ + self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.node.extend([ + self.create_constant_node_def("C", 2, dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testGraphdefsWithNanCompareNonEqual(self): + graph_def1 = graph_pb2.GraphDef() + graph_def1.node.extend([ + self.create_constant_node_def( + "C", float("nan"), dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.node.extend([ + self.create_constant_node_def( + "C", float("nan"), dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testSimpleGraphdefEqualityWithNansEqual(self): + graph_def1 = graph_pb2.GraphDef() + graph_def1.node.extend([ + self.create_constant_node_def( + "C", float("nan"), dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.node.extend([ + self.create_constant_node_def( + "C", float("nan"), dtypes.float32, inputs=["^I"]), + self.create_node_def("Identity", "I", ["Base"]), + self.create_node_def("BaseOp", "Base", []) + ]) + self.assertTrue( + graph_util.graph_defs_equal( + graph_def1, graph_def2, treat_nan_as_equal=True)) + + def testGraphDefsWithFunctionLibsCompareEqual(self): + + @function.Defun(dtypes.float32) + def F1(x): + return math_ops.exp(x) - math_ops.exp(-x) + + library = function_pb2.FunctionDefLibrary() + library.function.extend([F1.definition]) + + graph_def1 = graph_pb2.GraphDef() + graph_def1.library.CopyFrom(library) + + graph_def2 = graph_pb2.GraphDef() + graph_def2.library.CopyFrom(library) + + self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testGraphDefsWithPermutedFunctionsCompareEqual(self): + + @function.Defun(dtypes.float32) + def F1(x): + return math_ops.exp(x) - math_ops.exp(-x) + + @function.Defun(dtypes.float32) + def F2(x): + return math_ops.exp(x) + + definition_1 = F1.definition + definition_2 = F2.definition + library = function_pb2.FunctionDefLibrary() + library.function.extend([definition_1, definition_2]) + + graph_def1 = graph_pb2.GraphDef() + graph_def1.library.CopyFrom(library) + + reversed_library = function_pb2.FunctionDefLibrary() + reversed_library.function.extend([definition_2, definition_1]) + graph_def2 = graph_pb2.GraphDef() + graph_def2.library.CopyFrom(reversed_library) + + self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2)) + + def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self): + + @function.Defun(dtypes.float32) + def F1(x): + return math_ops.exp(x) - math_ops.exp(-x) + + f1_def = F1.definition + + library = function_pb2.FunctionDefLibrary() + library.function.extend([f1_def]) + + graph_def1 = graph_pb2.GraphDef() + graph_def1.library.CopyFrom(library) + + reversed_function = function_pb2.FunctionDef() + reversed_function.CopyFrom(f1_def) + # Clear the node_def attribute. + del reversed_function.node_def[:] + reversed_function.node_def.extend(reversed(f1_def.node_def)) + reversed_library = function_pb2.FunctionDefLibrary() + reversed_library.function.extend([reversed_function]) + graph_def2 = graph_pb2.GraphDef() + graph_def2.library.CopyFrom(reversed_library) + + self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/proto_comparators.cc b/tensorflow/python/framework/proto_comparators.cc new file mode 100644 index 00000000000..5d457e9d91b --- /dev/null +++ b/tensorflow/python/framework/proto_comparators.cc @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <Python.h> + +#include <memory> +#include <string> + +#include "pybind11/detail/common.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +struct ProtoComparisonOptions; // Forward declaration + +namespace tensorflow { + +namespace { + +namespace py = pybind11; +namespace tf = tensorflow; + +struct ProtoComparisonOptions { + bool treat_nan_as_equal; +}; + +bool EqualsGraphDef(string graphdef_string1, string graphdef_string2, + const ProtoComparisonOptions& options) { + GraphDef graph_def_1; + if (!graph_def_1.ParseFromString(graphdef_string1)) { + MaybeRaiseFromStatus(errors::InvalidArgument( + "Couldn't interpret first argument as a GraphDef")); + } + GraphDef graph_def_2; + if (!graph_def_2.ParseFromString(graphdef_string2)) { + MaybeRaiseFromStatus(errors::InvalidArgument( + "Couldn't interpret second argument as a GraphDef")); + } + tf::protobuf::util::MessageDifferencer differencer; + // Order doesnt matter in node defs, or functions in the function library and + // their nested nodes. + differencer.TreatAsSet(GraphDef::descriptor()->FindFieldByName("node")); + differencer.TreatAsSet( + FunctionDefLibrary::descriptor()->FindFieldByName("function")); + differencer.TreatAsSet( + FunctionDefLibrary::descriptor()->FindFieldByName("gradient")); + differencer.TreatAsSet( + FunctionDef::descriptor()->FindFieldByName("node_def")); + tf::protobuf::util::DefaultFieldComparator comparator; + comparator.set_treat_nan_as_equal(options.treat_nan_as_equal); + differencer.set_field_comparator(&comparator); + return differencer.Compare(graph_def_1, graph_def_2); +} + +PYBIND11_MODULE(_proto_comparators, m) { + py::class_<tensorflow::ProtoComparisonOptions>(m, "ProtoComparisonOptions") + .def(py::init<const bool&>()); + m.def("EqualsGraphDef", &EqualsGraphDef, + "GraphDef equality test taking comparison options."); +} + +} // anonymous namespace + +} // namespace tensorflow