STT-tensorflow/tensorflow/compiler/jit/xla_tpu_device.cc
Frank Chen c5f474d1c8 [TPU] Move TPU node and system device initializers to compiler/jit
This is part of a series of changes to move TPU-related code to better locations so that the TensorFlow build isn't confused and TPU-based TF can be built without the define=framework_shared_object=false flag.

PiperOrigin-RevId: 352726495
Change-Id: Idc23455a8289c4a2546edad9ca59e9207a7492ce
2021-01-19 22:43:21 -08:00

487 lines
20 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_tpu_device.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_node_device_util.h"
#include "tensorflow/core/tpu/virtual_device.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
namespace tensorflow {
namespace {
static bool tpu_autoclustering_flag = false;
static bool tpu_xla_device_failure_closes_chips_flag = true;
static bool tpu_use_substreams_for_cross_tpu_device_transfers_flag = true;
// Given a tensor of `shape` and `type`, as what shape should it be stored on
// the TPU device? This function tranposes or flattens the excessively-padded
// tensors to rank 1, but leaves other tensor shapes alone.
xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
DataType type,
bool use_fast_memory) {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
tensorflow::TensorShapeToXLAShape(type, shape, &xla_shape));
ApiConverter::StackHelper<XLA_Shape> se_shape(xla_shape);
ApiConverter::StackHelper<XLA_Shape> tpu_shape;
StatusHelper status;
tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn(
&se_shape.value, type, use_fast_memory, &tpu_shape.value,
status.c_status);
if (!status.status().ok()) {
return status.status();
}
return tpu_shape.AsCpp<xla::Shape>();
}
// Given a tensor, returns the shape of its representation on device,
// fully padded. Contents of `shape` are undefined on error.
Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return errors::InvalidArgument(
"Expected an XlaTensor when computing padded shape");
}
if (!xla_tensor->has_shaped_buffer()) {
return errors::InvalidArgument(
"XlaTensor is expected to have device memory allocated when "
"computing padded shape");
}
const xla::Shape& on_device_shape =
xla_tensor->shaped_buffer().on_device_shape();
StatusHelper status;
ApiConverter::StackHelper<XLA_Shape> se_shape(on_device_shape);
ApiConverter::StackHelper<XLA_Shape> tpu_shape;
tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn(
&se_shape.value, &tpu_shape.value, status.c_status);
if (!status.ok()) {
return status.status();
}
*shape = tpu_shape.AsCpp<xla::Shape>();
return Status::OK();
}
// Check if TPU has been initialized. TPU initialization is not necessary
// for 1x1.
Status CheckIfTPUInitialized() {
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
if (!tpu_platform->Initialized()) {
return errors::FailedPrecondition(
"The TPU system has not been initialized.");
}
return Status::OK();
}
// Implementation of TPU->TPU device copies that copies over the dedicated TPU
// interconnects, which is much faster than PCIe or the host network.
// TODO(b/117426293): This implementation is only called for direct interconnect
// transfers between TPU devices attached to the same host. Ideally, we would
// generalize this support to direct interconnect transfers across hosts, but
// currently the CopyTensor infrastructure seems to the network topology is
// strictly hierarchical, that is, transfers between devices on different hosts
// can only take place using the host network.
void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context,
DeviceContext* dst_dev_context, Device* src,
Device* dst, AllocatorAttributes src_allocator_attrs,
AllocatorAttributes dst_allocator_attrs,
const Tensor* input, Tensor* output,
int dev_to_dev_stream_index, StatusCallback done) {
XlaDeviceContext* const src_xla_context =
static_cast<XlaDeviceContext*>(src_dev_context);
XlaDeviceContext* const dst_xla_context =
static_cast<XlaDeviceContext*>(dst_dev_context);
static const bool should_use_substream =
tpu_use_substreams_for_cross_tpu_device_transfers_flag;
auto impl = [&]() -> Status {
if (src->name() != dst->name()) {
Status s = CheckIfTPUInitialized();
if (!s.ok()) {
done(s);
return Status::OK();
}
}
if (input->shape().num_elements() == 0) {
// Zero-element tensors have no backing buffers.
done(Status::OK());
return Status::OK();
}
se::Stream* const src_compute_stream = src_xla_context->stream();
TF_RET_CHECK(src_compute_stream != nullptr);
TF_RET_CHECK(input->dtype() == output->dtype())
<< "input type: " << DataTypeString(input->dtype()) << " output type "
<< DataTypeString(output->dtype());
TF_RET_CHECK(input->shape() == output->shape());
TF_RET_CHECK(DMAHelper::CanUseDMA(input));
auto* const src_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
src_compute_stream->implementation());
se::Stream* dst_compute_stream = dst_xla_context->stream();
auto* const dst_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
dst_compute_stream->implementation());
if (src_compute_stream_impl->IsSameSharedMemoryLocation(
dst_compute_stream_impl)) {
// Surprisingly, this path does get triggered in practice.
*output = *input;
done(Status::OK());
return Status::OK();
}
// To avoid stream exhaustion, we pick a substream from a pool if enabled.
se::Stream* const device_to_device_master_stream =
should_use_substream ? dst_xla_context->device_to_device_stream(0)
: nullptr;
se::Stream* const dst_device_to_device_stream =
should_use_substream
? device_to_device_master_stream->GetOrCreateSubStream()
: dst_xla_context->GetDeviceToDeviceStream();
TF_RET_CHECK(dst_device_to_device_stream != nullptr);
auto return_substream = gtl::MakeCleanup(
[device_to_device_master_stream, dst_device_to_device_stream] {
if (device_to_device_master_stream) {
device_to_device_master_stream->ReturnSubStream(
dst_device_to_device_stream);
}
});
auto* const dst_device_to_device_stream_impl =
static_cast<tpu::TpuStreamInterface*>(
dst_device_to_device_stream->implementation());
const int dst_device_ordinal =
dst_xla_context->stream()->parent()->device_ordinal();
XlaTensor* const xla_input = XlaTensor::FromTensor(input);
TF_RET_CHECK(xla_input != nullptr && xla_input->has_shaped_buffer());
XlaTensor* const xla_output = XlaTensor::FromTensor(output);
TF_RET_CHECK(xla_output != nullptr && !xla_output->has_shaped_buffer());
TF_RET_CHECK(input->shape() == output->shape());
TF_ASSIGN_OR_RETURN(xla::Shape shape,
dst_xla_context->shape_representation_fn()(
input->shape(), input->dtype(),
/*use_fast_memory=*/false));
TF_RETURN_IF_ERROR(xla_output->AllocateShapedBuffer(
input->dtype(), shape, dst_xla_context->client(), dst_device_ordinal));
VLOG(2) << "TpuDeviceToDeviceCopy: src: "
<< src_compute_stream->parent()->device_ordinal() << ", "
<< " dst: " << dst_compute_stream->parent()->device_ordinal()
<< ", "
<< " input buffers: " << xla_input->shaped_buffer().ToString()
<< " output buffers: " << xla_output->shaped_buffer().ToString();
// Wait for definition event of the source tensor so the input buffers are
// available.
xla_input->WaitForDefinitionEventOnStream(dst_device_to_device_stream);
// Wait for the destination tensor buffers to be ready, if they are not
// available for an immediate write.
if (!dst_xla_context->transfer_manager()->CanShapedBufferBeAccessedNow(
dst_compute_stream->parent(), xla_output->shaped_buffer())) {
dst_device_to_device_stream->ThenWaitFor(dst_compute_stream);
// If the representation is a tuple, we also must wait for the tuple index
// buffers to be available on the destination host to device transfer
// stream.
if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
dst_xla_context->host_to_device_stream()->ThenWaitFor(
dst_compute_stream);
}
}
for (const auto& leaf : xla_input->shaped_buffer().buffers().leaves()) {
const xla::ShapeIndex& index = leaf.first;
const se::DeviceMemoryBase& input_buffer = leaf.second;
const se::DeviceMemoryBase& output_buffer =
xla_output->shaped_buffer().buffer(index);
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
<< "input: " << input_buffer.size()
<< " output: " << output_buffer.size();
TF_RETURN_IF_ERROR(
dst_device_to_device_stream_impl->EnqueueOnTpuDeviceSendRecvLocal(
input_buffer, output_buffer));
}
// If the on-device shape is a tuple, write new tuple index buffers.
if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
TF_RETURN_IF_ERROR(
dst_xla_context->transfer_manager()->WriteTupleIndexTablesAsync(
dst_xla_context->host_to_device_stream(),
xla_output->shaped_buffer()));
// We need a single definition event for an XlaTensor, so make the
// device to device stream wait for the stream that wrote the tuple index
// tables on the destination device. Should this prove to be a problem,
// we can always extend XlaTensor to take a pair of definition events that
// must all be satisfied, or add an Event::Merge() API that allows us to
// build an event that is triggered when all of its dependencies are
// triggered.
dst_device_to_device_stream->ThenWaitFor(
dst_xla_context->host_to_device_stream());
}
auto definition_event =
std::make_shared<se::Event>(dst_xla_context->stream()->parent());
TF_RET_CHECK(definition_event->Init()) << "Event failed to initialize!";
dst_device_to_device_stream->ThenRecordEvent(definition_event.get());
xla_output->ResetDefinitionEvent(std::move(definition_event),
dst_device_to_device_stream);
// The input must remain alive until the transfer completes, so we keep a
// reference. We also wait until the transfer completes before calling
// done().
// The latter may be too conservative, but given the host is involved in
// waiting for the transfer to complete anyway there is probably little
// downside. If we were to add the ability for computations to wait directly
// on transfers, then we might want to rethink this property.
// Also ideally this host callback should be on source stream rather than
// destination stream, but when this function returns, the send requests
// might not be enqueued to the stream yet, we put it on destination stream.
TensorReference input_reference(*input);
std::move(return_substream).release();
dst_device_to_device_stream->ThenDoHostCallback(
[input_reference, done = std::move(done),
device_to_device_master_stream, dst_device_to_device_stream] {
if (device_to_device_master_stream) {
device_to_device_master_stream->ReturnSubStream(
dst_device_to_device_stream);
}
input_reference.Unref();
done(Status::OK());
});
return Status::OK();
};
Status status = impl();
if (!status.ok()) {
done(status);
}
}
class TpuNodeDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
tpu::TpuPlatformInterface* platform =
tpu::TpuPlatformInterface::GetRegisteredPlatform();
if (platform == nullptr) {
// If we don't have a platform registered, then we have no devices.
return Status::OK();
}
int device_count = platform->VisibleDeviceCount();
for (int i = 0; i < device_count; ++i) {
const string device_name = absl::StrCat("/physical_device:TPU:", i);
devices->push_back(device_name);
}
return Status::OK();
}
Status TpuNodeDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
tpu::TpuPlatformInterface* platform =
tpu::TpuPlatformInterface::GetRegisteredPlatform();
if (platform == nullptr) {
// If we don't have a platform registered, then we should not create any.
return Status::OK();
}
if (platform != nullptr && platform->ShouldRegisterTpuDeviceToDeviceCopy()) {
RegisterTpuDeviceToDeviceCopy();
}
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_TPU_XLA_JIT;
registration.autoclustering_policy =
tpu_autoclustering_flag
? XlaOpRegistry::AutoclusteringPolicy::kAlways
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_TPU_NODE, registration);
static XlaDeviceOpRegistrations* registrations =
RegisterXlaDeviceKernels(DEVICE_TPU_NODE, DEVICE_TPU_XLA_JIT);
(void)registrations;
int device_count = platform->VisibleDeviceCount();
VLOG(1) << "Creating " << device_count << " TPU devices";
for (int i = 0; i < device_count; ++i) {
TF_RETURN_IF_ERROR(tpu::TpuNodeContext::Initialize(i));
XlaDevice::Options options;
options.platform = platform;
options.device_name_prefix = name_prefix;
options.device_name = DEVICE_TPU_NODE;
options.device_ordinal = i;
options.compilation_device_name = DEVICE_TPU_XLA_JIT;
options.use_multiple_streams = true;
options.shape_representation_fn = &TpuShapeRepresentation;
options.padded_shape_fn = &TpuPaddedShapeFn;
auto device = absl::make_unique<XlaDevice>(session_options, options);
// The GpuDeviceInfo actually provides information not only for GPU
// devices but also for TPU. The name is a legacy from the pre-TPU
// dark ages.
Status status = device->UseGpuDeviceInfo();
if (!status.ok()) {
errors::AppendToMessage(&status, "while setting up ", DEVICE_TPU_XLA_JIT,
" device number ", i);
return status;
}
device->SetAllowsSyncOnCompletion(false);
if (tpu_xla_device_failure_closes_chips_flag) {
device->SetHandleDeviceErrorCallback(&tpu::TpuNodeContext::CloseTpuHost);
}
devices->push_back(std::move(device));
}
return Status::OK();
}
class TpuSystemDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status TpuSystemDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
int device_count = 0;
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
if (device_count == 0) {
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
return Status::OK();
}
devices->push_back("/physical_device:TPU_SYSTEM:0");
return Status::OK();
}
Status TpuSystemDeviceFactory::CreateDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
int device_count = 0;
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
if (device_count == 0) {
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
return Status::OK();
}
int64 memory_limit;
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpuMemoryLimit(&memory_limit));
// Creates a device that represents a TPU distributed system.
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
absl::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0),
DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(),
absl::StrCat("device: ", DEVICE_TPU_SYSTEM, " device"));
devices->push_back(absl::make_unique<VirtualDevice>(options.env, attrs));
VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count
<< " TPUs";
return Status::OK();
}
} // namespace
void RegisterTpuDeviceToDeviceCopy() {
static auto* const register_tpu_tpu_copy = new CopyTensor::Registration(
DEVICE_TPU_NODE, DEVICE_TPU_NODE, TpuDeviceToDeviceCopy);
(void)register_tpu_tpu_copy;
}
void RegisterTpuNodeDevice(
bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips,
bool tpu_use_substreams_for_cross_tpu_device_transfers) {
tpu_autoclustering_flag = tpu_autoclustering;
tpu_xla_device_failure_closes_chips_flag =
tpu_xla_device_failure_closes_chips;
tpu_use_substreams_for_cross_tpu_device_transfers_flag =
tpu_use_substreams_for_cross_tpu_device_transfers;
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
}
void RegisterTpuSystemDevice() {
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
}
#if !defined(PLATFORM_GOOGLE)
// We automatically register this if we are building for open source. For
// Google platforms, we initialize these devices in other places.
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
#endif // PLATFORM_GOOGLE
} // namespace tensorflow