Adds proto comparators and options for nan-equivalence to graph util.
PiperOrigin-RevId: 353883024 Change-Id: I05423b52de13571294867b9284901f1567577b2f
This commit is contained in:
parent
e3e5a24f93
commit
85d50ab3a0
@ -37,7 +37,9 @@ limitations under the License.
|
|||||||
#include "google/protobuf/message.h"
|
#include "google/protobuf/message.h"
|
||||||
#include "google/protobuf/repeated_field.h"
|
#include "google/protobuf/repeated_field.h"
|
||||||
#include "google/protobuf/text_format.h"
|
#include "google/protobuf/text_format.h"
|
||||||
|
#include "google/protobuf/util/field_comparator.h"
|
||||||
#include "google/protobuf/util/json_util.h"
|
#include "google/protobuf/util/json_util.h"
|
||||||
|
#include "google/protobuf/util/message_differencer.h"
|
||||||
#include "google/protobuf/util/type_resolver_util.h"
|
#include "google/protobuf/util/type_resolver_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -116,7 +118,6 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
|
|||||||
|
|
||||||
tstring* target_;
|
tstring* target_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
|
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
|
||||||
|
@ -965,6 +965,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_proto_comparators",
|
||||||
":dtypes",
|
":dtypes",
|
||||||
":framework_ops",
|
":framework_ops",
|
||||||
":platform",
|
":platform",
|
||||||
@ -6583,6 +6584,21 @@ tf_gen_op_wrapper_private_py(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_proto_comparators",
|
||||||
|
srcs = ["framework/proto_comparators.cc"],
|
||||||
|
module_name = "_proto_comparators",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/core/platform:protobuf",
|
||||||
|
"//tensorflow/python/lib/core:pybind11_status",
|
||||||
|
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "proto_ops",
|
name = "proto_ops",
|
||||||
srcs = ["ops/proto_ops.py"],
|
srcs = ["ops/proto_ops.py"],
|
||||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants
|
from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants
|
||||||
from tensorflow.python.framework.graph_util_impl import extract_sub_graph
|
from tensorflow.python.framework.graph_util_impl import extract_sub_graph
|
||||||
|
from tensorflow.python.framework.graph_util_impl import graph_defs_equal
|
||||||
from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
|
from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
|
||||||
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
|
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
|
||||||
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
|
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
|
||||||
|
@ -25,6 +25,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
|
from tensorflow.python import _proto_comparators
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
@ -374,3 +375,37 @@ def remove_training_nodes(input_graph, protected_nodes=None):
|
|||||||
output_graph = graph_pb2.GraphDef()
|
output_graph = graph_pb2.GraphDef()
|
||||||
output_graph.node.extend(nodes_after_splicing)
|
output_graph.node.extend(nodes_after_splicing)
|
||||||
return output_graph
|
return output_graph
|
||||||
|
|
||||||
|
|
||||||
|
def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
|
||||||
|
graph_def_2: graph_pb2.GraphDef,
|
||||||
|
treat_nan_as_equal: bool = False) -> bool:
|
||||||
|
"""Returns True iff the graph def arguments are structurally equivalent.
|
||||||
|
|
||||||
|
The notion of equivalence encoded here checks that the set of NodeDefs in
|
||||||
|
the GraphDef's function library and main graph body are identical.
|
||||||
|
Additionally, it checks that the functions in the function library are equal
|
||||||
|
as sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
|
||||||
|
graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
|
||||||
|
treat_nan_as_equal: Boolean indicating whether or not to treat nan
|
||||||
|
floating-point values as equal. This is crucial for any equivalence
|
||||||
|
relation defined over GraphDefs, to ensure symmetry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating structural equivalence as described above.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If either of the GraphDefs are not instances of
|
||||||
|
`graph_pb2.GraphDef`.
|
||||||
|
"""
|
||||||
|
if not isinstance(graph_def_1, graph_pb2.GraphDef):
|
||||||
|
raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
|
||||||
|
if not isinstance(graph_def_2, graph_pb2.GraphDef):
|
||||||
|
raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
|
||||||
|
options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
|
||||||
|
return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
|
||||||
|
graph_def_2.SerializeToString(),
|
||||||
|
options)
|
||||||
|
@ -19,27 +19,30 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
|
from tensorflow.core.framework import function_pb2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_state_ops
|
from tensorflow.python.ops import gen_state_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
# Utility device function to use for testing
|
# Utility device function to use for testing
|
||||||
def test_device_func_pin_variable_to_cpu(op):
|
def TestDeviceFuncPinVariableToCpu(op):
|
||||||
if op.device:
|
if op.device:
|
||||||
return op.device
|
return op.device
|
||||||
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
|
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
|
||||||
|
|
||||||
|
|
||||||
class DeviceFunctionsTest(test.TestCase):
|
class GraphUtilTest(test.TestCase):
|
||||||
|
|
||||||
def testTwoDeviceFunctions(self):
|
def testTwoDeviceFunctions(self):
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
@ -49,7 +52,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
name="var_0",
|
name="var_0",
|
||||||
container="",
|
container="",
|
||||||
shared_name="")
|
shared_name="")
|
||||||
with g.device(test_device_func_pin_variable_to_cpu):
|
with g.device(TestDeviceFuncPinVariableToCpu):
|
||||||
var_1 = gen_state_ops.variable(
|
var_1 = gen_state_ops.variable(
|
||||||
shape=[1],
|
shape=[1],
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -68,7 +71,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
name="var_3",
|
name="var_3",
|
||||||
container="",
|
container="",
|
||||||
shared_name="")
|
shared_name="")
|
||||||
with g.device(test_device_func_pin_variable_to_cpu):
|
with g.device(TestDeviceFuncPinVariableToCpu):
|
||||||
var_4 = gen_state_ops.variable(
|
var_4 = gen_state_ops.variable(
|
||||||
shape=[1],
|
shape=[1],
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -101,7 +104,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
def testNestedDeviceFunctions(self):
|
def testNestedDeviceFunctions(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
var_0 = variables.VariableV1(0)
|
var_0 = variables.VariableV1(0)
|
||||||
with ops.device(test_device_func_pin_variable_to_cpu):
|
with ops.device(TestDeviceFuncPinVariableToCpu):
|
||||||
var_1 = variables.VariableV1(1)
|
var_1 = variables.VariableV1(1)
|
||||||
with ops.device(lambda op: "/device:GPU:0"):
|
with ops.device(lambda op: "/device:GPU:0"):
|
||||||
var_2 = variables.VariableV1(2)
|
var_2 = variables.VariableV1(2)
|
||||||
@ -136,7 +139,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
|
|
||||||
def testDefaultDevice(self):
|
def testDefaultDevice(self):
|
||||||
with ops.Graph().as_default() as g, g.device(
|
with ops.Graph().as_default() as g, g.device(
|
||||||
test_device_func_pin_variable_to_cpu):
|
TestDeviceFuncPinVariableToCpu):
|
||||||
with g.device("/job:ps"):
|
with g.device("/job:ps"):
|
||||||
const_0 = constant_op.constant(5.0)
|
const_0 = constant_op.constant(5.0)
|
||||||
with g.device("/device:GPU:0"):
|
with g.device("/device:GPU:0"):
|
||||||
@ -303,6 +306,162 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
self.assertProtoEquals(graph_def,
|
self.assertProtoEquals(graph_def,
|
||||||
graph_util.remove_training_nodes(graph_def))
|
graph_util.remove_training_nodes(graph_def))
|
||||||
|
|
||||||
|
def testSimpleGraphdefsCompareEqual(self):
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.node.extend([
|
||||||
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.node.extend([
|
||||||
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testNodeDefsInDifferentOrderCompareEqual(self):
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.node.extend([
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", []),
|
||||||
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||||
|
])
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.node.extend([
|
||||||
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testDifferentGraphDefsCompareNotEqual(self):
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.node.extend([
|
||||||
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.node.extend([
|
||||||
|
self.create_constant_node_def("C", 2, dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testGraphdefsWithNanCompareNonEqual(self):
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.node.extend([
|
||||||
|
self.create_constant_node_def(
|
||||||
|
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.node.extend([
|
||||||
|
self.create_constant_node_def(
|
||||||
|
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testSimpleGraphdefEqualityWithNansEqual(self):
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.node.extend([
|
||||||
|
self.create_constant_node_def(
|
||||||
|
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.node.extend([
|
||||||
|
self.create_constant_node_def(
|
||||||
|
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||||
|
self.create_node_def("Identity", "I", ["Base"]),
|
||||||
|
self.create_node_def("BaseOp", "Base", [])
|
||||||
|
])
|
||||||
|
self.assertTrue(
|
||||||
|
graph_util.graph_defs_equal(
|
||||||
|
graph_def1, graph_def2, treat_nan_as_equal=True))
|
||||||
|
|
||||||
|
def testGraphDefsWithFunctionLibsCompareEqual(self):
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def F1(x):
|
||||||
|
return math_ops.exp(x) - math_ops.exp(-x)
|
||||||
|
|
||||||
|
library = function_pb2.FunctionDefLibrary()
|
||||||
|
library.function.extend([F1.definition])
|
||||||
|
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.library.CopyFrom(library)
|
||||||
|
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.library.CopyFrom(library)
|
||||||
|
|
||||||
|
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testGraphDefsWithPermutedFunctionsCompareEqual(self):
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def F1(x):
|
||||||
|
return math_ops.exp(x) - math_ops.exp(-x)
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def F2(x):
|
||||||
|
return math_ops.exp(x)
|
||||||
|
|
||||||
|
definition_1 = F1.definition
|
||||||
|
definition_2 = F2.definition
|
||||||
|
library = function_pb2.FunctionDefLibrary()
|
||||||
|
library.function.extend([definition_1, definition_2])
|
||||||
|
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.library.CopyFrom(library)
|
||||||
|
|
||||||
|
reversed_library = function_pb2.FunctionDefLibrary()
|
||||||
|
reversed_library.function.extend([definition_2, definition_1])
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.library.CopyFrom(reversed_library)
|
||||||
|
|
||||||
|
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def F1(x):
|
||||||
|
return math_ops.exp(x) - math_ops.exp(-x)
|
||||||
|
|
||||||
|
f1_def = F1.definition
|
||||||
|
|
||||||
|
library = function_pb2.FunctionDefLibrary()
|
||||||
|
library.function.extend([f1_def])
|
||||||
|
|
||||||
|
graph_def1 = graph_pb2.GraphDef()
|
||||||
|
graph_def1.library.CopyFrom(library)
|
||||||
|
|
||||||
|
reversed_function = function_pb2.FunctionDef()
|
||||||
|
reversed_function.CopyFrom(f1_def)
|
||||||
|
# Clear the node_def attribute.
|
||||||
|
del reversed_function.node_def[:]
|
||||||
|
reversed_function.node_def.extend(reversed(f1_def.node_def))
|
||||||
|
reversed_library = function_pb2.FunctionDefLibrary()
|
||||||
|
reversed_library.function.extend([reversed_function])
|
||||||
|
graph_def2 = graph_pb2.GraphDef()
|
||||||
|
graph_def2.library.CopyFrom(reversed_library)
|
||||||
|
|
||||||
|
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
79
tensorflow/python/framework/proto_comparators.cc
Normal file
79
tensorflow/python/framework/proto_comparators.cc
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "pybind11/detail/common.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
struct ProtoComparisonOptions; // Forward declaration
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
namespace tf = tensorflow;
|
||||||
|
|
||||||
|
struct ProtoComparisonOptions {
|
||||||
|
bool treat_nan_as_equal;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool EqualsGraphDef(string graphdef_string1, string graphdef_string2,
|
||||||
|
const ProtoComparisonOptions& options) {
|
||||||
|
GraphDef graph_def_1;
|
||||||
|
if (!graph_def_1.ParseFromString(graphdef_string1)) {
|
||||||
|
MaybeRaiseFromStatus(errors::InvalidArgument(
|
||||||
|
"Couldn't interpret first argument as a GraphDef"));
|
||||||
|
}
|
||||||
|
GraphDef graph_def_2;
|
||||||
|
if (!graph_def_2.ParseFromString(graphdef_string2)) {
|
||||||
|
MaybeRaiseFromStatus(errors::InvalidArgument(
|
||||||
|
"Couldn't interpret second argument as a GraphDef"));
|
||||||
|
}
|
||||||
|
tf::protobuf::util::MessageDifferencer differencer;
|
||||||
|
// Order doesnt matter in node defs, or functions in the function library and
|
||||||
|
// their nested nodes.
|
||||||
|
differencer.TreatAsSet(GraphDef::descriptor()->FindFieldByName("node"));
|
||||||
|
differencer.TreatAsSet(
|
||||||
|
FunctionDefLibrary::descriptor()->FindFieldByName("function"));
|
||||||
|
differencer.TreatAsSet(
|
||||||
|
FunctionDefLibrary::descriptor()->FindFieldByName("gradient"));
|
||||||
|
differencer.TreatAsSet(
|
||||||
|
FunctionDef::descriptor()->FindFieldByName("node_def"));
|
||||||
|
tf::protobuf::util::DefaultFieldComparator comparator;
|
||||||
|
comparator.set_treat_nan_as_equal(options.treat_nan_as_equal);
|
||||||
|
differencer.set_field_comparator(&comparator);
|
||||||
|
return differencer.Compare(graph_def_1, graph_def_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_proto_comparators, m) {
|
||||||
|
py::class_<tensorflow::ProtoComparisonOptions>(m, "ProtoComparisonOptions")
|
||||||
|
.def(py::init<const bool&>());
|
||||||
|
m.def("EqualsGraphDef", &EqualsGraphDef,
|
||||||
|
"GraphDef equality test taking comparison options.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user