[XLA] Warn if autojit is enabled but CPU autojit is disabled.

To enable CPU autojit, you need to set

  TF_XLA_FLAGS=--tf_xla_cpu_global_jit

PiperOrigin-RevId: 239090111
This commit is contained in:
Justin Lebar 2019-03-18 16:55:51 -07:00 committed by TensorFlower Gardener
parent 17a8ee5006
commit ba994787e6
2 changed files with 60 additions and 26 deletions

View File

@ -1061,6 +1061,29 @@ static Status ShouldCompileClusterImpl(
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
global_jit_level != OptimizerOptions::OFF);
if (!*should_compile &&
registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested &&
device_type.type_string() == DEVICE_CPU) {
static std::once_flag once;
std::call_once(once, [] {
LOG(WARNING)
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
"XLA:CPU, either set that envvar, or use experimental_jit_scope "
"to enable XLA:CPU. To confirm that XLA is active, pass "
"--vmodule=xla_compilation_cache=1 (as a proper command-line "
"flag, not via TF_XLA_FLAGS) or set the envvar "
"XLA_FLAGS=--xla_hlo_profile.";
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_cpu_global_jit) {
LOG(WARNING)
<< "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
"perhaps it wasn't enabled at process startup?)";
}
});
}
VLOG(3) << (*should_compile ? "Compiling" : "Not compiling")
<< " cluster with device " << chosen_device;

View File

@ -135,6 +135,17 @@ def isum(s, maximum_iterations=None):
return r_s
def enqueue_print_op(s):
"""Enqueues an op that prints a message to be captured in the test."""
return logging_ops.print_v2("ControlFlowOpsTest: " + s)
def filter_test_messages(s):
"""Returns a list of messages printed by enqueue_print_op."""
prefix = "ControlFlowOpsTest: "
return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
@ -1142,10 +1153,12 @@ class ControlFlowTest(test.TestCase):
if test_util.is_gpu_available():
self.skipTest("b/128676188 causes OOM on opensource gpu tests")
print_prefix = "testCondAutoControlDeps: "
def branch_fn():
logging_ops.print_v2("A")
logging_ops.print_v2("B")
with ops.control_dependencies([logging_ops.print_v2("C")]):
enqueue_print_op("A")
enqueue_print_op("B")
with ops.control_dependencies([enqueue_print_op("C")]):
return constant_op.constant(10)
def build_cond():
@ -1161,11 +1174,11 @@ class ControlFlowTest(test.TestCase):
with self.cached_session():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_cond()), 10)
self.assertEqual(printed.contents(), "C\n")
self.assertEqual(["C"], filter_test_messages(printed.contents()))
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_nested_cond()), 10)
self.assertEqual(printed.contents(), "C\n")
self.assertEqual(["C"], filter_test_messages(printed.contents()))
# In defuns, all prints should execute in program order.
# This doesn't work with legacy control flow.
@ -1177,8 +1190,8 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(cond()), 10)
self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
printed.contents())
self.assertEqual(["A", "B", "C"],
filter_test_messages(printed.contents()))
@eager_function.defun
def nested_cond():
@ -1186,8 +1199,8 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(nested_cond()), 10)
self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
printed.contents())
self.assertEqual(["A", "B", "C"],
filter_test_messages(printed.contents()))
# wrap_function should prune.
def pruned_cond():
@ -1196,7 +1209,7 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_cond()), 10)
self.assertEqual(printed.contents(), "C\n")
self.assertEqual(["C"], filter_test_messages(printed.contents()))
def pruned_nested_cond():
return build_nested_cond()
@ -1204,7 +1217,8 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
self.assertEqual(printed.contents(), "C\n")
self.assertEqual(["C"], filter_test_messages(printed.contents()))
@test_util.disable_xla("b/128643646 PrintV2")
@test_util.run_in_graph_and_eager_modes
@ -1214,14 +1228,14 @@ class ControlFlowTest(test.TestCase):
if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return
def cond(i, unused_x):
logging_ops.print_v2("A")
enqueue_print_op("A")
return i < 2
def body(i, x):
logging_ops.print_v2("B")
with ops.control_dependencies([logging_ops.print_v2("C")]):
enqueue_print_op("B")
with ops.control_dependencies([enqueue_print_op("C")]):
x = array_ops.identity(x)
with ops.control_dependencies([logging_ops.print_v2("D")]):
with ops.control_dependencies([enqueue_print_op("D")]):
return i + 1, x
def build_while():
@ -1237,13 +1251,11 @@ class ControlFlowTest(test.TestCase):
with self.cached_session():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_while()[0]), 2)
self.assertTrue(printed.contents().endswith("D\nD\n"),
printed.contents())
self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
self.assertTrue(printed.contents().endswith("D\nD\n"),
printed.contents())
self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
# In defuns, all prints should execute in program order.
@eager_function.defun
@ -1252,8 +1264,8 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(while_loop()), 2)
self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
printed.contents())
self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
filter_test_messages(printed.contents()))
@eager_function.defun
def nested_while_loop():
@ -1263,9 +1275,8 @@ class ControlFlowTest(test.TestCase):
if not context.executing_eagerly():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(nested_while_loop()), 2)
self.assertTrue(
printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
printed.contents())
self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
filter_test_messages(printed.contents()))
# wrap_function should prune.
def pruned_while():
@ -1274,7 +1285,7 @@ class ControlFlowTest(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_while()), 2)
self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
def pruned_nested_while():
return build_nested_while()[0]
@ -1284,7 +1295,7 @@ class ControlFlowTest(test.TestCase):
if not context.executing_eagerly():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_nested_while()), 2)
self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
# Microbenchmark: 256,000 iterations/s.
def testWhile_1(self):