diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h
index 371912cc2b7..6950cb9a1b6 100644
--- a/tensorflow/core/platform/protobuf.h
+++ b/tensorflow/core/platform/protobuf.h
@@ -37,7 +37,9 @@ limitations under the License.
 #include "google/protobuf/message.h"
 #include "google/protobuf/repeated_field.h"
 #include "google/protobuf/text_format.h"
+#include "google/protobuf/util/field_comparator.h"
 #include "google/protobuf/util/json_util.h"
+#include "google/protobuf/util/message_differencer.h"
 #include "google/protobuf/util/type_resolver_util.h"
 
 namespace tensorflow {
@@ -116,7 +118,6 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
 
   tstring* target_;
 };
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c4a15c55f4d..7b1ccdae12e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -965,6 +965,7 @@ py_library(
     ],
     srcs_version = "PY3",
     deps = [
+        ":_proto_comparators",
         ":dtypes",
         ":framework_ops",
         ":platform",
@@ -6583,6 +6584,21 @@ tf_gen_op_wrapper_private_py(
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_proto_comparators",
+    srcs = ["framework/proto_comparators.cc"],
+    module_name = "_proto_comparators",
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:protobuf",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@pybind11",
+    ],
+)
+
 py_library(
     name = "proto_ops",
     srcs = ["ops/proto_ops.py"],
diff --git a/tensorflow/python/framework/graph_util.py b/tensorflow/python/framework/graph_util.py
index c5cc1107343..117926a4ab7 100644
--- a/tensorflow/python/framework/graph_util.py
+++ b/tensorflow/python/framework/graph_util.py
@@ -24,6 +24,7 @@ from __future__ import print_function
 # pylint: disable=unused-import
 from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants
 from tensorflow.python.framework.graph_util_impl import extract_sub_graph
+from tensorflow.python.framework.graph_util_impl import graph_defs_equal
 from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
 from tensorflow.python.framework.graph_util_impl import remove_training_nodes
 from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 4ef26fc8539..a2055241517 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -25,6 +25,7 @@ import six
 
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import node_def_pb2
+from tensorflow.python import _proto_comparators
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.util import deprecation
@@ -374,3 +375,37 @@ def remove_training_nodes(input_graph, protected_nodes=None):
   output_graph = graph_pb2.GraphDef()
   output_graph.node.extend(nodes_after_splicing)
   return output_graph
+
+
+def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
+                     graph_def_2: graph_pb2.GraphDef,
+                     treat_nan_as_equal: bool = False) -> bool:
+  """Returns True iff the graph def arguments are structurally equivalent.
+
+  The notion of equivalence encoded here checks that the set of NodeDefs in
+  the GraphDef's function library and main graph body are identical.
+  Additionally, it checks that the functions in the function library are equal
+  as sets.
+
+  Args:
+    graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
+    graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
+    treat_nan_as_equal: Boolean indicating whether or not to treat nan
+      floating-point values as equal. This is crucial for any equivalence
+      relation defined over GraphDefs, to ensure symmetry.
+
+  Returns:
+    Boolean indicating structural equivalence as described above.
+
+  Raises:
+    TypeError: If either of the GraphDefs are not instances of
+      `graph_pb2.GraphDef`.
+  """
+  if not isinstance(graph_def_1, graph_pb2.GraphDef):
+    raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
+  if not isinstance(graph_def_2, graph_pb2.GraphDef):
+    raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
+  options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
+  return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
+                                           graph_def_2.SerializeToString(),
+                                           options)
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 4957ee7d97e..b23bdaf3b59 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -19,27 +19,30 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import function_pb2
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import node_def_pb2
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
 from tensorflow.python.framework import graph_util
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import gen_state_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
 # Utility device function to use for testing
-def test_device_func_pin_variable_to_cpu(op):
+def TestDeviceFuncPinVariableToCpu(op):
   if op.device:
     return op.device
   return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
 
 
