Adds proto comparators and options for nan-equivalence to graph util.
PiperOrigin-RevId: 353883024 Change-Id: I05423b52de13571294867b9284901f1567577b2f
This commit is contained in:
parent
e3e5a24f93
commit
85d50ab3a0
tensorflow
core/platform
python
@ -37,7 +37,9 @@ limitations under the License.
|
||||
#include "google/protobuf/message.h"
|
||||
#include "google/protobuf/repeated_field.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "google/protobuf/util/field_comparator.h"
|
||||
#include "google/protobuf/util/json_util.h"
|
||||
#include "google/protobuf/util/message_differencer.h"
|
||||
#include "google/protobuf/util/type_resolver_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -116,7 +118,6 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
|
||||
|
||||
tstring* target_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
|
||||
|
@ -965,6 +965,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":_proto_comparators",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":platform",
|
||||
@ -6583,6 +6584,21 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_proto_comparators",
|
||||
srcs = ["framework/proto_comparators.cc"],
|
||||
module_name = "_proto_comparators",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/python/lib/core:pybind11_status",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "proto_ops",
|
||||
srcs = ["ops/proto_ops.py"],
|
||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants
|
||||
from tensorflow.python.framework.graph_util_impl import extract_sub_graph
|
||||
from tensorflow.python.framework.graph_util_impl import graph_defs_equal
|
||||
from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
|
||||
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
|
||||
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python import _proto_comparators
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -374,3 +375,37 @@ def remove_training_nodes(input_graph, protected_nodes=None):
|
||||
output_graph = graph_pb2.GraphDef()
|
||||
output_graph.node.extend(nodes_after_splicing)
|
||||
return output_graph
|
||||
|
||||
|
||||
def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
|
||||
graph_def_2: graph_pb2.GraphDef,
|
||||
treat_nan_as_equal: bool = False) -> bool:
|
||||
"""Returns True iff the graph def arguments are structurally equivalent.
|
||||
|
||||
The notion of equivalence encoded here checks that the set of NodeDefs in
|
||||
the GraphDef's function library and main graph body are identical.
|
||||
Additionally, it checks that the functions in the function library are equal
|
||||
as sets.
|
||||
|
||||
Args:
|
||||
graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
|
||||
graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
|
||||
treat_nan_as_equal: Boolean indicating whether or not to treat nan
|
||||
floating-point values as equal. This is crucial for any equivalence
|
||||
relation defined over GraphDefs, to ensure symmetry.
|
||||
|
||||
Returns:
|
||||
Boolean indicating structural equivalence as described above.
|
||||
|
||||
Raises:
|
||||
TypeError: If either of the GraphDefs are not instances of
|
||||
`graph_pb2.GraphDef`.
|
||||
"""
|
||||
if not isinstance(graph_def_1, graph_pb2.GraphDef):
|
||||
raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
|
||||
if not isinstance(graph_def_2, graph_pb2.GraphDef):
|
||||
raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
|
||||
options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
|
||||
return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
|
||||
graph_def_2.SerializeToString(),
|
||||
options)
|
||||
|
@ -19,27 +19,30 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# Utility device function to use for testing
|
||||
def test_device_func_pin_variable_to_cpu(op):
|
||||
def TestDeviceFuncPinVariableToCpu(op):
|
||||
if op.device:
|
||||
return op.device
|
||||
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
|
||||
|
||||
|
||||
class DeviceFunctionsTest(test.TestCase):
|
||||
class GraphUtilTest(test.TestCase):
|
||||
|
||||
def testTwoDeviceFunctions(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
@ -49,7 +52,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
name="var_0",
|
||||
container="",
|
||||
shared_name="")
|
||||
with g.device(test_device_func_pin_variable_to_cpu):
|
||||
with g.device(TestDeviceFuncPinVariableToCpu):
|
||||
var_1 = gen_state_ops.variable(
|
||||
shape=[1],
|
||||
dtype=dtypes.float32,
|
||||
@ -68,7 +71,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
name="var_3",
|
||||
container="",
|
||||
shared_name="")
|
||||
with g.device(test_device_func_pin_variable_to_cpu):
|
||||
with g.device(TestDeviceFuncPinVariableToCpu):
|
||||
var_4 = gen_state_ops.variable(
|
||||
shape=[1],
|
||||
dtype=dtypes.float32,
|
||||
@ -101,7 +104,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
def testNestedDeviceFunctions(self):
|
||||
with ops.Graph().as_default():
|
||||
var_0 = variables.VariableV1(0)
|
||||
with ops.device(test_device_func_pin_variable_to_cpu):
|
||||
with ops.device(TestDeviceFuncPinVariableToCpu):
|
||||
var_1 = variables.VariableV1(1)
|
||||
with ops.device(lambda op: "/device:GPU:0"):
|
||||
var_2 = variables.VariableV1(2)
|
||||
@ -136,7 +139,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
|
||||
def testDefaultDevice(self):
|
||||
with ops.Graph().as_default() as g, g.device(
|
||||
test_device_func_pin_variable_to_cpu):
|
||||
TestDeviceFuncPinVariableToCpu):
|
||||
with g.device("/job:ps"):
|
||||
const_0 = constant_op.constant(5.0)
|
||||
with g.device("/device:GPU:0"):
|
||||
@ -303,6 +306,162 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
self.assertProtoEquals(graph_def,
|
||||
graph_util.remove_training_nodes(graph_def))
|
||||
|
||||
def testSimpleGraphdefsCompareEqual(self):
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.node.extend([
|
||||
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.node.extend([
|
||||
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testNodeDefsInDifferentOrderCompareEqual(self):
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.node.extend([
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", []),
|
||||
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||
])
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.node.extend([
|
||||
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testDifferentGraphDefsCompareNotEqual(self):
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.node.extend([
|
||||
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.node.extend([
|
||||
self.create_constant_node_def("C", 2, dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testGraphdefsWithNanCompareNonEqual(self):
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.node.extend([
|
||||
self.create_constant_node_def(
|
||||
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.node.extend([
|
||||
self.create_constant_node_def(
|
||||
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testSimpleGraphdefEqualityWithNansEqual(self):
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.node.extend([
|
||||
self.create_constant_node_def(
|
||||
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.node.extend([
|
||||
self.create_constant_node_def(
|
||||
"C", float("nan"), dtypes.float32, inputs=["^I"]),
|
||||
self.create_node_def("Identity", "I", ["Base"]),
|
||||
self.create_node_def("BaseOp", "Base", [])
|
||||
])
|
||||
self.assertTrue(
|
||||
graph_util.graph_defs_equal(
|
||||
graph_def1, graph_def2, treat_nan_as_equal=True))
|
||||
|
||||
def testGraphDefsWithFunctionLibsCompareEqual(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def F1(x):
|
||||
return math_ops.exp(x) - math_ops.exp(-x)
|
||||
|
||||
library = function_pb2.FunctionDefLibrary()
|
||||
library.function.extend([F1.definition])
|
||||
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.library.CopyFrom(library)
|
||||
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.library.CopyFrom(library)
|
||||
|
||||
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testGraphDefsWithPermutedFunctionsCompareEqual(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def F1(x):
|
||||
return math_ops.exp(x) - math_ops.exp(-x)
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def F2(x):
|
||||
return math_ops.exp(x)
|
||||
|
||||
definition_1 = F1.definition
|
||||
definition_2 = F2.definition
|
||||
library = function_pb2.FunctionDefLibrary()
|
||||
library.function.extend([definition_1, definition_2])
|
||||
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.library.CopyFrom(library)
|
||||
|
||||
reversed_library = function_pb2.FunctionDefLibrary()
|
||||
reversed_library.function.extend([definition_2, definition_1])
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.library.CopyFrom(reversed_library)
|
||||
|
||||
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def F1(x):
|
||||
return math_ops.exp(x) - math_ops.exp(-x)
|
||||
|
||||
f1_def = F1.definition
|
||||
|
||||
library = function_pb2.FunctionDefLibrary()
|
||||
library.function.extend([f1_def])
|
||||
|
||||
graph_def1 = graph_pb2.GraphDef()
|
||||
graph_def1.library.CopyFrom(library)
|
||||
|
||||
reversed_function = function_pb2.FunctionDef()
|
||||
reversed_function.CopyFrom(f1_def)
|
||||
# Clear the node_def attribute.
|
||||
del reversed_function.node_def[:]
|
||||
reversed_function.node_def.extend(reversed(f1_def.node_def))
|
||||
reversed_library = function_pb2.FunctionDefLibrary()
|
||||
reversed_library.function.extend([reversed_function])
|
||||
graph_def2 = graph_pb2.GraphDef()
|
||||
graph_def2.library.CopyFrom(reversed_library)
|
||||
|
||||
self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
79
tensorflow/python/framework/proto_comparators.cc
Normal file
79
tensorflow/python/framework/proto_comparators.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <Python.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
struct ProtoComparisonOptions; // Forward declaration
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tf = tensorflow;
|
||||
|
||||
struct ProtoComparisonOptions {
|
||||
bool treat_nan_as_equal;
|
||||
};
|
||||
|
||||
bool EqualsGraphDef(string graphdef_string1, string graphdef_string2,
|
||||
const ProtoComparisonOptions& options) {
|
||||
GraphDef graph_def_1;
|
||||
if (!graph_def_1.ParseFromString(graphdef_string1)) {
|
||||
MaybeRaiseFromStatus(errors::InvalidArgument(
|
||||
"Couldn't interpret first argument as a GraphDef"));
|
||||
}
|
||||
GraphDef graph_def_2;
|
||||
if (!graph_def_2.ParseFromString(graphdef_string2)) {
|
||||
MaybeRaiseFromStatus(errors::InvalidArgument(
|
||||
"Couldn't interpret second argument as a GraphDef"));
|
||||
}
|
||||
tf::protobuf::util::MessageDifferencer differencer;
|
||||
// Order doesnt matter in node defs, or functions in the function library and
|
||||
// their nested nodes.
|
||||
differencer.TreatAsSet(GraphDef::descriptor()->FindFieldByName("node"));
|
||||
differencer.TreatAsSet(
|
||||
FunctionDefLibrary::descriptor()->FindFieldByName("function"));
|
||||
differencer.TreatAsSet(
|
||||
FunctionDefLibrary::descriptor()->FindFieldByName("gradient"));
|
||||
differencer.TreatAsSet(
|
||||
FunctionDef::descriptor()->FindFieldByName("node_def"));
|
||||
tf::protobuf::util::DefaultFieldComparator comparator;
|
||||
comparator.set_treat_nan_as_equal(options.treat_nan_as_equal);
|
||||
differencer.set_field_comparator(&comparator);
|
||||
return differencer.Compare(graph_def_1, graph_def_2);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_proto_comparators, m) {
|
||||
py::class_<tensorflow::ProtoComparisonOptions>(m, "ProtoComparisonOptions")
|
||||
.def(py::init<const bool&>());
|
||||
m.def("EqualsGraphDef", &EqualsGraphDef,
|
||||
"GraphDef equality test taking comparison options.");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user