Prepare critical_section_test for upcoming changes that enable v2 control flow as part of TF2 behavior in graph mode.
PiperOrigin-RevId: 260804842
This commit is contained in:
parent
6089f636b8
commit
2b7e42fe8e
@ -3891,6 +3891,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import critical_section_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -63,10 +64,12 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
@parameterized.named_parameters(
|
||||
("Inner%sOuter%s" % (inner, outer), inner, outer)
|
||||
for (inner, outer) in itertools.product(*([(False, True)] * 2)))
|
||||
@test_util.disable_control_flow_v2("b/135070612")
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.xla_allow_fallback("b/128495870")
|
||||
def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond):
|
||||
if (not context.executing_eagerly() and
|
||||
control_flow_v2_toggles.control_flow_v2_enabled()):
|
||||
self.skipTest("b/135070612")
|
||||
cs = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
v = resource_variable_ops.ResourceVariable(0.0, name="v")
|
||||
num_concurrent = 100
|
||||
|
Loading…
x
Reference in New Issue
Block a user