-class DeviceFunctionsTest(test.TestCase):
+class GraphUtilTest(test.TestCase):
 
   def testTwoDeviceFunctions(self):
     with ops.Graph().as_default() as g:
@@ -49,7 +52,7 @@ class DeviceFunctionsTest(test.TestCase):
           name="var_0",
           container="",
           shared_name="")
-      with g.device(test_device_func_pin_variable_to_cpu):
+      with g.device(TestDeviceFuncPinVariableToCpu):
         var_1 = gen_state_ops.variable(
             shape=[1],
             dtype=dtypes.float32,
@@ -68,7 +71,7 @@ class DeviceFunctionsTest(test.TestCase):
           name="var_3",
           container="",
           shared_name="")
-      with g.device(test_device_func_pin_variable_to_cpu):
+      with g.device(TestDeviceFuncPinVariableToCpu):
         var_4 = gen_state_ops.variable(
             shape=[1],
             dtype=dtypes.float32,
@@ -101,7 +104,7 @@ class DeviceFunctionsTest(test.TestCase):
   def testNestedDeviceFunctions(self):
     with ops.Graph().as_default():
       var_0 = variables.VariableV1(0)
-      with ops.device(test_device_func_pin_variable_to_cpu):
+      with ops.device(TestDeviceFuncPinVariableToCpu):
         var_1 = variables.VariableV1(1)
         with ops.device(lambda op: "/device:GPU:0"):
           var_2 = variables.VariableV1(2)
@@ -136,7 +139,7 @@ class DeviceFunctionsTest(test.TestCase):
 
   def testDefaultDevice(self):
     with ops.Graph().as_default() as g, g.device(
-        test_device_func_pin_variable_to_cpu):
+        TestDeviceFuncPinVariableToCpu):
       with g.device("/job:ps"):
         const_0 = constant_op.constant(5.0)
       with g.device("/device:GPU:0"):
@@ -303,6 +306,162 @@ class DeviceFunctionsTest(test.TestCase):
     self.assertProtoEquals(graph_def,
                            graph_util.remove_training_nodes(graph_def))
 
