diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c844f6d1801..618165d4b64 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -115,6 +115,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":flags", ":jit_compilation_passes", ":xla_device", ":xla_kernel_creator", # buildcleaner: keep diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 1cf71298b05..991ad82daa1 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -155,6 +155,7 @@ void AllocateAndParseFlags() { device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; + device_flags->tf_xla_enable_xla_devices = true; ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; @@ -187,6 +188,12 @@ void AllocateAndParseFlags() { "Switch a device into 'on-demand' mode, where instead of " "autoclustering ops are compiled one by one just-in-time."), + Flag("tf_xla_enable_xla_devices", + &device_flags->tf_xla_enable_xla_devices, + "Generate XLA_* devices, where placing a computation on such a " + "device" + "forces compilation by XLA. Deprecated."), + Flag("tf_xla_always_defer_compilation", &ops_flags->tf_xla_always_defer_compilation, ""), diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 87a89841b91..618e839fa36 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -87,6 +87,9 @@ struct XlaDeviceFlags { // Enabling this mode by a legacy flag is a temporary mechanism. When this // feature is battle-tested, we will switch this to be a session option. bool tf_xla_compile_on_demand; + + // Enables "XLA" devices if this flag is set. + bool tf_xla_enable_xla_devices; }; // Flags common to the _Xla* ops and their kernels. diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 85c09a027d3..446cd8944de 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory { }; Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { - devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); + XlaDeviceFlags* flags = GetXlaDeviceFlags(); + if (!flags->tf_xla_enable_xla_devices) { + LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + return Status::OK(); + } + devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); return Status::OK(); } @@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); + if (!flags->tf_xla_enable_xla_devices) { + LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + return Status::OK(); + } bool compile_on_demand = flags->tf_xla_compile_on_demand; XlaOpRegistry::DeviceRegistration registration; diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 8dc75c969a4..91943edd775 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -17,9 +17,11 @@ limitations under the License. // operators using XLA via the XLA "CUDA" (GPU) backend. #include + #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -61,6 +63,12 @@ class XlaGpuDeviceFactory : public DeviceFactory { }; Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { + XlaDeviceFlags* flags = GetXlaDeviceFlags(); + if (!flags->tf_xla_enable_xla_devices) { + LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + return Status::OK(); + } + auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. @@ -84,6 +92,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { Status XlaGpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { + XlaDeviceFlags* flags = GetXlaDeviceFlags(); + if (!flags->tf_xla_enable_xla_devices) { + LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + return Status::OK(); + } + XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index a43608bd434..b16dd3086fe 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -140,7 +140,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. - static void* registration_init = [®istry]() { + { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool cpu_global_jit = flags->tf_xla_cpu_global_jit; VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit; @@ -162,9 +162,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; } - return nullptr; - }(); - (void)registration_init; + } mutex_lock lock(registry.mutex_); auto it = registry.compilation_devices_.find(device_name);