286 lines
12 KiB
286 lines
12 KiB
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
// The XlaDevice executes a TensorFlow graph using the XLA linear algebra
// runtime.
// Operators assigned to an XlaDevice are compiled into XLA computations.
// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU).
#include <set>
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace tensorflow {
class XlaDevice : public LocalDevice {
// Given a tensor, sets `xla::Shape*` the shape of tensor's representation
// on device, fully padded. On error, the contents of `xla::Shape*`
// are undefined.
typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
// Wrapper class to store metadata about the XlaDevice, where it can be
// retrieved e.g., when lazily creating the XlaCompilationCache device.
class Metadata {
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
// The index of the device on this host.
int device_ordinal() const;
se::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
bool UseMultipleStreams() const { return use_multiple_streams_; }
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
PaddedShapeFn padded_shape_fn_;
const bool use_multiple_streams_;
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
static Status GetMetadata(OpKernelConstruction* ctx,
const Metadata** metadata);
struct Options {
// The StreamExecutor platform. Not owned. Must be non-null.
se::Platform* platform = nullptr;
// The device name's prefix (e.g., "/task:7")
string device_name_prefix;
// The name of the XLA device (e.g., "XLA_CPU")
string device_name;
// The number of the device.
int device_ordinal = -1;
// The name of the compilation device (e.g., "XLA_CPU_JIT");
string compilation_device_name;
// If 'use_multiple_streams' is true, we create separate streams for
// compute, host-to-device, and device-to-host communication.
bool use_multiple_streams = false;
// A function that describes how the on-host shapes of
// a) argument and return value, for entry computations
// b) variables, for all computations,
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared
// shapes for computations. Must be non-null.
XlaCompiler::ShapeRepresentationFn shape_representation_fn;
// If padded_shape_fn is empty, a default implementation that returns
// the logical on-device shape without padding is used.
PaddedShapeFn padded_shape_fn;
// Set of devices to use. This controls which of the devices on the given
// platform will have resources allocated. For GPUs this will be
// filled from visible_gpu_devices list from session configuration.
absl::optional<std::set<int>> allowed_devices;
// Creates a new XLA Device.
XlaDevice(const SessionOptions& session_options, const Options& options);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override;
void Sync(const DoneCallback& done) override;
Status TryGetDeviceContext(DeviceContext** out_context) override
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
// Allocate tensor on fast memory space. This is only applied to the new TPU
// hardware which has faster read/write memory. If the hardware doesn't
// have such memory space, we fallback to the ordinary memory space.
Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
const Metadata& metadata() { return xla_metadata_; }
// Ensures the DeviceContext associated with this XlaDevice is created and
// valid (i.e. all streams are ok). If any state is not valid, a new
// DeviceContext will be created.
// TODO(b/111859745): The Eager context needs to call this method to recover
// from failures.
Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to return 'sync_on_completion' for
// AllowsSyncOnCompletion().
void SetAllowsSyncOnCompletion(bool sync_on_completion)
bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
// Installs an error handling callback when RefreshStatus sees !status.ok().
void SetHandleDeviceErrorCallback(std::function<Status()> callback);
Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
// Return a pair of device context, the second one is fast_mem device context.
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
Status MakeTensorFromProto(XlaDeviceContext* device_context,
const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor);
// Handles error when RefreshStatus sees !status.ok().
Status HandleDeviceError();
mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
const int device_ordinal_;
// The name of the device that is used to compile Ops for this XlaDevice.
const DeviceType jit_device_name_;
// The platform for this device.
se::Platform* const platform_; // Not owned.
// Intra-op threads to spawn (from SessionOptions).
const int intra_op_parallelism_threads_;
// Memory allocator associated with this device.
Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_device stream or borrowing a stream
// for each device to host transfer.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
// If use_multiple_streams_, transfers between different devices are performed
// using these streams.
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// The device context accessed by all users of the XlaDevice, set by calls to
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
// The device context will allocate memory on fast memory space on TPU.
// XlaDeviceContext is a ref-counted object.
XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
// True if the device allows XlaDevice::Sync to be called on completion
// regardless of status.
bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
// A callback that will be invoked when RefreshStatus sees a status error.
std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
// Set of devices to use. This controls which of the devices on the given
// platform will have resources allocated. For GPUs this will be
// filled from visible_gpu_devices list from session configuration.
absl::optional<std::set<int>> allowed_devices_;
// Builds OpKernel registrations on 'device' for the JIT operators
// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
// object that encapsulates the kernel registrations.
struct XlaDeviceOpRegistrations {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device);
} // namespace tensorflow