From c7621b86bc425245eb01557faded795cebc649ee Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Fri, 8 Feb 2019 16:24:59 -0800
Subject: [PATCH] Expose CriticalSection in core as tf.CriticalSection.

PiperOrigin-RevId: 233146763
---
 tensorflow/contrib/framework/BUILD            | 22 ------
 tensorflow/contrib/framework/__init__.py      |  2 -
 .../contrib/framework/python/ops/__init__.py  |  1 -
 tensorflow/python/BUILD                       | 17 +++++
 tensorflow/python/kernel_tests/BUILD          | 21 ++++++
 .../kernel_tests}/critical_section_test.py    | 67 ++++++++++++++-----
 .../python/ops/critical_section_ops.py        | 48 +++++++------
 tensorflow/python/ops/standard_ops.py         |  1 +
 .../v1/tensorflow.-critical-section.pbtxt     | 17 +++++
 .../tools/api/golden/v1/tensorflow.pbtxt      |  4 ++
 .../v2/tensorflow.-critical-section.pbtxt     | 17 +++++
 .../tools/api/golden/v2/tensorflow.pbtxt      |  4 ++
 .../tools/compatibility/tf_upgrade_v2.py      |  2 +
 .../tools/compatibility/tf_upgrade_v2_test.py |  6 ++
 14 files changed, 164 insertions(+), 65 deletions(-)
 rename tensorflow/{contrib/framework/python/ops => python/kernel_tests}/critical_section_test.py (87%)
 rename tensorflow/{contrib/framework => }/python/ops/critical_section_ops.py (92%)
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt

diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index c99f84789e0..8fd2b5f39bc 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -32,7 +32,6 @@ tf_custom_op_py_library(
         "python/ops/arg_scope.py",
         "python/ops/audio_ops.py",
         "python/ops/checkpoint_ops.py",
-        "python/ops/critical_section_ops.py",
         "python/ops/ops.py",
         "python/ops/prettyprint_ops.py",
         "python/ops/script_ops.py",
@@ -172,27 +171,6 @@ py_test(
     ],
 )
 
-cuda_py_test(
-    name = "critical_section_test",
-    size = "medium",
-    srcs = ["python/ops/critical_section_test.py"],
-    additional_deps = [
-        "//tensorflow/python:client_testlib",
-        ":framework_py",
-        "//tensorflow/python/data/experimental/ops:prefetching_ops",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:gradients",
-        "//tensorflow/python:platform_test",
-        "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:tensor_array_ops",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/eager:context",
-    ],
-)
-
 py_test(
     name = "ops_test",
     size = "small",
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index fc2334d5d7f..94fb35b3346 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -94,8 +94,6 @@
 @@smart_constant_value
 @@smart_case
 
-@@CriticalSection
-
 @@BoundedTensorSpec
 @@TensorSpec
 
diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py
index c4976497f5f..8113bf7c095 100644
--- a/tensorflow/contrib/framework/python/ops/__init__.py
+++ b/tensorflow/contrib/framework/python/ops/__init__.py
@@ -22,7 +22,6 @@ from __future__ import print_function
 # pylint: disable=wildcard-import
 from tensorflow.contrib.framework.python.ops.arg_scope import *
 from tensorflow.contrib.framework.python.ops.checkpoint_ops import *
-from tensorflow.contrib.framework.python.ops.critical_section_ops import *
 from tensorflow.contrib.framework.python.ops.ops import *
 from tensorflow.contrib.framework.python.ops.prettyprint_ops import *
 from tensorflow.contrib.framework.python.ops.script_ops import *
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ca6a09cdd49..d3254cc2e48 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2737,6 +2737,22 @@ py_library(
     ],
 )
 