+  def testSimpleGraphdefsCompareEqual(self):
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.node.extend([
+        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.node.extend([
+        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testNodeDefsInDifferentOrderCompareEqual(self):
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.node.extend([
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", []),
+        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
+    ])
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.node.extend([
+        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testDifferentGraphDefsCompareNotEqual(self):
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.node.extend([
+        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.node.extend([
+        self.create_constant_node_def("C", 2, dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+    self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testGraphdefsWithNanCompareNonEqual(self):
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.node.extend([
+        self.create_constant_node_def(
+            "C", float("nan"), dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.node.extend([
+        self.create_constant_node_def(
+            "C", float("nan"), dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+    self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testSimpleGraphdefEqualityWithNansEqual(self):
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.node.extend([
+        self.create_constant_node_def(
+            "C", float("nan"), dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.node.extend([
+        self.create_constant_node_def(
+            "C", float("nan"), dtypes.float32, inputs=["^I"]),
+        self.create_node_def("Identity", "I", ["Base"]),
+        self.create_node_def("BaseOp", "Base", [])
+    ])
+    self.assertTrue(
+        graph_util.graph_defs_equal(
+            graph_def1, graph_def2, treat_nan_as_equal=True))
+
+  def testGraphDefsWithFunctionLibsCompareEqual(self):
+
+    @function.Defun(dtypes.float32)
+    def F1(x):
+      return math_ops.exp(x) - math_ops.exp(-x)
+
+    library = function_pb2.FunctionDefLibrary()
+    library.function.extend([F1.definition])
+
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.library.CopyFrom(library)
+
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.library.CopyFrom(library)
+
+    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testGraphDefsWithPermutedFunctionsCompareEqual(self):
+
+    @function.Defun(dtypes.float32)
+    def F1(x):
+      return math_ops.exp(x) - math_ops.exp(-x)
+
+    @function.Defun(dtypes.float32)
+    def F2(x):
+      return math_ops.exp(x)
+
+    definition_1 = F1.definition
+    definition_2 = F2.definition
+    library = function_pb2.FunctionDefLibrary()
+    library.function.extend([definition_1, definition_2])
+
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.library.CopyFrom(library)
+
+    reversed_library = function_pb2.FunctionDefLibrary()
+    reversed_library.function.extend([definition_2, definition_1])
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.library.CopyFrom(reversed_library)
+
+    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
+  def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):
+
+    @function.Defun(dtypes.float32)
+    def F1(x):
+      return math_ops.exp(x) - math_ops.exp(-x)
+
+    f1_def = F1.definition
+
+    library = function_pb2.FunctionDefLibrary()
+    library.function.extend([f1_def])
+
+    graph_def1 = graph_pb2.GraphDef()
+    graph_def1.library.CopyFrom(library)
+
+    reversed_function = function_pb2.FunctionDef()
+    reversed_function.CopyFrom(f1_def)
+    # Clear the node_def attribute.
+    del reversed_function.node_def[:]
+    reversed_function.node_def.extend(reversed(f1_def.node_def))
+    reversed_library = function_pb2.FunctionDefLibrary()
+    reversed_library.function.extend([reversed_function])
+    graph_def2 = graph_pb2.GraphDef()
+    graph_def2.library.CopyFrom(reversed_library)
+
+    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/framework/proto_comparators.cc b/tensorflow/python/framework/proto_comparators.cc
new file mode 100644
index 00000000000..5d457e9d91b
--- /dev/null
+++ b/tensorflow/python/framework/proto_comparators.cc
@@ -0,0 +1,79 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <Python.h>
+
+#include <memory>
+#include <string>
+
+#include "pybind11/detail/common.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+struct ProtoComparisonOptions;  // Forward declaration
+
+namespace tensorflow {
+
+namespace {
+
+namespace py = pybind11;
+namespace tf = tensorflow;
+
+struct ProtoComparisonOptions {
+  bool treat_nan_as_equal;
+};
+
+bool EqualsGraphDef(string graphdef_string1, string graphdef_string2,
+                    const ProtoComparisonOptions& options) {
+  GraphDef graph_def_1;
+  if (!graph_def_1.ParseFromString(graphdef_string1)) {
+    MaybeRaiseFromStatus(errors::InvalidArgument(
+        "Couldn't interpret first argument as a GraphDef"));
+  }
+  GraphDef graph_def_2;
+  if (!graph_def_2.ParseFromString(graphdef_string2)) {
+    MaybeRaiseFromStatus(errors::InvalidArgument(
+        "Couldn't interpret second argument as a GraphDef"));
+  }
+  tf::protobuf::util::MessageDifferencer differencer;
+  // Order doesnt matter in node defs, or functions in the function library and
+  // their nested nodes.
+  differencer.TreatAsSet(GraphDef::descriptor()->FindFieldByName("node"));
+  differencer.TreatAsSet(
+      FunctionDefLibrary::descriptor()->FindFieldByName("function"));
+  differencer.TreatAsSet(
+      FunctionDefLibrary::descriptor()->FindFieldByName("gradient"));
+  differencer.TreatAsSet(
+      FunctionDef::descriptor()->FindFieldByName("node_def"));
+  tf::protobuf::util::DefaultFieldComparator comparator;
+  comparator.set_treat_nan_as_equal(options.treat_nan_as_equal);
+  differencer.set_field_comparator(&comparator);
+  return differencer.Compare(graph_def_1, graph_def_2);
+}
+
+PYBIND11_MODULE(_proto_comparators, m) {
+  py::class_<tensorflow::ProtoComparisonOptions>(m, "ProtoComparisonOptions")
+      .def(py::init<const bool&>());
+  m.def("EqualsGraphDef", &EqualsGraphDef,
+        "GraphDef equality test taking comparison options.");
+}
+
+}  // anonymous namespace
+
+}  // namespace tensorflow