[TF/XLA] Only enable XLA_ devices if TF_XLA_FLAGS=--tf_xla_enable_xla_devices is set. For now, set the flag to "true" by default.

In future, the flag will be switched to "false".

PiperOrigin-RevId: 288939060
Change-Id: Ia0420edc9382f0ad0ae47ee4463f83677efe2e0c
This commit is contained in:
Anna R 2020-01-09 11:39:00 -08:00 committed by TensorFlower Gardener
parent e98a887ebe
commit dbf459bcb0
6 changed files with 37 additions and 5 deletions

View File

@ -115,6 +115,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
":flags",
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep

View File

@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_enable_xla_devices",
&device_flags->tf_xla_enable_xla_devices,
"Generate XLA_* devices, where placing a computation on such a "
"device"
"forces compilation by XLA. Deprecated."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),

View File

@ -87,6 +87,9 @@ struct XlaDeviceFlags {
// Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand;
// Enables "XLA" devices if this flag is set.
bool tf_xla_enable_xla_devices;
};
// Flags common to the _Xla* ops and their kernels.

View File

@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration;

View File

@ -17,9 +17,11 @@ limitations under the License.
// operators using XLA via the XLA "CUDA" (GPU) backend.
#include <set>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
@ -61,6 +63,12 @@ class XlaGpuDeviceFactory : public DeviceFactory {
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
@ -84,6 +92,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =

View File

@ -140,7 +140,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
// Lazily register the CPU and GPU JIT devices the first time
// GetCompilationDevice is called.
static void* registration_init = [&registry]() {
{
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
@ -162,9 +162,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
}
return nullptr;
}();
(void)registration_init;
}
mutex_lock lock(registry.mutex_);
auto it = registry.compilation_devices_.find(device_name);