+py_library(
+    name = "critical_section_ops",
+    srcs = ["ops/critical_section_ops.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":array_ops",
+        ":control_flow_ops",
+        ":dtypes",
+        ":framework_ops",
+        ":resource_variable_ops_gen",
+        ":tensor_array_ops",
+        ":util",
+        "//tensorflow/python/eager:context",
+    ],
+)
+
 py_library(
     name = "list_ops",
     srcs = ["ops/list_ops.py"],
@@ -3148,6 +3164,7 @@ py_library(
         ":clip_ops",
         ":confusion_matrix",
         ":control_flow_ops",
+        ":critical_section_ops",
         ":cudnn_rnn_grad",
         ":data_flow_grad",
         ":data_flow_ops",
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c660633a768..d98a6dd59c6 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3592,3 +3592,24 @@ cuda_py_test(
     grpc_enabled = True,
     xla_enable_strict_auto_jit = True,
 )
+
+cuda_py_test(
+    name = "critical_section_test",
+    size = "medium",
+    srcs = ["critical_section_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python/data/experimental/ops:prefetching_ops",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:gradients",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:critical_section_ops",
+        "//tensorflow/python:tensor_array_ops",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/eager:context",
+    ],
+)
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/python/kernel_tests/critical_section_test.py
similarity index 87%
rename from tensorflow/contrib/framework/python/ops/critical_section_test.py
rename to tensorflow/python/kernel_tests/critical_section_test.py
index d2bb4f476f4..7b1519c5e3c 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/python/kernel_tests/critical_section_test.py
@@ -18,14 +18,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.framework.python.ops import critical_section_ops
 from tensorflow.python.data.experimental.ops import prefetching_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import critical_section_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
@@ -48,7 +49,7 @@ class CriticalSectionTest(test.TestCase):
           return array_ops.identity(c)
 
     num_concurrent = 100
-    r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
+    r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
     self.evaluate(v.initializer)
     r_value = self.evaluate(r)
     self.assertAllClose([2.0 * i for i in range(num_concurrent)],
@@ -74,7 +75,7 @@ class CriticalSectionTest(test.TestCase):
               array_ops.identity(inner_cond), true_fn, lambda: c)
 
         def execute():
-          return cs.execute(fn, 1.0, 2.0)
+          return cs.execute(lambda: fn(1.0, 2.0))
 
         r = [
             control_flow_ops.cond(array_ops.identity(outer_cond),
@@ -92,6 +93,7 @@ class CriticalSectionTest(test.TestCase):
         else:
           self.assertAllClose([0] * num_concurrent, r_value)
 
+  @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
   def testCriticalSectionInParallelDoesntDeadlockOnError(self):
     # No eager mode execution of this test because eager does not
     # run fn() in parallel, which is where the deadlock could
@@ -103,12 +105,23 @@ class CriticalSectionTest(test.TestCase):
       error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
       with ops.control_dependencies([error]):
         return v.read_value()
+
     num_concurrent = 2
-    r = [cs.execute(fn, i) for i in range(num_concurrent)]
+
+    @def_function.function(autograph=False)
+    def run_concurrently():
+      return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
+
+    if not context.executing_eagerly():
+      run_concurrently = run_concurrently()
+
     self.evaluate(v.initializer)
     for _ in range(100):
       with self.assertRaisesOpError("Error"):
-        self.evaluate(r)
+        if context.executing_eagerly():
+          run_concurrently()
+        else:
+          self.evaluate(run_concurrently)
 
   @test_util.run_in_graph_and_eager_modes
   def testCreateCriticalSectionFnReturnsOp(self):
@@ -123,17 +136,20 @@ class CriticalSectionTest(test.TestCase):
           return control_flow_ops.no_op()
 
     num_concurrent = 100
-    r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
+    r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
+         for _ in range(num_concurrent)]
     self.evaluate(v.initializer)
     self.evaluate(r)
     final_v = self.evaluate(v)
     self.assertAllClose(2.0 * num_concurrent, final_v)
 
+  @test_util.run_v1_only("Collections don't exist in TF2")
   def testCollection(self):
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     self.assertIn(
         cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
-    execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
+    add = lambda x: x + 1
+    execute = cs.execute(lambda: add(1.0), name="my_execute")
     execute_op = [
         x for x in execute.graph.get_operations()
         if "my_execute" in x.name and "MutexLock" in x.type
@@ -143,18 +159,21 @@ class CriticalSectionTest(test.TestCase):
         [signature.op for signature in
          ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
 
+  @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
   def testRecursiveCriticalSectionAccessIsIllegal(self):
     # This does not work properly in eager mode.  Eager users will
     # just hit a deadlock if they do this.  But at least it'll be easier
     # to debug.
     cs = critical_section_ops.CriticalSection()
+    add = lambda y: y + 1
     def fn(x):
-      return cs.execute(lambda y: y + 1, x)
+      return cs.execute(lambda: add(x))
+
     with self.assertRaisesRegexp(
         ValueError,
         r"attempts to directly access the CriticalSection in which it "
         r"would be running"):
-      cs.execute(fn, 1.0)
+      cs.execute(lambda: fn(1.0))
 
   def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
     # This one is subtle; and we're being overly cautious here.  The
@@ -174,24 +193,24 @@ class CriticalSectionTest(test.TestCase):
     # operations are finished before anything runs within the critical section.
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     fn = array_ops.identity
-    to_capture = cs.execute(fn, 1.0)
+    to_capture = cs.execute(lambda: fn(1.0))
     fn_captures = lambda x: x + to_capture
     to_capture_too = array_ops.identity(to_capture)
 
-    ex_0 = cs.execute(fn_captures, 1.0)
+    ex_0 = cs.execute(lambda: fn_captures(1.0))
 
     with ops.control_dependencies([to_capture]):
       # This is OK because to_capture will execute before this next call
-      ex_1 = cs.execute(fn_captures, 1.0)
+      ex_1 = cs.execute(lambda: fn_captures(1.0))
 
     dependency = array_ops.identity(to_capture)
 
     fn_captures_dependency = lambda x: x + dependency
 
-    ex_2 = cs.execute(fn_captures_dependency, 1.0)
+    ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
 
     with ops.control_dependencies([to_capture_too]):
-      ex_3 = cs.execute(fn_captures_dependency, 1.0)
+      ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
 
     # Ensure there's no actual deadlock on to_execute.
     self.assertEquals(2.0, self.evaluate(ex_0))
@@ -217,6 +236,8 @@ class CriticalSectionTest(test.TestCase):
         body_implicit_capture,
         [0, 0],
         parallel_iterations=25)
+    # For consistency between eager and graph mode.
+    i_n = array_ops.identity(i_n)
     logging.warn(
         "\n==============\nRunning "
         "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@@ -242,6 +263,8 @@ class CriticalSectionTest(test.TestCase):
         body_implicit_capture_protected,
         [0, 0],
         parallel_iterations=25)
+    # For consistency between eager and graph mode.
+    i_n = array_ops.identity(i_n)
     logging.warn(
         "\n==============\nRunning "
         "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@@ -258,13 +281,15 @@ class CriticalSectionTest(test.TestCase):
       # This version is ok because j is an argument to fn and we can
       # ensure there's a control dependency on j.
       fn = lambda x: x + 1
-      return (i + 1, cs.execute(fn, j))
+      return (i + 1, cs.execute(lambda: fn(j)))
 
     (i_n, j_n) = control_flow_ops.while_loop(
         lambda i, _: i < 1000,
         body_args_capture,
         [0, 0],
         parallel_iterations=25)
+    # For consistency between eager and graph mode.
+    i_n = array_ops.identity(i_n)
     logging.warn(
         "\n==============\nRunning "
         "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@@ -277,20 +302,23 @@ class CriticalSectionTest(test.TestCase):
         "body_args_capture'\n"
         "==============\n")
 
+  @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
   def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
     # This does not work properly in eager mode.  Eager users will
     # just hit a deadlock if they do this.  But at least it'll be easier
     # to debug.
     cs = critical_section_ops.CriticalSection(shared_name="cs")
     cs_same = critical_section_ops.CriticalSection(shared_name="cs")
+    add = lambda x: x + 1
     def fn(x):
-      return cs_same.execute(lambda x: x+1, x)
+      return cs_same.execute(lambda: add(x))
     with self.assertRaisesRegexp(
         ValueError,
         r"attempts to directly access the CriticalSection in which it "
         r"would be running"):
-      cs.execute(fn, 1.0)
+      cs.execute(lambda: fn(1.0))
 
+  @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
   def testMultipleCSExecutionsRequestSameResource(self):
     cs0 = critical_section_ops.CriticalSection()
     cs1 = critical_section_ops.CriticalSection()
@@ -328,8 +356,11 @@ class CriticalSectionTest(test.TestCase):
     # Note, here v must be a resource variable (or something similar),
     # otherwise it gets hoisted into the while_loop by the time we add
     # control dependencies to the lock_op.
+    def body(i):
+      add_j = lambda j: v + j + 1
+      return cs.execute(lambda: add_j(i))
     out = control_flow_ops.while_loop(
-        lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
+        lambda i: i < 10, body, [0])
     self.evaluate(v.initializer)
     self.assertEqual(10, self.evaluate(out))
 
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/python/ops/critical_section_ops.py
similarity index 92%
rename from tensorflow/contrib/framework/python/ops/critical_section_ops.py
rename to tensorflow/python/ops/critical_section_ops.py
index 71ab755aa29..21872ffff13 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py
+++ b/tensorflow/python/ops/critical_section_ops.py
@@ -31,6 +31,10 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
+
+
+__all__ = ["CriticalSection"]
 
 
 # Graph Keys
@@ -66,6 +70,7 @@ def _get_colocation(op):
     return None
 
 
+@tf_export("CriticalSection")
 class CriticalSection(object):
   """Critical section.
 
@@ -179,37 +184,36 @@ class CriticalSection(object):
   def name(self):
     return self._handle.op.name
 
-  def execute(self, fn, *args, **kwargs):
-    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.
+  def execute(self, fn, exclusive_resource_access=True, name=None):
+    """Execute function `fn()` inside the critical section.
+
+    `fn` should not accept any arguments.  To add extra arguments to when
+    calling `fn` in the critical section, create a lambda:
+
+    ```python
+    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
+    ```
 
     Args:
       fn: The function to execute.  Must return at least one tensor.
-      *args: Additional positional arguments to `fn`.
-      **kwargs: Additional keyword arguments to `fn`.
-        Several keywords are reserved for `execute`.  These are:
-
-        - name; The name to use when creating the execute operation.
-        - exclusive_resource_access; Whether the resources required by
-          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
-          You may want to set this to `False` if you will be accessing a
-          resource in read-only mode in two different CriticalSections.
+      exclusive_resource_access: Whether the resources required by
+        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
+        You may want to set this to `False` if you will be accessing a
+        resource in read-only mode in two different CriticalSections.
+      name: The name to use when creating the execute operation.
 
     Returns:
-      The tensors returned from `fn(*args, **kwargs)`.
+      The tensors returned from `fn()`.
 
     Raises:
       ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
         or lazy way that may cause a deadlock.
-      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
+      ValueError: If `exclusive_resource_access == True` and
         another `CriticalSection` has an execution requesting the same
-        resources as in `*args`, `**kwargs`, and any additionally captured
-        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
-        if another execution in another `CriticalSection` was created without
-        `exclusive_resource_access=True`, a `ValueError` will be raised.
+        resources as `fn``.  Note, even if `exclusive_resource_access` is
+        `True`, if another execution in another `CriticalSection` was created
+        without `exclusive_resource_access=True`, a `ValueError` will be raised.
     """
-    name = kwargs.pop("name", None)
-    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
-
     with ops.name_scope(name, "critical_section_execute", []):
 
       # Ensure that mutex locking only happens *after* all args and
@@ -222,7 +226,7 @@ class CriticalSection(object):
         with ops.get_default_graph()._lock:  # pylint: disable=protected-access
           existing_ops = ops.get_default_graph().get_operations()
           with ops.control_dependencies([lock]):
-            r = fn(*args, **kwargs)
+            r = fn()
           # TODO(ebrevdo): If creating critical sections in a python loop, this
           # makes graph creation time quadratic.  Revisit if this
           # becomes a problem.
@@ -230,7 +234,7 @@ class CriticalSection(object):
                          .difference(existing_ops))
       else:
         with ops.control_dependencies([lock]):
-          r = fn(*args, **kwargs)
+          r = fn()
 
       if not context.executing_eagerly():
         self._add_control_dependencies_to_lock(created_ops, lock.op)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index ba3bd094923..5e217d8ed2f 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple  # pylint: disable=rede
 # pylint: enable=redefined-builtin
 from tensorflow.python.eager import wrap_function
 from tensorflow.python.ops.control_flow_ops import while_loop
+from tensorflow.python.ops.critical_section_ops import *
 from tensorflow.python.ops.data_flow_ops import *
 from tensorflow.python.ops.functional_ops import *
 from tensorflow.python.ops.gradients import *
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt
new file mode 100644
index 00000000000..024a2083463
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.CriticalSection"
+tf_class {
+  is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "execute"
+    argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 64f572022ef..cb9d6a907f7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -32,6 +32,10 @@ tf_module {
     name: "ConfigProto"
     mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
   }
+  member {
+    name: "CriticalSection"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "DType"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt
new file mode 100644
index 00000000000..024a2083463
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.CriticalSection"
+tf_class {
+  is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "execute"
+    argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index a75c0c27dcd..5db60220c76 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -4,6 +4,10 @@ tf_module {
     name: "AggregationMethod"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "CriticalSection"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "DType"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 7e86d6cfa50..080e9420a1b 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -540,6 +540,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
             "tf.data.experimental.unbatch",
         "tf.contrib.data.unique":
             "tf.data.experimental.unique",
+        "tf.contrib.framework.CriticalSection":
+            "tf.CriticalSection",
         "tf.contrib.framework.is_tensor":
             "tf.is_tensor",
         "tf.contrib.framework.nest.assert_same_structure":
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index bf1e80a6816..c78db3fdf67 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -1157,6 +1157,12 @@ def _log_prob(self, x):
     _, _, _, new_text = self._upgrade(text)
     self.assertEqual(expected, new_text)
 
+  def test_CriticalSection_upgrade(self):
+    text = "tf.contrib.framework.CriticalSection(shared_name='blah')"
+    expected = "tf.CriticalSection(shared_name='blah')"
+    _, _, _, new_text = self._upgrade(text)
+    self.assertEqual(expected, new_text)
+
   def test_sample_distorted_bounding_box(self):
     # pylint: disable=line-too-long
     text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"