STT-tensorflow/tensorflow/compiler/jit/xla_cpu_device.cc
George Karpenkov fbd998ddf8 [NFC] [TF2XLA] Reduce log spam: no need to say every time that we are not creating XLA devices
PiperOrigin-RevId: 346700269
Change-Id: Ic923100e7ddb7da4a6eaed6a483d5eaf69799988
2020-12-09 21:14:55 -08:00

122 lines
4.9 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
registration.autoclustering_policy =
compile_on_demand
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations =
RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT);
(void)registrations;
TF_ASSIGN_OR_RETURN(auto platform,
se::MultiPlatformManager::PlatformWithName("Host"));
XlaDevice::Options options;
options.platform = platform;
options.device_name_prefix = name_prefix;
options.device_name = DEVICE_XLA_CPU;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);
// Setting GpuDeviceInfo because eager runtime relies on the device
// context in tensorflow_gpu_device_info(). Also,
// tensorflow_gpu_device_info() == nullptr is used as an IsCPU test.
// We need XlaCpuDevice to be treated not as CPU because it allocates
// XlaTensors, not regular Tensors.
Status status = device->UseGpuDeviceInfo();
if (!status.ok()) {
errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT);
return status;
}
devices->push_back(std::move(device));
return Status::OK();
}
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow