Add a TPUExecute function.

PiperOrigin-RevId: 321720840
Change-Id: I9f7304c6f8fd6ffe8266c60b10a6f19a7b3bdc54
This commit is contained in:
Wenhao Jia 2020-07-16 23:14:25 -07:00 committed by TensorFlower Gardener
parent 94d2ab31f8
commit 6b8687f97c
5 changed files with 628 additions and 0 deletions

View File

@ -227,3 +227,45 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tpu_execute",
srcs = ["tpu_execute.cc"],
hdrs = ["tpu_execute.h"],
deps = [
":tpu_api",
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/tpu:c_api_conversions",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_executable",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_node_context",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
],
)

View File

@ -37,11 +37,21 @@ TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
XLA_Shape* host_shape, XLA_Shape* device_shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
uint32_t* runtime_input_ptr, size_t runtime_input_size,
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
XLA_Shape* compile_time_shape, SE_Status* status);
struct TfTpu_ExecuteApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
};
} // extern "C"

View File

@ -0,0 +1,519 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/tpu_execute.h"
#include <cstdlib>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include "absl/base/casts.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executable.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
namespace {
using ::tensorflow::tpu::TpuNodeContext;
static bool tpu_cancellation_terminates_process = false;
static bool tpu_cancellation_closes_chips = true;
// Host-side runtime for transfers between TPU and host.
class HostTransferManager {
public:
using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
explicit HostTransferManager(TpuNodeContext* node_context)
: node_context_(node_context) {}
// Returns a function to be called when the TPU triggers a host command
// interrupt while executing the current program.
xla::StatusOr<HostCommmandHandler> Initialize(
const TPUHostTransferInfoProto& program,
const std::string& rendezvous_key_base, OpKernelContext* ctx);
private:
TpuNodeContext* node_context_; // not owned
TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
};
xla::StatusOr<HostTransferManager::HostCommmandHandler>
HostTransferManager::Initialize(const TPUHostTransferInfoProto& program,
const string& rendezvous_key_base,
OpKernelContext* ctx) {
return HostCommmandHandler([](uint32, int64) {
LOG(WARNING) << "HostTransferManager is unimplemented.";
});
}
// Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart.
void ExitCountdown(Env* env) {
const int kSleepSeconds = 5;
LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds
<< " seconds before terminating the process to give time "
"for other errors to propagate";
env->SleepForMicroseconds(kSleepSeconds * 1000000);
LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult "
"the anomalies reported above (if any), run state of job "
"(including failed RPCs) and worker logs. This "
"termination is to ensure a consistent state, if your job "
"does not restart, modify the retries allowed. See "
"b/62262381 and b/65223927.";
std::quick_exit(42);
}
xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
XLA_Shape c_host_shape;
XLA_Shape c_device_shape;
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
&c_host_shape, &c_device_shape);
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
TpuConversions::CShapeCleanup(&c_host_shape);
TpuConversions::CShapeCleanup(&c_device_shape);
return device_shape;
}
int64 ShapeSizeCompact(const xla::Shape& shape) {
XLA_Shape c_shape;
TpuConversions::XlaShapeToCShape(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactFn(
&c_shape);
TpuConversions::CShapeCleanup(&c_shape);
return size;
}
int64 ShapeSizeCompactRaw(const xla::Shape& shape) {
XLA_Shape c_shape;
TpuConversions::XlaShapeToCShape(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
&c_shape);
TpuConversions::CShapeCleanup(&c_shape);
return size;
}
// Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables
// point to the correct leaf nodes.
xla::Status FixTupleTableAsync(se::Stream* stream,
const xla::Shape& tuple_shape,
xla::ExecutionInput* mem,
xla::TransferManager* transfer_manager) {
return xla::ShapeUtil::ForEachSubshapeWithStatus(
tuple_shape,
[&](const xla::Shape& element_shape,
const xla::ShapeIndex& index) -> Status {
if (!element_shape.IsTuple()) {
return Status::OK();
}
std::vector<se::DeviceMemoryBase> elements;
xla::ShapeIndex element_index = index;
element_index.push_back(0);
for (int64 i = 0; i < element_shape.tuple_shapes_size(); ++i) {
// Gather all children of the tuple element.
element_index.back() = i;
elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase());
}
se::DeviceMemoryBase tuple_table_addr =
mem->Buffer(index).AsDeviceMemoryBase();
return transfer_manager->WriteSingleTupleIndexTable(
stream, elements, element_shape, &tuple_table_addr);
});
}
// Returns true if `dynamic_shape` has dimensions that are less-equal to the
// "bounded_shape".
bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
const xla::Shape& bounded_shape) {
if (dynamic_shape.rank() != bounded_shape.rank()) {
return false;
}
for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
return false;
}
}
return true;
}
// For dynamic inputs, copy them and attach metadata of shape sizes to the
// end of the tensor.
//
// The buffer for dynamic shapes contains three parts:
// +--------+
// |Payload |
// +--------+
// | Padding|
// +--------+
// |Metadata|
// +--------+
//
// Metadata contains the sizes of shape without padding, eventually
// representing the size of valid data.
xla::Status UpdateDynamicInputs(
se::Stream* stream, se::DeviceMemoryAllocator* allocator,
std::vector<xla::ExecutionInput>* runtime_inputs,
const std::vector<xla::Shape>& compile_time_shapes) {
TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size());
for (int64 i = 0; i < compile_time_shapes.size(); i++) {
// TODO(yunxing): Iterating over thousands of elements can be slow. One way
// to optimize for fast path without dynamic shapes is add a field in
// compilation result indicating if dynamic input is presented.
if (compile_time_shapes[i].is_static()) {
continue;
}
auto& runtime_input = (*runtime_inputs)[i];
xla::Shape compile_time_shapes_on_device =
HostShapeToDeviceShape(compile_time_shapes[i]);
bool element_modified = false;
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
compile_time_shapes_on_device,
[&](const xla::Shape& compile_time_shape,
const xla::ShapeIndex& index) -> Status {
if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
return Status::OK();
}
const xla::Shape& runtime_shape =
xla::ShapeUtil::GetSubshape(runtime_input.shape(), index);
TF_RET_CHECK(!runtime_shape.IsTuple());
TF_RET_CHECK(
DynamicShapeIsCompatible(runtime_shape, compile_time_shape));
xla::MaybeOwningDeviceMemory* mutable_input_mem =
runtime_input.MutableBuffer(index);
auto padded_data = std::make_shared<std::vector<int8>>(
ShapeSizeCompact(compile_time_shape), -1);
auto raw_input_runtime = std::make_shared<std::vector<uint32>>(
ShapeSizeCompact(runtime_shape) / sizeof(uint32));
stream->ThenMemcpyD2H(
se::DeviceMemory<int8>(mutable_input_mem->AsDeviceMemoryBase()),
absl::MakeSpan(absl::bit_cast<int8*>(raw_input_runtime->data()),
ShapeSizeCompactRaw(runtime_shape)));
stream->ThenDoHostCallback([raw_input_runtime, padded_data,
runtime_shape, compile_time_shape]() {
// After getting the data onto the host, transpose the data to
// the correct layout by delinearizing it and linearizing it again.
XLA_Shape c_runtime_shape, c_compile_time_shape;
TpuConversions::XlaShapeToCShape(runtime_shape, &c_runtime_shape);
TpuConversions::XlaShapeToCShape(compile_time_shape,
&c_compile_time_shape);
StatusHelper status;
tensorflow::tpu::ExecuteApiFn()
->TpuExecute_RuntimeInputToPaddedDataFn(
raw_input_runtime->data(), raw_input_runtime->size(),
padded_data->data(), padded_data->size(), &c_runtime_shape,
&c_compile_time_shape, status.c_status);
TpuConversions::CShapeCleanup(&c_runtime_shape);
TpuConversions::CShapeCleanup(&c_compile_time_shape);
return status.status();
});
// Allocate new input and transfer the padded and transposed data to
// the new input location.
TF_ASSIGN_OR_RETURN(
auto new_input,
allocator->Allocate(stream->parent()->device_ordinal(),
ShapeSizeCompact(compile_time_shape)));
auto typed_new_input_memory =
se::DeviceMemory<int8>(new_input.cref());
stream->ThenMemcpyH2D<int8>(*padded_data, &typed_new_input_memory);
// Retain the memory until the end of the transfer.
stream->ThenDoHostCallback([padded_data]() { return Status::OK(); });
// Modify the memory location in the input shape tree to point to the
// new input.
*mutable_input_mem =
xla::MaybeOwningDeviceMemory(std::move(new_input));
element_modified = true;
return Status::OK();
}));
if (element_modified) {
// The input location has been modified, need to fix tuple table to
// point to the correct address.
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(FixTupleTableAsync(stream,
compile_time_shapes_on_device,
&runtime_input, transfer_manager));
}
}
return Status::OK();
}
void TPUCancelExecution(Env* env, int device_ordinal) {
if (tpu_cancellation_terminates_process) {
LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device "
<< device_ordinal;
Status status = TpuNodeContext::StopChipHeartbeats();
LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status
<< " on device " << device_ordinal;
// Sleep and exit in another thread so the cancellation manager can
// continue running callbacks. The new thread will call quick_exit,
// so we discard the returned Thread pointer because we won't have
// an opportunity to delete it.
(void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown",
[env]() { ExitCountdown(env); });
} else if (tpu_cancellation_closes_chips) {
LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal;
Status status = TpuNodeContext::CloseTpuHost();
LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status
<< " on device " << device_ordinal;
} else {
LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal
<< " is suppressed";
}
}
std::pair<CancellationToken, bool> RegisterCancellation(
OpKernelContext* ctx, CancellationManager* cancellation_manager,
int device_ordinal) {
// Set up a cancellation callback, to ensure the TPU program we run will
// halt if the RPC is cancelled. Without this the TPU program might block
// forever. The mechanism itself is a big hammer; we close all devices
// attached to this host on each cancellation callback. This is necessary to
// ensure the system will eventually halt, since the TensorNodes on each
// chip may be stuck waiting for mutual communication.
//
// By closing all devices, we ensure all subsequent attempts to use the
// device will fail, until the devices are re-initialized via a new call to
// tpu.initialize_system.
//
// In a multi-TensorNode setup, CloseTPUHost may be called once for each
// TensorNode, and each call will close all TensorNodes. This quadratic
// behavior ensures the mechanism is robust to various orderings
// (i.e. races) between the TPU programs, which are run on separate threads.
// In practice the quadratic behavior isn't that bad; the first call will
// actually halt any running TPU programs (which may be expensive), while
// subsequent calls will attempt to close an already-closed device (which is
// cheap).
//
// TODO(b/62262381): The cancellation manager is shared between multiple TPU
// execute ops and the cancellation will not be invoked only when RPC is
// cancelled (it may also be induced by OOM errors from a different TPU
// execute), this results in a pretty coarse cancellation domain. This
// cancellation callback should only execute in a narrower scope to not be
// triggered in such cases.
CancellationToken token = cancellation_manager->get_cancellation_token();
// Don't rely on OpKernelContext being available when the callback runs.
Env* env = ctx->env();
bool already_cancelled = !cancellation_manager->RegisterCallback(
token,
[device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); });
return std::pair<CancellationToken, bool>(token, already_cancelled);
}
void UnregisterCancellation(
OpKernelContext* ctx, CancellationManager* cancellation_manager,
se::Stream* stream, int device_ordinal, CancellationToken token,
std::shared_ptr<HostTransferManager> host_transfer_manager) {
// If execution reaches this point, the host callback enqueued below will get
// called regardless of stream status. Call inc_num_deferred_ops_function here
// and dec_num_deferred_ops_function in the host callback.
ctx->inc_num_deferred_ops_function()();
auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function();
// Try to avoid running callbacks on the compute stream, because this reduces
// the frequency of back-to-back programs (which are most efficient because
// they don't require host synchronization). Instead, borrow a substream and
// have the substream wait on the compute stream.
se::Stream* deregister_stream = stream->GetOrCreateSubStream();
deregister_stream->ThenWaitFor(stream);
deregister_stream->ThenDoHostCallback([=]() {
// Ensure the host_transfer_manager is copied into the callback scope.
(void)host_transfer_manager;
// We must deregister the callback in the success case, to avoid closing all
// devices. In the failure case we must NOT call DeregisterCallback as that
// waits for all previous cancellation callbacks to complete and any call
// to XlaDevice::Sync() will cause deadlock. Consider:
// 1) CancellationManager::StartCancel() is in progress (state is
// cancelling_).
// 2) The call below to DeregisterCallback will block until state is
// cancelled_ (all callbacks are completed).
// 3) A different cancellation callback has called XlaDevice::Sync(),
// which will block until (2) is done.
// 4) StartCancel() in (1) cannot complete until (3) is done.
//
// Instead, call TryDeregisterCallback. The functional difference is
// TryDeregisterCallback will not block if cancellation is in proress
// so makes no guarantees as to the state of any callbacks.
// This is not a problem, as our cancellation handler does not rely on
// any external state.
VLOG(1) << "cancellation_manager->TryDeregisterCallback on device "
<< device_ordinal;
cancellation_manager->TryDeregisterCallback(token);
VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device "
<< device_ordinal;
// ExecutorState is held alive until at least this point to ensure
// cancellation_manager is valid. After all outstanding
// dec_num_deferred_ops_function are called, ExecutorState::Finish will be
// allowed to proceed.
dec_num_deferred_ops_function();
});
stream->ReturnSubStream(deregister_stream);
}
} // namespace
xla::StatusOr<xla::ExecutionOutput> TPUExecute(
const TPUExecutableInfoProto& executable,
const TPUHostTransferInfoProto& host_transfers,
const xla::HloProto& hlo_metadata,
std::vector<xla::ExecutionInput> arguments,
const string& rendezvous_key_base, uint32 rng_seed,
TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
CancellationManager* cancellation_manager, OpKernelContext* ctx,
stream_executor::Stream* stream,
stream_executor::Stream* host_to_device_stream,
const XLA_TpuProgram* tpu_program) {
profiler::TraceMe traceme("TPUExecute", 2);
TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
TF_RET_CHECK(tpu_program != nullptr);
VLOG(1) << "TPUExecute on device " << node_context->tensor_core_location();
XlaDevice* device =
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
TF_RET_CHECK(device);
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
std::shared_ptr<HostTransferManager> host_transfer_manager =
std::make_shared<HostTransferManager>(node_context);
TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
host_transfer_manager->Initialize(
host_transfers, rendezvous_key_base, ctx));
VLOG(2) << "Cloud TPU: Executing computation on device "
<< node_context->index_on_host();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_device_assignment(device_assignment);
run_options.set_rng_seed(rng_seed);
run_options.set_allocator(node_context->memory_allocator());
run_options.set_host_to_device_stream(host_to_device_stream);
const xla::ServiceExecutableRunOptions service_run_options(run_options);
std::unique_ptr<xla::HloModule> module;
std::vector<xla::Shape> input_shapes;
{
xla::ComputationLayout computation_layout(
xla::ShapeLayout(xla::Shape(executable.output_shape())));
for (const xla::ShapeProto& shape_proto : executable.input_shapes()) {
xla::Shape shape(shape_proto);
computation_layout.add_parameter_layout(xla::ShapeLayout(shape));
input_shapes.push_back(std::move(shape));
}
module = absl::make_unique<xla::HloModule>(
"TpuExecutableModule",
xla::HloModuleConfig(std::move(computation_layout)));
}
TF_ASSIGN_OR_RETURN(
module->input_output_alias_config(),
xla::HloInputOutputAliasConfig::CreateFromProto(
node_context->transfer_manager()->HostShapeToDeviceShape(
module->config().entry_computation_layout().result_shape()),
hlo_metadata.hlo_module().input_output_alias()));
TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) {
module->AddCrossProgramPrefetch(
prefetch.parameter(),
xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
}
TF_RETURN_IF_ERROR(UpdateDynamicInputs(
stream, node_context->memory_allocator(), &arguments, input_shapes));
auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
tpu_program, std::move(module), handler);
const int32 device_ordinal = node_context->device_ordinal();
CancellationToken token;
bool already_cancelled;
std::tie(token, already_cancelled) =
RegisterCancellation(ctx, cancellation_manager, device_ordinal);
// If the RPC was already cancelled before we managed to register the
// cancellation callback, we shouldn't attempt to run the TPU program, since
// it might block forever.
if (already_cancelled) {
return errors::Cancelled(
"RPC cancelled, not running TPU program on device ", device_ordinal);
}
xla::StatusOr<xla::ExecutionOutput> output =
tpu_executable->ExecuteAsyncOnStream(&service_run_options,
std::move(arguments),
/*hlo_execution_profile=*/nullptr);
// If !output.ok(), it means we failed to enqueue the program the TPU. This is
// possibly caused by a failed cancellation callback closing the chips.
if (!output.ok()) {
// If cancellation manager is already cancelled or cancelling, it means
// another failure has occurred earlier and this TpuExecuteOp is cancelled
// regardless of whether itself is an error.
already_cancelled = cancellation_manager->IsCancelling() ||
cancellation_manager->IsCancelled();
if (already_cancelled) {
return errors::Cancelled(
"RPC cancelled, not running TPU program on device ", device_ordinal);
}
}
UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal,
token, host_transfer_manager);
VLOG(1) << "Cloud TPU: TPUExecute done";
return output;
}
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_
#define TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
namespace tensorflow {
// Runs a TPU executable. `input_allocations` and `output_allocations` are
// non-owning pointers to the root buffers of each argument/result tuple.
// `output_shape` is the output shape of the XLA computation from which
// `program` was derived. If `session_module` is not nullptr, it will be filled
// with the input and output literals of the execution.
xla::StatusOr<xla::ExecutionOutput> TPUExecute(
const TPUExecutableInfoProto& executable,
const TPUHostTransferInfoProto& host_transfers,
const xla::HloProto& hlo_metadata,
std::vector<xla::ExecutionInput> arguments,
const std::string& rendezvous_key_base, uint32 rng_seed,
tpu::TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
CancellationManager* cancellation_manager, OpKernelContext* ctx,
stream_executor::Stream* stream,
stream_executor::Stream* host_to_device_stream,
const XLA_TpuProgram* tpu_program);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_

View File

@ -43,6 +43,9 @@ tensorflow::Status SetExecuteStructFn(void* library_handle) {
TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompact);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompactRaw);
TFTPU_SET_FN(execute_fn, TpuExecute_RuntimeInputToPaddedData);
return tensorflow::Status::OK();
}