Don't lower nested control flow if we're compiling to XLA.

PiperOrigin-RevId: 227579512
This commit is contained in:
Skye Wanderman-Milne 2019-01-02 14:13:27 -08:00 committed by TensorFlower Gardener
parent 869f7cc28f
commit 0c2565de11
3 changed files with 55 additions and 2 deletions

View File

@ -145,6 +145,22 @@ class CondV2Test(test.TestCase):
self.assertEqual(cond_op.type, "If") self.assertEqual(cond_op.type, "If")
return output, cond_op return output, cond_op
def _createNestedCond(self, name):
"""Like _createCond but creates a nested cond_v2 call as well."""
pred = constant_op.constant(True, name="pred")
x = constant_op.constant(1.0, name="x")
def true_fn():
return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
def false_fn():
return x + 2
output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
cond_op = output.op.inputs[0].op
self.assertEqual(cond_op.type, "If")
return output, cond_op
def testDefaultName(self): def testDefaultName(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
_, cond_op = self._createCond(None) _, cond_op = self._createCond(None)
@ -645,9 +661,14 @@ class CondV2Test(test.TestCase):
# Build the cond_v2 in an XLA context # Build the cond_v2 in an XLA context
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
cond_output, _ = self._createCond("cond") cond_output, cond_op = self._createCond("cond")
xla_context.Exit() xla_context.Exit()
# Check lowering attr is not set.
with self.assertRaises(ValueError):
cond_op.get_attr("_lower_using_switch_merge")
# Check the actual graph that is run.
run_options = config_pb2.RunOptions(output_partition_graphs=True) run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata() run_metadata = config_pb2.RunMetadata()
sess.run(cond_output, options=run_options, run_metadata=run_metadata) sess.run(cond_output, options=run_options, run_metadata=run_metadata)
@ -672,6 +693,29 @@ class CondV2Test(test.TestCase):
if_found, if_found,
"An `If` op was not found, but the graph should not be lowered.") "An `If` op was not found, but the graph should not be lowered.")
@test_util.run_deprecated_v1
def testNestedLoweringDisabledInXLA(self):
# Build the cond_v2 in an XLA context
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
_, cond_op = self._createNestedCond("cond")
xla_context.Exit()
# Check lowering attr is not set for either If node.
with self.assertRaises(ValueError):
cond_op.get_attr("_lower_using_switch_merge")
nested_if_ops = []
for func in ops.get_default_graph()._functions.values():
nested_if_ops.extend(op for op in func._graph.get_operations()
if op.type == "If")
self.assertEqual(len(nested_if_ops), 1)
with self.assertRaises(ValueError):
nested_if_ops[0].get_attr("_lower_using_switch_merge")
# TODO(skyewm): check the actual graphs that are run once we have a way to
# programmatically access those graphs.
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testLoweringDisabledWithSingleThreadedExecutorContext(self): def testLoweringDisabledWithSingleThreadedExecutorContext(self):
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:

View File

@ -57,6 +57,15 @@ def InXlaContext(graph):
return GetContainingXLAContext(ctxt) is not None return GetContainingXLAContext(ctxt) is not None
def GraphOrParentsInXlaContext(graph):
while True:
if InXlaContext(graph): return True
try:
graph = graph.outer_graph
except AttributeError:
return False
def IsInWhileLoop(op): def IsInWhileLoop(op):
ctxt = op._get_control_flow_context() # pylint: disable=protected-access ctxt = op._get_control_flow_context() # pylint: disable=protected-access
return GetContainingWhileContext(ctxt) is not None return GetContainingWhileContext(ctxt) is not None

View File

@ -114,7 +114,7 @@ def maybe_set_lowering_attr(op):
Args: Args:
op: An `If` or `While` Operation. op: An `If` or `While` Operation.
""" """
if (not control_flow_util.IsInXLAContext(op) and if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
context.context().get_function_call_options().executor_type context.context().get_function_call_options().executor_type
!= "SINGLE_THREADED_EXECUTOR"): != "SINGLE_THREADED_EXECUTOR"):
# pylint: disable=protected-access # pylint: disable=protected-access