From ba994787e69ade025949c79fc95932e83ac0a1af Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 18 Mar 2019 16:55:51 -0700 Subject: [PATCH] [XLA] Warn if autojit is enabled but CPU autojit is disabled. To enable CPU autojit, you need to set TF_XLA_FLAGS=--tf_xla_cpu_global_jit PiperOrigin-RevId: 239090111 --- .../compiler/jit/mark_for_compilation_pass.cc | 23 +++++++ .../kernel_tests/control_flow_ops_py_test.py | 63 +++++++++++-------- 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 184afe3aa8e..d3376a788d3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1061,6 +1061,29 @@ static Status ShouldCompileClusterImpl( XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && global_jit_level != OptimizerOptions::OFF); + if (!*should_compile && + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested && + device_type.type_string() == DEVICE_CPU) { + static std::once_flag once; + std::call_once(once, [] { + LOG(WARNING) + << "(One-time warning): Not using XLA:CPU for cluster because envvar " + "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want " + "XLA:CPU, either set that envvar, or use experimental_jit_scope " + "to enable XLA:CPU. To confirm that XLA is active, pass " + "--vmodule=xla_compilation_cache=1 (as a proper command-line " + "flag, not via TF_XLA_FLAGS) or set the envvar " + "XLA_FLAGS=--xla_hlo_profile."; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_cpu_global_jit) { + LOG(WARNING) + << "(Although the tf_xla_cpu_global_jit flag is currently enabled, " + "perhaps it wasn't enabled at process startup?)"; + } + }); + } + VLOG(3) << (*should_compile ? "Compiling" : "Not compiling") << " cluster with device " << chosen_device; diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 9208ad55240..6cdd88630f4 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -135,6 +135,17 @@ def isum(s, maximum_iterations=None): return r_s +def enqueue_print_op(s): + """Enqueues an op that prints a message to be captured in the test.""" + return logging_ops.print_v2("ControlFlowOpsTest: " + s) + + +def filter_test_messages(s): + """Returns a list of messages printed by enqueue_print_op.""" + prefix = "ControlFlowOpsTest: " + return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)] + + @test_util.with_control_flow_v2 class ControlFlowTest(test.TestCase): @@ -1142,10 +1153,12 @@ class ControlFlowTest(test.TestCase): if test_util.is_gpu_available(): self.skipTest("b/128676188 causes OOM on opensource gpu tests") + print_prefix = "testCondAutoControlDeps: " + def branch_fn(): - logging_ops.print_v2("A") - logging_ops.print_v2("B") - with ops.control_dependencies([logging_ops.print_v2("C")]): + enqueue_print_op("A") + enqueue_print_op("B") + with ops.control_dependencies([enqueue_print_op("C")]): return constant_op.constant(10) def build_cond(): @@ -1161,11 +1174,11 @@ class ControlFlowTest(test.TestCase): with self.cached_session(): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(build_cond()), 10) - self.assertEqual(printed.contents(), "C\n") + self.assertEqual(["C"], filter_test_messages(printed.contents())) with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(build_nested_cond()), 10) - self.assertEqual(printed.contents(), "C\n") + self.assertEqual(["C"], filter_test_messages(printed.contents())) # In defuns, all prints should execute in program order. # This doesn't work with legacy control flow. @@ -1177,8 +1190,8 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(cond()), 10) - self.assertTrue(printed.contents().endswith("A\nB\nC\n"), - printed.contents()) + self.assertEqual(["A", "B", "C"], + filter_test_messages(printed.contents())) @eager_function.defun def nested_cond(): @@ -1186,8 +1199,8 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(nested_cond()), 10) - self.assertTrue(printed.contents().endswith("A\nB\nC\n"), - printed.contents()) + self.assertEqual(["A", "B", "C"], + filter_test_messages(printed.contents())) # wrap_function should prune. def pruned_cond(): @@ -1196,7 +1209,7 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(pruned_cond()), 10) - self.assertEqual(printed.contents(), "C\n") + self.assertEqual(["C"], filter_test_messages(printed.contents())) def pruned_nested_cond(): return build_nested_cond() @@ -1204,7 +1217,8 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(pruned_nested_cond()), 10) - self.assertEqual(printed.contents(), "C\n") + self.assertEqual(["C"], filter_test_messages(printed.contents())) + @test_util.disable_xla("b/128643646 PrintV2") @test_util.run_in_graph_and_eager_modes @@ -1214,14 +1228,14 @@ class ControlFlowTest(test.TestCase): if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return def cond(i, unused_x): - logging_ops.print_v2("A") + enqueue_print_op("A") return i < 2 def body(i, x): - logging_ops.print_v2("B") - with ops.control_dependencies([logging_ops.print_v2("C")]): + enqueue_print_op("B") + with ops.control_dependencies([enqueue_print_op("C")]): x = array_ops.identity(x) - with ops.control_dependencies([logging_ops.print_v2("D")]): + with ops.control_dependencies([enqueue_print_op("D")]): return i + 1, x def build_while(): @@ -1237,13 +1251,11 @@ class ControlFlowTest(test.TestCase): with self.cached_session(): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(build_while()[0]), 2) - self.assertTrue(printed.contents().endswith("D\nD\n"), - printed.contents()) + self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(build_nested_while()[0]), 2) - self.assertTrue(printed.contents().endswith("D\nD\n"), - printed.contents()) + self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) # In defuns, all prints should execute in program order. @eager_function.defun @@ -1252,8 +1264,8 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(while_loop()), 2) - self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"), - printed.contents()) + self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], + filter_test_messages(printed.contents())) @eager_function.defun def nested_while_loop(): @@ -1263,9 +1275,8 @@ class ControlFlowTest(test.TestCase): if not context.executing_eagerly(): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(nested_while_loop()), 2) - self.assertTrue( - printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"), - printed.contents()) + self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], + filter_test_messages(printed.contents())) # wrap_function should prune. def pruned_while(): @@ -1274,7 +1285,7 @@ class ControlFlowTest(test.TestCase): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(pruned_while()), 2) - self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents()) + self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) def pruned_nested_while(): return build_nested_while()[0] @@ -1284,7 +1295,7 @@ class ControlFlowTest(test.TestCase): if not context.executing_eagerly(): with self.captureWritesToStream(sys.stderr) as printed: self.assertEqual(self.evaluate(pruned_nested_while()), 2) - self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents()) + self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) # Microbenchmark: 256,000 iterations/s. def testWhile_1(self):