Expose CriticalSection in core as tf.CriticalSection.
PiperOrigin-RevId: 233146763
This commit is contained in:
parent
ce5a8d8ff7
commit
c7621b86bc
@ -32,7 +32,6 @@ tf_custom_op_py_library(
|
||||
"python/ops/arg_scope.py",
|
||||
"python/ops/audio_ops.py",
|
||||
"python/ops/checkpoint_ops.py",
|
||||
"python/ops/critical_section_ops.py",
|
||||
"python/ops/ops.py",
|
||||
"python/ops/prettyprint_ops.py",
|
||||
"python/ops/script_ops.py",
|
||||
@ -172,27 +171,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "critical_section_test",
|
||||
size = "medium",
|
||||
srcs = ["python/ops/critical_section_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
":framework_py",
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ops_test",
|
||||
size = "small",
|
||||
|
@ -94,8 +94,6 @@
|
||||
@@smart_constant_value
|
||||
@@smart_case
|
||||
|
||||
@@CriticalSection
|
||||
|
||||
@@BoundedTensorSpec
|
||||
@@TensorSpec
|
||||
|
||||
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.framework.python.ops.arg_scope import *
|
||||
from tensorflow.contrib.framework.python.ops.checkpoint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.critical_section_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.ops import *
|
||||
from tensorflow.contrib.framework.python.ops.prettyprint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.script_ops import *
|
||||
|
@ -2737,6 +2737,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "critical_section_ops",
|
||||
srcs = ["ops/critical_section_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":control_flow_ops",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":resource_variable_ops_gen",
|
||||
":tensor_array_ops",
|
||||
":util",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "list_ops",
|
||||
srcs = ["ops/list_ops.py"],
|
||||
@ -3148,6 +3164,7 @@ py_library(
|
||||
":clip_ops",
|
||||
":confusion_matrix",
|
||||
":control_flow_ops",
|
||||
":critical_section_ops",
|
||||
":cudnn_rnn_grad",
|
||||
":data_flow_grad",
|
||||
":data_flow_ops",
|
||||
|
@ -3592,3 +3592,24 @@ cuda_py_test(
|
||||
grpc_enabled = True,
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "critical_section_test",
|
||||
size = "medium",
|
||||
srcs = ["critical_section_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:critical_section_ops",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
@ -18,14 +18,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import critical_section_ops
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import critical_section_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -48,7 +49,7 @@ class CriticalSectionTest(test.TestCase):
|
||||
return array_ops.identity(c)
|
||||
|
||||
num_concurrent = 100
|
||||
r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
|
||||
r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
|
||||
self.evaluate(v.initializer)
|
||||
r_value = self.evaluate(r)
|
||||
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
|
||||
@ -74,7 +75,7 @@ class CriticalSectionTest(test.TestCase):
|
||||
array_ops.identity(inner_cond), true_fn, lambda: c)
|
||||
|
||||
def execute():
|
||||
return cs.execute(fn, 1.0, 2.0)
|
||||
return cs.execute(lambda: fn(1.0, 2.0))
|
||||
|
||||
r = [
|
||||
control_flow_ops.cond(array_ops.identity(outer_cond),
|
||||
@ -92,6 +93,7 @@ class CriticalSectionTest(test.TestCase):
|
||||
else:
|
||||
self.assertAllClose([0] * num_concurrent, r_value)
|
||||
|
||||
@test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
|
||||
def testCriticalSectionInParallelDoesntDeadlockOnError(self):
|
||||
# No eager mode execution of this test because eager does not
|
||||
# run fn() in parallel, which is where the deadlock could
|
||||
@ -103,12 +105,23 @@ class CriticalSectionTest(test.TestCase):
|
||||
error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
|
||||
with ops.control_dependencies([error]):
|
||||
return v.read_value()
|
||||
|
||||
num_concurrent = 2
|
||||
r = [cs.execute(fn, i) for i in range(num_concurrent)]
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def run_concurrently():
|
||||
return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
|
||||
|
||||
if not context.executing_eagerly():
|
||||
run_concurrently = run_concurrently()
|
||||
|
||||
self.evaluate(v.initializer)
|
||||
for _ in range(100):
|
||||
with self.assertRaisesOpError("Error"):
|
||||
self.evaluate(r)
|
||||
if context.executing_eagerly():
|
||||
run_concurrently()
|
||||
else:
|
||||
self.evaluate(run_concurrently)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCreateCriticalSectionFnReturnsOp(self):
|
||||
@ -123,17 +136,20 @@ class CriticalSectionTest(test.TestCase):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
num_concurrent = 100
|
||||
r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
|
||||
r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
|
||||
for _ in range(num_concurrent)]
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(r)
|
||||
final_v = self.evaluate(v)
|
||||
self.assertAllClose(2.0 * num_concurrent, final_v)
|
||||
|
||||
@test_util.run_v1_only("Collections don't exist in TF2")
|
||||
def testCollection(self):
|
||||
cs = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
self.assertIn(
|
||||
cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
|
||||
execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
|
||||
add = lambda x: x + 1
|
||||
execute = cs.execute(lambda: add(1.0), name="my_execute")
|
||||
execute_op = [
|
||||
x for x in execute.graph.get_operations()
|
||||
if "my_execute" in x.name and "MutexLock" in x.type
|
||||
@ -143,18 +159,21 @@ class CriticalSectionTest(test.TestCase):
|
||||
[signature.op for signature in
|
||||
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
def testRecursiveCriticalSectionAccessIsIllegal(self):
|
||||
# This does not work properly in eager mode. Eager users will
|
||||
# just hit a deadlock if they do this. But at least it'll be easier
|
||||
# to debug.
|
||||
cs = critical_section_ops.CriticalSection()
|
||||
add = lambda y: y + 1
|
||||
def fn(x):
|
||||
return cs.execute(lambda y: y + 1, x)
|
||||
return cs.execute(lambda: add(x))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"attempts to directly access the CriticalSection in which it "
|
||||
r"would be running"):
|
||||
cs.execute(fn, 1.0)
|
||||
cs.execute(lambda: fn(1.0))
|
||||
|
||||
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
|
||||
# This one is subtle; and we're being overly cautious here. The
|
||||
@ -174,24 +193,24 @@ class CriticalSectionTest(test.TestCase):
|
||||
# operations are finished before anything runs within the critical section.
|
||||
cs = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
fn = array_ops.identity
|
||||
to_capture = cs.execute(fn, 1.0)
|
||||
to_capture = cs.execute(lambda: fn(1.0))
|
||||
fn_captures = lambda x: x + to_capture
|
||||
to_capture_too = array_ops.identity(to_capture)
|
||||
|
||||
ex_0 = cs.execute(fn_captures, 1.0)
|
||||
ex_0 = cs.execute(lambda: fn_captures(1.0))
|
||||
|
||||
with ops.control_dependencies([to_capture]):
|
||||
# This is OK because to_capture will execute before this next call
|
||||
ex_1 = cs.execute(fn_captures, 1.0)
|
||||
ex_1 = cs.execute(lambda: fn_captures(1.0))
|
||||
|
||||
dependency = array_ops.identity(to_capture)
|
||||
|
||||
fn_captures_dependency = lambda x: x + dependency
|
||||
|
||||
ex_2 = cs.execute(fn_captures_dependency, 1.0)
|
||||
ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
|
||||
|
||||
with ops.control_dependencies([to_capture_too]):
|
||||
ex_3 = cs.execute(fn_captures_dependency, 1.0)
|
||||
ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
|
||||
|
||||
# Ensure there's no actual deadlock on to_execute.
|
||||
self.assertEquals(2.0, self.evaluate(ex_0))
|
||||
@ -217,6 +236,8 @@ class CriticalSectionTest(test.TestCase):
|
||||
body_implicit_capture,
|
||||
[0, 0],
|
||||
parallel_iterations=25)
|
||||
# For consistency between eager and graph mode.
|
||||
i_n = array_ops.identity(i_n)
|
||||
logging.warn(
|
||||
"\n==============\nRunning "
|
||||
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
|
||||
@ -242,6 +263,8 @@ class CriticalSectionTest(test.TestCase):
|
||||
body_implicit_capture_protected,
|
||||
[0, 0],
|
||||
parallel_iterations=25)
|
||||
# For consistency between eager and graph mode.
|
||||
i_n = array_ops.identity(i_n)
|
||||
logging.warn(
|
||||
"\n==============\nRunning "
|
||||
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
|
||||
@ -258,13 +281,15 @@ class CriticalSectionTest(test.TestCase):
|
||||
# This version is ok because j is an argument to fn and we can
|
||||
# ensure there's a control dependency on j.
|
||||
fn = lambda x: x + 1
|
||||
return (i + 1, cs.execute(fn, j))
|
||||
return (i + 1, cs.execute(lambda: fn(j)))
|
||||
|
||||
(i_n, j_n) = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 1000,
|
||||
body_args_capture,
|
||||
[0, 0],
|
||||
parallel_iterations=25)
|
||||
# For consistency between eager and graph mode.
|
||||
i_n = array_ops.identity(i_n)
|
||||
logging.warn(
|
||||
"\n==============\nRunning "
|
||||
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
|
||||
@ -277,20 +302,23 @@ class CriticalSectionTest(test.TestCase):
|
||||
"body_args_capture'\n"
|
||||
"==============\n")
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
|
||||
# This does not work properly in eager mode. Eager users will
|
||||
# just hit a deadlock if they do this. But at least it'll be easier
|
||||
# to debug.
|
||||
cs = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
cs_same = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
add = lambda x: x + 1
|
||||
def fn(x):
|
||||
return cs_same.execute(lambda x: x+1, x)
|
||||
return cs_same.execute(lambda: add(x))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"attempts to directly access the CriticalSection in which it "
|
||||
r"would be running"):
|
||||
cs.execute(fn, 1.0)
|
||||
cs.execute(lambda: fn(1.0))
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
def testMultipleCSExecutionsRequestSameResource(self):
|
||||
cs0 = critical_section_ops.CriticalSection()
|
||||
cs1 = critical_section_ops.CriticalSection()
|
||||
@ -328,8 +356,11 @@ class CriticalSectionTest(test.TestCase):
|
||||
# Note, here v must be a resource variable (or something similar),
|
||||
# otherwise it gets hoisted into the while_loop by the time we add
|
||||
# control dependencies to the lock_op.
|
||||
def body(i):
|
||||
add_j = lambda j: v + j + 1
|
||||
return cs.execute(lambda: add_j(i))
|
||||
out = control_flow_ops.while_loop(
|
||||
lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
|
||||
lambda i: i < 10, body, [0])
|
||||
self.evaluate(v.initializer)
|
||||
self.assertEqual(10, self.evaluate(out))
|
||||
|
@ -31,6 +31,10 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
__all__ = ["CriticalSection"]
|
||||
|
||||
|
||||
# Graph Keys
|
||||
@ -66,6 +70,7 @@ def _get_colocation(op):
|
||||
return None
|
||||
|
||||
|
||||
@tf_export("CriticalSection")
|
||||
class CriticalSection(object):
|
||||
"""Critical section.
|
||||
|
||||
@ -179,37 +184,36 @@ class CriticalSection(object):
|
||||
def name(self):
|
||||
return self._handle.op.name
|
||||
|
||||
def execute(self, fn, *args, **kwargs):
|
||||
"""Execute function `fn(*args, **kwargs)` inside the CriticalSection.
|
||||
def execute(self, fn, exclusive_resource_access=True, name=None):
|
||||
"""Execute function `fn()` inside the critical section.
|
||||
|
||||
`fn` should not accept any arguments. To add extra arguments to when
|
||||
calling `fn` in the critical section, create a lambda:
|
||||
|
||||
```python
|
||||
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: The function to execute. Must return at least one tensor.
|
||||
*args: Additional positional arguments to `fn`.
|
||||
**kwargs: Additional keyword arguments to `fn`.
|
||||
Several keywords are reserved for `execute`. These are:
|
||||
|
||||
- name; The name to use when creating the execute operation.
|
||||
- exclusive_resource_access; Whether the resources required by
|
||||
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
|
||||
You may want to set this to `False` if you will be accessing a
|
||||
resource in read-only mode in two different CriticalSections.
|
||||
exclusive_resource_access: Whether the resources required by
|
||||
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
|
||||
You may want to set this to `False` if you will be accessing a
|
||||
resource in read-only mode in two different CriticalSections.
|
||||
name: The name to use when creating the execute operation.
|
||||
|
||||
Returns:
|
||||
The tensors returned from `fn(*args, **kwargs)`.
|
||||
The tensors returned from `fn()`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
|
||||
or lazy way that may cause a deadlock.
|
||||
ValueError: If `exclusive_resource_access` is not provided (is `True`) and
|
||||
ValueError: If `exclusive_resource_access == True` and
|
||||
another `CriticalSection` has an execution requesting the same
|
||||
resources as in `*args`, `**kwargs`, and any additionally captured
|
||||
inputs in `fn`. Note, even if `exclusive_resource_access` is `True`,
|
||||
if another execution in another `CriticalSection` was created without
|
||||
`exclusive_resource_access=True`, a `ValueError` will be raised.
|
||||
resources as `fn``. Note, even if `exclusive_resource_access` is
|
||||
`True`, if another execution in another `CriticalSection` was created
|
||||
without `exclusive_resource_access=True`, a `ValueError` will be raised.
|
||||
"""
|
||||
name = kwargs.pop("name", None)
|
||||
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
|
||||
|
||||
with ops.name_scope(name, "critical_section_execute", []):
|
||||
|
||||
# Ensure that mutex locking only happens *after* all args and
|
||||
@ -222,7 +226,7 @@ class CriticalSection(object):
|
||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||
existing_ops = ops.get_default_graph().get_operations()
|
||||
with ops.control_dependencies([lock]):
|
||||
r = fn(*args, **kwargs)
|
||||
r = fn()
|
||||
# TODO(ebrevdo): If creating critical sections in a python loop, this
|
||||
# makes graph creation time quadratic. Revisit if this
|
||||
# becomes a problem.
|
||||
@ -230,7 +234,7 @@ class CriticalSection(object):
|
||||
.difference(existing_ops))
|
||||
else:
|
||||
with ops.control_dependencies([lock]):
|
||||
r = fn(*args, **kwargs)
|
||||
r = fn()
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self._add_control_dependencies_to_lock(created_ops, lock.op)
|
@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=rede
|
||||
# pylint: enable=redefined-builtin
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.ops.control_flow_ops import while_loop
|
||||
from tensorflow.python.ops.critical_section_ops import *
|
||||
from tensorflow.python.ops.data_flow_ops import *
|
||||
from tensorflow.python.ops.functional_ops import *
|
||||
from tensorflow.python.ops.gradients import *
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.CriticalSection"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "execute"
|
||||
argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
}
|
@ -32,6 +32,10 @@ tf_module {
|
||||
name: "ConfigProto"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "CriticalSection"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.CriticalSection"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "execute"
|
||||
argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
}
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "AggregationMethod"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "CriticalSection"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -540,6 +540,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.data.experimental.unbatch",
|
||||
"tf.contrib.data.unique":
|
||||
"tf.data.experimental.unique",
|
||||
"tf.contrib.framework.CriticalSection":
|
||||
"tf.CriticalSection",
|
||||
"tf.contrib.framework.is_tensor":
|
||||
"tf.is_tensor",
|
||||
"tf.contrib.framework.nest.assert_same_structure":
|
||||
|
@ -1157,6 +1157,12 @@ def _log_prob(self, x):
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
|
||||
def test_CriticalSection_upgrade(self):
|
||||
text = "tf.contrib.framework.CriticalSection(shared_name='blah')"
|
||||
expected = "tf.CriticalSection(shared_name='blah')"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
|
||||
def test_sample_distorted_bounding_box(self):
|
||||
# pylint: disable=line-too-long
|
||||
text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"
|
||||
|
Loading…
Reference in New Issue
Block a user