[TF/XLA] Only enable XLA_ devices if TF_XLA_FLAGS=--tf_xla_enable_xla_devices is set. For now, set the flag to "true" by default.
In future, the flag will be switched to "false". PiperOrigin-RevId: 288939060 Change-Id: Ia0420edc9382f0ad0ae47ee4463f83677efe2e0c
This commit is contained in:
parent
e98a887ebe
commit
dbf459bcb0
@ -115,6 +115,7 @@ cc_library(
|
|||||||
srcs = ["xla_gpu_device.cc"],
|
srcs = ["xla_gpu_device.cc"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_device",
|
":xla_device",
|
||||||
":xla_kernel_creator", # buildcleaner: keep
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
|
@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
|
|||||||
|
|
||||||
device_flags = new XlaDeviceFlags;
|
device_flags = new XlaDeviceFlags;
|
||||||
device_flags->tf_xla_compile_on_demand = false;
|
device_flags->tf_xla_compile_on_demand = false;
|
||||||
|
device_flags->tf_xla_enable_xla_devices = true;
|
||||||
|
|
||||||
ops_flags = new XlaOpsCommonFlags;
|
ops_flags = new XlaOpsCommonFlags;
|
||||||
ops_flags->tf_xla_always_defer_compilation = false;
|
ops_flags->tf_xla_always_defer_compilation = false;
|
||||||
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
|
|||||||
"Switch a device into 'on-demand' mode, where instead of "
|
"Switch a device into 'on-demand' mode, where instead of "
|
||||||
"autoclustering ops are compiled one by one just-in-time."),
|
"autoclustering ops are compiled one by one just-in-time."),
|
||||||
|
|
||||||
|
Flag("tf_xla_enable_xla_devices",
|
||||||
|
&device_flags->tf_xla_enable_xla_devices,
|
||||||
|
"Generate XLA_* devices, where placing a computation on such a "
|
||||||
|
"device"
|
||||||
|
"forces compilation by XLA. Deprecated."),
|
||||||
|
|
||||||
Flag("tf_xla_always_defer_compilation",
|
Flag("tf_xla_always_defer_compilation",
|
||||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||||
|
|
||||||
|
@ -87,6 +87,9 @@ struct XlaDeviceFlags {
|
|||||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||||
// feature is battle-tested, we will switch this to be a session option.
|
// feature is battle-tested, we will switch this to be a session option.
|
||||||
bool tf_xla_compile_on_demand;
|
bool tf_xla_compile_on_demand;
|
||||||
|
|
||||||
|
// Enables "XLA" devices if this flag is set.
|
||||||
|
bool tf_xla_enable_xla_devices;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Flags common to the _Xla* ops and their kernels.
|
// Flags common to the _Xla* ops and their kernels.
|
||||||
|
@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
|||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||||
|
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
|||||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device.h"
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||||
@ -61,6 +63,12 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||||
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||||
if (!platform.ok()) {
|
if (!platform.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
@ -84,6 +92,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||||||
Status XlaGpuDeviceFactory::CreateDevices(
|
Status XlaGpuDeviceFactory::CreateDevices(
|
||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||||
registration.autoclustering_policy =
|
registration.autoclustering_policy =
|
||||||
|
@ -140,7 +140,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
|
|
||||||
// Lazily register the CPU and GPU JIT devices the first time
|
// Lazily register the CPU and GPU JIT devices the first time
|
||||||
// GetCompilationDevice is called.
|
// GetCompilationDevice is called.
|
||||||
static void* registration_init = [®istry]() {
|
{
|
||||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||||
VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
|
VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
|
||||||
@ -162,9 +162,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
registration.autoclustering_policy =
|
registration.autoclustering_policy =
|
||||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||||
}
|
}
|
||||||
return nullptr;
|
}
|
||||||
}();
|
|
||||||
(void)registration_init;
|
|
||||||
|
|
||||||
mutex_lock lock(registry.mutex_);
|
mutex_lock lock(registry.mutex_);
|
||||||
auto it = registry.compilation_devices_.find(device_name);
|
auto it = registry.compilation_devices_.find(device_name);
|
||||||
|
Loading…
Reference in New Issue
Block a user