[TF/XLA] Only enable XLA_ devices if TF_XLA_FLAGS=--tf_xla_enable_xla_devices is set. For now, set the flag to "true" by default.
In future, the flag will be switched to "false". PiperOrigin-RevId: 288939060 Change-Id: Ia0420edc9382f0ad0ae47ee4463f83677efe2e0c
This commit is contained in:
parent
e98a887ebe
commit
dbf459bcb0
@ -115,6 +115,7 @@ cc_library(
|
||||
srcs = ["xla_gpu_device.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
|
@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
|
||||
|
||||
device_flags = new XlaDeviceFlags;
|
||||
device_flags->tf_xla_compile_on_demand = false;
|
||||
device_flags->tf_xla_enable_xla_devices = true;
|
||||
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
|
||||
Flag("tf_xla_enable_xla_devices",
|
||||
&device_flags->tf_xla_enable_xla_devices,
|
||||
"Generate XLA_* devices, where placing a computation on such a "
|
||||
"device"
|
||||
"forces compilation by XLA. Deprecated."),
|
||||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
|
||||
|
@ -87,6 +87,9 @@ struct XlaDeviceFlags {
|
||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||
// feature is battle-tested, we will switch this to be a session option.
|
||||
bool tf_xla_compile_on_demand;
|
||||
|
||||
// Enables "XLA" devices if this flag is set.
|
||||
bool tf_xla_enable_xla_devices;
|
||||
};
|
||||
|
||||
// Flags common to the _Xla* ops and their kernels.
|
||||
|
@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
@ -61,6 +63,12 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
@ -84,6 +92,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
Status XlaGpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
|
@ -140,7 +140,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
|
||||
// Lazily register the CPU and GPU JIT devices the first time
|
||||
// GetCompilationDevice is called.
|
||||
static void* registration_init = [®istry]() {
|
||||
{
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
|
||||
@ -162,9 +162,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||
}
|
||||
return nullptr;
|
||||
}();
|
||||
(void)registration_init;
|
||||
}
|
||||
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto it = registry.compilation_devices_.find(device_name);
|
||||
|
Loading…
Reference in New Issue
Block a user