diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 2c990261055..050da5ff6cc 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -145,6 +145,22 @@ class CondV2Test(test.TestCase): self.assertEqual(cond_op.type, "If") return output, cond_op + def _createNestedCond(self, name): + """Like _createCond but creates a nested cond_v2 call as well.""" + pred = constant_op.constant(True, name="pred") + x = constant_op.constant(1.0, name="x") + + def true_fn(): + return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1) + + def false_fn(): + return x + 2 + + output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name) + cond_op = output.op.inputs[0].op + self.assertEqual(cond_op.type, "If") + return output, cond_op + def testDefaultName(self): with ops.Graph().as_default(): _, cond_op = self._createCond(None) @@ -645,9 +661,14 @@ class CondV2Test(test.TestCase): # Build the cond_v2 in an XLA context xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() - cond_output, _ = self._createCond("cond") + cond_output, cond_op = self._createCond("cond") xla_context.Exit() + # Check lowering attr is not set. + with self.assertRaises(ValueError): + cond_op.get_attr("_lower_using_switch_merge") + + # Check the actual graph that is run. run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(cond_output, options=run_options, run_metadata=run_metadata) @@ -672,6 +693,29 @@ class CondV2Test(test.TestCase): if_found, "An `If` op was not found, but the graph should not be lowered.") + @test_util.run_deprecated_v1 + def testNestedLoweringDisabledInXLA(self): + # Build the cond_v2 in an XLA context + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + _, cond_op = self._createNestedCond("cond") + xla_context.Exit() + + # Check lowering attr is not set for either If node. + with self.assertRaises(ValueError): + cond_op.get_attr("_lower_using_switch_merge") + + nested_if_ops = [] + for func in ops.get_default_graph()._functions.values(): + nested_if_ops.extend(op for op in func._graph.get_operations() + if op.type == "If") + self.assertEqual(len(nested_if_ops), 1) + with self.assertRaises(ValueError): + nested_if_ops[0].get_attr("_lower_using_switch_merge") + + # TODO(skyewm): check the actual graphs that are run once we have a way to + # programmatically access those graphs. + @test_util.run_deprecated_v1 def testLoweringDisabledWithSingleThreadedExecutorContext(self): with self.session(graph=ops.Graph()) as sess: diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index 8f5442da5e4..ff0dff0042e 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -57,6 +57,15 @@ def InXlaContext(graph): return GetContainingXLAContext(ctxt) is not None +def GraphOrParentsInXlaContext(graph): + while True: + if InXlaContext(graph): return True + try: + graph = graph.outer_graph + except AttributeError: + return False + + def IsInWhileLoop(op): ctxt = op._get_control_flow_context() # pylint: disable=protected-access return GetContainingWhileContext(ctxt) is not None diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 5f56850884a..58917ad264a 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -114,7 +114,7 @@ def maybe_set_lowering_attr(op): Args: op: An `If` or `While` Operation. """ - if (not control_flow_util.IsInXLAContext(op) and + if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and context.context().get_function_call_options().executor_type != "SINGLE_THREADED_EXECUTOR"): # pylint: disable=protected-access