From c7621b86bc425245eb01557faded795cebc649ee Mon Sep 17 00:00:00 2001 From: Eugene Brevdo <ebrevdo@google.com> Date: Fri, 8 Feb 2019 16:24:59 -0800 Subject: [PATCH] Expose CriticalSection in core as tf.CriticalSection. PiperOrigin-RevId: 233146763 --- tensorflow/contrib/framework/BUILD | 22 ------ tensorflow/contrib/framework/__init__.py | 2 - .../contrib/framework/python/ops/__init__.py | 1 - tensorflow/python/BUILD | 17 +++++ tensorflow/python/kernel_tests/BUILD | 21 ++++++ .../kernel_tests}/critical_section_test.py | 67 ++++++++++++++----- .../python/ops/critical_section_ops.py | 48 +++++++------ tensorflow/python/ops/standard_ops.py | 1 + .../v1/tensorflow.-critical-section.pbtxt | 17 +++++ .../tools/api/golden/v1/tensorflow.pbtxt | 4 ++ .../v2/tensorflow.-critical-section.pbtxt | 17 +++++ .../tools/api/golden/v2/tensorflow.pbtxt | 4 ++ .../tools/compatibility/tf_upgrade_v2.py | 2 + .../tools/compatibility/tf_upgrade_v2_test.py | 6 ++ 14 files changed, 164 insertions(+), 65 deletions(-) rename tensorflow/{contrib/framework/python/ops => python/kernel_tests}/critical_section_test.py (87%) rename tensorflow/{contrib/framework => }/python/ops/critical_section_ops.py (92%) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index c99f84789e0..8fd2b5f39bc 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -32,7 +32,6 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", - "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", "python/ops/script_ops.py", @@ -172,27 +171,6 @@ py_test( ], ) -cuda_py_test( - name = "critical_section_test", - size = "medium", - srcs = ["python/ops/critical_section_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":framework_py", - "//tensorflow/python/data/experimental/ops:prefetching_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - ], -) - py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index fc2334d5d7f..94fb35b3346 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -94,8 +94,6 @@ @@smart_constant_value @@smart_case -@@CriticalSection - @@BoundedTensorSpec @@TensorSpec diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index c4976497f5f..8113bf7c095 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,7 +22,6 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * -from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * from tensorflow.contrib.framework.python.ops.script_ops import * diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ca6a09cdd49..d3254cc2e48 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2737,6 +2737,22 @@ py_library( ], ) +py_library( + name = "critical_section_ops", + srcs = ["ops/critical_section_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":control_flow_ops", + ":dtypes", + ":framework_ops", + ":resource_variable_ops_gen", + ":tensor_array_ops", + ":util", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "list_ops", srcs = ["ops/list_ops.py"], @@ -3148,6 +3164,7 @@ py_library( ":clip_ops", ":confusion_matrix", ":control_flow_ops", + ":critical_section_ops", ":cudnn_rnn_grad", ":data_flow_grad", ":data_flow_ops", diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index c660633a768..d98a6dd59c6 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3592,3 +3592,24 @@ cuda_py_test( grpc_enabled = True, xla_enable_strict_auto_jit = True, ) + +cuda_py_test( + name = "critical_section_test", + size = "medium", + srcs = ["critical_section_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:prefetching_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:critical_section_ops", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + ], +) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/python/kernel_tests/critical_section_test.py similarity index 87% rename from tensorflow/contrib/framework/python/ops/critical_section_test.py rename to tensorflow/python/kernel_tests/critical_section_test.py index d2bb4f476f4..7b1519c5e3c 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/python/kernel_tests/critical_section_test.py @@ -18,14 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import critical_section_ops from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import critical_section_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -48,7 +49,7 @@ class CriticalSectionTest(test.TestCase): return array_ops.identity(c) num_concurrent = 100 - r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] + r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], @@ -74,7 +75,7 @@ class CriticalSectionTest(test.TestCase): array_ops.identity(inner_cond), true_fn, lambda: c) def execute(): - return cs.execute(fn, 1.0, 2.0) + return cs.execute(lambda: fn(1.0, 2.0)) r = [ control_flow_ops.cond(array_ops.identity(outer_cond), @@ -92,6 +93,7 @@ class CriticalSectionTest(test.TestCase): else: self.assertAllClose([0] * num_concurrent, r_value) + @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls") def testCriticalSectionInParallelDoesntDeadlockOnError(self): # No eager mode execution of this test because eager does not # run fn() in parallel, which is where the deadlock could @@ -103,12 +105,23 @@ class CriticalSectionTest(test.TestCase): error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) with ops.control_dependencies([error]): return v.read_value() + num_concurrent = 2 - r = [cs.execute(fn, i) for i in range(num_concurrent)] + + @def_function.function(autograph=False) + def run_concurrently(): + return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)] + + if not context.executing_eagerly(): + run_concurrently = run_concurrently() + self.evaluate(v.initializer) for _ in range(100): with self.assertRaisesOpError("Error"): - self.evaluate(r) + if context.executing_eagerly(): + run_concurrently() + else: + self.evaluate(run_concurrently) @test_util.run_in_graph_and_eager_modes def testCreateCriticalSectionFnReturnsOp(self): @@ -123,17 +136,20 @@ class CriticalSectionTest(test.TestCase): return control_flow_ops.no_op() num_concurrent = 100 - r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] + r = [cs.execute(lambda: fn_return_op(1.0, 2.0)) + for _ in range(num_concurrent)] self.evaluate(v.initializer) self.evaluate(r) final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v) + @test_util.run_v1_only("Collections don't exist in TF2") def testCollection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) - execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute") + add = lambda x: x + 1 + execute = cs.execute(lambda: add(1.0), name="my_execute") execute_op = [ x for x in execute.graph.get_operations() if "my_execute" in x.name and "MutexLock" in x.type @@ -143,18 +159,21 @@ class CriticalSectionTest(test.TestCase): [signature.op for signature in ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) + @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode") def testRecursiveCriticalSectionAccessIsIllegal(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection() + add = lambda y: y + 1 def fn(x): - return cs.execute(lambda y: y + 1, x) + return cs.execute(lambda: add(x)) + with self.assertRaisesRegexp( ValueError, r"attempts to directly access the CriticalSection in which it " r"would be running"): - cs.execute(fn, 1.0) + cs.execute(lambda: fn(1.0)) def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): # This one is subtle; and we're being overly cautious here. The @@ -174,24 +193,24 @@ class CriticalSectionTest(test.TestCase): # operations are finished before anything runs within the critical section. cs = critical_section_ops.CriticalSection(shared_name="cs") fn = array_ops.identity - to_capture = cs.execute(fn, 1.0) + to_capture = cs.execute(lambda: fn(1.0)) fn_captures = lambda x: x + to_capture to_capture_too = array_ops.identity(to_capture) - ex_0 = cs.execute(fn_captures, 1.0) + ex_0 = cs.execute(lambda: fn_captures(1.0)) with ops.control_dependencies([to_capture]): # This is OK because to_capture will execute before this next call - ex_1 = cs.execute(fn_captures, 1.0) + ex_1 = cs.execute(lambda: fn_captures(1.0)) dependency = array_ops.identity(to_capture) fn_captures_dependency = lambda x: x + dependency - ex_2 = cs.execute(fn_captures_dependency, 1.0) + ex_2 = cs.execute(lambda: fn_captures_dependency(1.0)) with ops.control_dependencies([to_capture_too]): - ex_3 = cs.execute(fn_captures_dependency, 1.0) + ex_3 = cs.execute(lambda: fn_captures_dependency(1.0)) # Ensure there's no actual deadlock on to_execute. self.assertEquals(2.0, self.evaluate(ex_0)) @@ -217,6 +236,8 @@ class CriticalSectionTest(test.TestCase): body_implicit_capture, [0, 0], parallel_iterations=25) + # For consistency between eager and graph mode. + i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " @@ -242,6 +263,8 @@ class CriticalSectionTest(test.TestCase): body_implicit_capture_protected, [0, 0], parallel_iterations=25) + # For consistency between eager and graph mode. + i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " @@ -258,13 +281,15 @@ class CriticalSectionTest(test.TestCase): # This version is ok because j is an argument to fn and we can # ensure there's a control dependency on j. fn = lambda x: x + 1 - return (i + 1, cs.execute(fn, j)) + return (i + 1, cs.execute(lambda: fn(j))) (i_n, j_n) = control_flow_ops.while_loop( lambda i, _: i < 1000, body_args_capture, [0, 0], parallel_iterations=25) + # For consistency between eager and graph mode. + i_n = array_ops.identity(i_n) logging.warn( "\n==============\nRunning " "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " @@ -277,20 +302,23 @@ class CriticalSectionTest(test.TestCase): "body_args_capture'\n" "==============\n") + @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode") def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection(shared_name="cs") cs_same = critical_section_ops.CriticalSection(shared_name="cs") + add = lambda x: x + 1 def fn(x): - return cs_same.execute(lambda x: x+1, x) + return cs_same.execute(lambda: add(x)) with self.assertRaisesRegexp( ValueError, r"attempts to directly access the CriticalSection in which it " r"would be running"): - cs.execute(fn, 1.0) + cs.execute(lambda: fn(1.0)) + @test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode") def testMultipleCSExecutionsRequestSameResource(self): cs0 = critical_section_ops.CriticalSection() cs1 = critical_section_ops.CriticalSection() @@ -328,8 +356,11 @@ class CriticalSectionTest(test.TestCase): # Note, here v must be a resource variable (or something similar), # otherwise it gets hoisted into the while_loop by the time we add # control dependencies to the lock_op. + def body(i): + add_j = lambda j: v + j + 1 + return cs.execute(lambda: add_j(i)) out = control_flow_ops.while_loop( - lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + lambda i: i < 10, body, [0]) self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/python/ops/critical_section_ops.py similarity index 92% rename from tensorflow/contrib/framework/python/ops/critical_section_ops.py rename to tensorflow/python/ops/critical_section_ops.py index 71ab755aa29..21872ffff13 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/python/ops/critical_section_ops.py @@ -31,6 +31,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export + + +__all__ = ["CriticalSection"] # Graph Keys @@ -66,6 +70,7 @@ def _get_colocation(op): return None +@tf_export("CriticalSection") class CriticalSection(object): """Critical section. @@ -179,37 +184,36 @@ class CriticalSection(object): def name(self): return self._handle.op.name - def execute(self, fn, *args, **kwargs): - """Execute function `fn(*args, **kwargs)` inside the CriticalSection. + def execute(self, fn, exclusive_resource_access=True, name=None): + """Execute function `fn()` inside the critical section. + + `fn` should not accept any arguments. To add extra arguments to when + calling `fn` in the critical section, create a lambda: + + ```python + critical_section.execute(lambda: fn(*my_args, **my_kwargs)) + ``` Args: fn: The function to execute. Must return at least one tensor. - *args: Additional positional arguments to `fn`. - **kwargs: Additional keyword arguments to `fn`. - Several keywords are reserved for `execute`. These are: - - - name; The name to use when creating the execute operation. - - exclusive_resource_access; Whether the resources required by - `fn` should be exclusive to this `CriticalSection`. Default: `True`. - You may want to set this to `False` if you will be accessing a - resource in read-only mode in two different CriticalSections. + exclusive_resource_access: Whether the resources required by + `fn` should be exclusive to this `CriticalSection`. Default: `True`. + You may want to set this to `False` if you will be accessing a + resource in read-only mode in two different CriticalSections. + name: The name to use when creating the execute operation. Returns: - The tensors returned from `fn(*args, **kwargs)`. + The tensors returned from `fn()`. Raises: ValueError: If `fn` attempts to lock this `CriticalSection` in any nested or lazy way that may cause a deadlock. - ValueError: If `exclusive_resource_access` is not provided (is `True`) and + ValueError: If `exclusive_resource_access == True` and another `CriticalSection` has an execution requesting the same - resources as in `*args`, `**kwargs`, and any additionally captured - inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, - if another execution in another `CriticalSection` was created without - `exclusive_resource_access=True`, a `ValueError` will be raised. + resources as `fn``. Note, even if `exclusive_resource_access` is + `True`, if another execution in another `CriticalSection` was created + without `exclusive_resource_access=True`, a `ValueError` will be raised. """ - name = kwargs.pop("name", None) - exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) - with ops.name_scope(name, "critical_section_execute", []): # Ensure that mutex locking only happens *after* all args and @@ -222,7 +226,7 @@ class CriticalSection(object): with ops.get_default_graph()._lock: # pylint: disable=protected-access existing_ops = ops.get_default_graph().get_operations() with ops.control_dependencies([lock]): - r = fn(*args, **kwargs) + r = fn() # TODO(ebrevdo): If creating critical sections in a python loop, this # makes graph creation time quadratic. Revisit if this # becomes a problem. @@ -230,7 +234,7 @@ class CriticalSection(object): .difference(existing_ops)) else: with ops.control_dependencies([lock]): - r = fn(*args, **kwargs) + r = fn() if not context.executing_eagerly(): self._add_control_dependencies_to_lock(created_ops, lock.op) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index ba3bd094923..5e217d8ed2f 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=rede # pylint: enable=redefined-builtin from tensorflow.python.eager import wrap_function from tensorflow.python.ops.control_flow_ops import while_loop +from tensorflow.python.ops.critical_section_ops import * from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.functional_ops import * from tensorflow.python.ops.gradients import * diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt new file mode 100644 index 00000000000..024a2083463 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-critical-section.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.CriticalSection" +tf_class { + is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>" + is_instance: "<type \'object\'>" + member { + name: "name" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "execute" + argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 64f572022ef..cb9d6a907f7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -32,6 +32,10 @@ tf_module { name: "ConfigProto" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" } + member { + name: "CriticalSection" + mtype: "<type \'type\'>" + } member { name: "DType" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt new file mode 100644 index 00000000000..024a2083463 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-critical-section.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.CriticalSection" +tf_class { + is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>" + is_instance: "<type \'object\'>" + member { + name: "name" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "execute" + argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index a75c0c27dcd..5db60220c76 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -4,6 +4,10 @@ tf_module { name: "AggregationMethod" mtype: "<type \'type\'>" } + member { + name: "CriticalSection" + mtype: "<type \'type\'>" + } member { name: "DType" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 7e86d6cfa50..080e9420a1b 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -540,6 +540,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.data.experimental.unbatch", "tf.contrib.data.unique": "tf.data.experimental.unique", + "tf.contrib.framework.CriticalSection": + "tf.CriticalSection", "tf.contrib.framework.is_tensor": "tf.is_tensor", "tf.contrib.framework.nest.assert_same_structure": diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index bf1e80a6816..c78db3fdf67 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -1157,6 +1157,12 @@ def _log_prob(self, x): _, _, _, new_text = self._upgrade(text) self.assertEqual(expected, new_text) + def test_CriticalSection_upgrade(self): + text = "tf.contrib.framework.CriticalSection(shared_name='blah')" + expected = "tf.CriticalSection(shared_name='blah')" + _, _, _, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + def test_sample_distorted_bounding_box(self): # pylint: disable=line-too-long text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"