Don't lower nested control flow if we're compiling to XLA.
PiperOrigin-RevId: 227579512
This commit is contained in:
parent
869f7cc28f
commit
0c2565de11
@ -145,6 +145,22 @@ class CondV2Test(test.TestCase):
|
||||
self.assertEqual(cond_op.type, "If")
|
||||
return output, cond_op
|
||||
|
||||
def _createNestedCond(self, name):
|
||||
"""Like _createCond but creates a nested cond_v2 call as well."""
|
||||
pred = constant_op.constant(True, name="pred")
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
|
||||
def true_fn():
|
||||
return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
|
||||
|
||||
def false_fn():
|
||||
return x + 2
|
||||
|
||||
output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
|
||||
cond_op = output.op.inputs[0].op
|
||||
self.assertEqual(cond_op.type, "If")
|
||||
return output, cond_op
|
||||
|
||||
def testDefaultName(self):
|
||||
with ops.Graph().as_default():
|
||||
_, cond_op = self._createCond(None)
|
||||
@ -645,9 +661,14 @@ class CondV2Test(test.TestCase):
|
||||
# Build the cond_v2 in an XLA context
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
cond_output, _ = self._createCond("cond")
|
||||
cond_output, cond_op = self._createCond("cond")
|
||||
xla_context.Exit()
|
||||
|
||||
# Check lowering attr is not set.
|
||||
with self.assertRaises(ValueError):
|
||||
cond_op.get_attr("_lower_using_switch_merge")
|
||||
|
||||
# Check the actual graph that is run.
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(cond_output, options=run_options, run_metadata=run_metadata)
|
||||
@ -672,6 +693,29 @@ class CondV2Test(test.TestCase):
|
||||
if_found,
|
||||
"An `If` op was not found, but the graph should not be lowered.")
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNestedLoweringDisabledInXLA(self):
|
||||
# Build the cond_v2 in an XLA context
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
_, cond_op = self._createNestedCond("cond")
|
||||
xla_context.Exit()
|
||||
|
||||
# Check lowering attr is not set for either If node.
|
||||
with self.assertRaises(ValueError):
|
||||
cond_op.get_attr("_lower_using_switch_merge")
|
||||
|
||||
nested_if_ops = []
|
||||
for func in ops.get_default_graph()._functions.values():
|
||||
nested_if_ops.extend(op for op in func._graph.get_operations()
|
||||
if op.type == "If")
|
||||
self.assertEqual(len(nested_if_ops), 1)
|
||||
with self.assertRaises(ValueError):
|
||||
nested_if_ops[0].get_attr("_lower_using_switch_merge")
|
||||
|
||||
# TODO(skyewm): check the actual graphs that are run once we have a way to
|
||||
# programmatically access those graphs.
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testLoweringDisabledWithSingleThreadedExecutorContext(self):
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
|
@ -57,6 +57,15 @@ def InXlaContext(graph):
|
||||
return GetContainingXLAContext(ctxt) is not None
|
||||
|
||||
|
||||
def GraphOrParentsInXlaContext(graph):
|
||||
while True:
|
||||
if InXlaContext(graph): return True
|
||||
try:
|
||||
graph = graph.outer_graph
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def IsInWhileLoop(op):
|
||||
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
||||
return GetContainingWhileContext(ctxt) is not None
|
||||
|
@ -114,7 +114,7 @@ def maybe_set_lowering_attr(op):
|
||||
Args:
|
||||
op: An `If` or `While` Operation.
|
||||
"""
|
||||
if (not control_flow_util.IsInXLAContext(op) and
|
||||
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||
context.context().get_function_call_options().executor_type
|
||||
!= "SINGLE_THREADED_EXECUTOR"):
|
||||
# pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user