From 6b8687f97cc349c5d3cdea39b535ba6292831b2e Mon Sep 17 00:00:00 2001 From: Wenhao Jia Date: Thu, 16 Jul 2020 23:14:25 -0700 Subject: [PATCH] Add a TPUExecute function. PiperOrigin-RevId: 321720840 Change-Id: I9f7304c6f8fd6ffe8266c60b10a6f19a7b3bdc54 --- tensorflow/core/tpu/BUILD | 42 ++ .../core/tpu/kernels/tpu_execute_c_api.h | 10 + tensorflow/core/tpu/tpu_execute.cc | 519 ++++++++++++++++++ tensorflow/core/tpu/tpu_execute.h | 54 ++ tensorflow/core/tpu/tpu_library_init_fns.inc | 3 + 5 files changed, 628 insertions(+) create mode 100644 tensorflow/core/tpu/tpu_execute.cc create mode 100644 tensorflow/core/tpu/tpu_execute.h diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index f9031b440f9..d82011c6961 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -227,3 +227,45 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +cc_library( + name = "tpu_execute", + srcs = ["tpu_execute.cc"], + hdrs = ["tpu_execute.h"], + deps = [ + ":tpu_api", + "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", + "//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/tpu:c_api_conversions", + "//tensorflow/stream_executor/tpu:status_helper", + "//tensorflow/stream_executor/tpu:tpu_executable", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + "//tensorflow/stream_executor/tpu:tpu_node_context", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + ], +) diff --git a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h b/tensorflow/core/tpu/kernels/tpu_execute_c_api.h index 38a550444a9..81d23441ddc 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_execute_c_api.h @@ -37,11 +37,21 @@ TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( XLA_Shape* host_shape, XLA_Shape* device_shape); TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); + +TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( + uint32_t* runtime_input_ptr, size_t runtime_input_size, + int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape, + XLA_Shape* compile_time_shape, SE_Status* status); struct TfTpu_ExecuteApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); + TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); }; } // extern "C" diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc new file mode 100644 index 00000000000..022e8c2a07e --- /dev/null +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -0,0 +1,519 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/tpu_execute.h" + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/tpu/c_api_conversions.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_executable.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { + +namespace { + +using ::tensorflow::tpu::TpuNodeContext; + +static bool tpu_cancellation_terminates_process = false; +static bool tpu_cancellation_closes_chips = true; + +// Host-side runtime for transfers between TPU and host. +class HostTransferManager { + public: + using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler; + + explicit HostTransferManager(TpuNodeContext* node_context) + : node_context_(node_context) {} + + // Returns a function to be called when the TPU triggers a host command + // interrupt while executing the current program. + xla::StatusOr Initialize( + const TPUHostTransferInfoProto& program, + const std::string& rendezvous_key_base, OpKernelContext* ctx); + + private: + TpuNodeContext* node_context_; // not owned + + TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager); +}; + +xla::StatusOr +HostTransferManager::Initialize(const TPUHostTransferInfoProto& program, + const string& rendezvous_key_base, + OpKernelContext* ctx) { + return HostCommmandHandler([](uint32, int64) { + LOG(WARNING) << "HostTransferManager is unimplemented."; + }); +} + +// Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart. +void ExitCountdown(Env* env) { + const int kSleepSeconds = 5; + LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds + << " seconds before terminating the process to give time " + "for other errors to propagate"; + env->SleepForMicroseconds(kSleepSeconds * 1000000); + LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult " + "the anomalies reported above (if any), run state of job " + "(including failed RPCs) and worker logs. This " + "termination is to ensure a consistent state, if your job " + "does not restart, modify the retries allowed. See " + "b/62262381 and b/65223927."; + std::quick_exit(42); +} + +xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) { + XLA_Shape c_host_shape; + XLA_Shape c_device_shape; + TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape); + tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( + &c_host_shape, &c_device_shape); + xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape); + TpuConversions::CShapeCleanup(&c_host_shape); + TpuConversions::CShapeCleanup(&c_device_shape); + return device_shape; +} + +int64 ShapeSizeCompact(const xla::Shape& shape) { + XLA_Shape c_shape; + TpuConversions::XlaShapeToCShape(shape, &c_shape); + int64 size = + tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactFn( + &c_shape); + TpuConversions::CShapeCleanup(&c_shape); + return size; +} + +int64 ShapeSizeCompactRaw(const xla::Shape& shape) { + XLA_Shape c_shape; + TpuConversions::XlaShapeToCShape(shape, &c_shape); + int64 size = + tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactRawFn( + &c_shape); + TpuConversions::CShapeCleanup(&c_shape); + return size; +} + +// Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables +// point to the correct leaf nodes. +xla::Status FixTupleTableAsync(se::Stream* stream, + const xla::Shape& tuple_shape, + xla::ExecutionInput* mem, + xla::TransferManager* transfer_manager) { + return xla::ShapeUtil::ForEachSubshapeWithStatus( + tuple_shape, + [&](const xla::Shape& element_shape, + const xla::ShapeIndex& index) -> Status { + if (!element_shape.IsTuple()) { + return Status::OK(); + } + std::vector elements; + xla::ShapeIndex element_index = index; + element_index.push_back(0); + for (int64 i = 0; i < element_shape.tuple_shapes_size(); ++i) { + // Gather all children of the tuple element. + element_index.back() = i; + elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase()); + } + se::DeviceMemoryBase tuple_table_addr = + mem->Buffer(index).AsDeviceMemoryBase(); + return transfer_manager->WriteSingleTupleIndexTable( + stream, elements, element_shape, &tuple_table_addr); + }); +} + +// Returns true if `dynamic_shape` has dimensions that are less-equal to the +// "bounded_shape". +bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape) { + if (dynamic_shape.rank() != bounded_shape.rank()) { + return false; + } + for (int64 i = 0; i < dynamic_shape.rank(); ++i) { + if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { + return false; + } + } + return true; +} + +// For dynamic inputs, copy them and attach metadata of shape sizes to the +// end of the tensor. +// +// The buffer for dynamic shapes contains three parts: +// +--------+ +// |Payload | +// +--------+ +// | Padding| +// +--------+ +// |Metadata| +// +--------+ +// +// Metadata contains the sizes of shape without padding, eventually +// representing the size of valid data. +xla::Status UpdateDynamicInputs( + se::Stream* stream, se::DeviceMemoryAllocator* allocator, + std::vector* runtime_inputs, + const std::vector& compile_time_shapes) { + TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size()); + for (int64 i = 0; i < compile_time_shapes.size(); i++) { + // TODO(yunxing): Iterating over thousands of elements can be slow. One way + // to optimize for fast path without dynamic shapes is add a field in + // compilation result indicating if dynamic input is presented. + if (compile_time_shapes[i].is_static()) { + continue; + } + auto& runtime_input = (*runtime_inputs)[i]; + xla::Shape compile_time_shapes_on_device = + HostShapeToDeviceShape(compile_time_shapes[i]); + bool element_modified = false; + TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( + compile_time_shapes_on_device, + [&](const xla::Shape& compile_time_shape, + const xla::ShapeIndex& index) -> Status { + if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + return Status::OK(); + } + + const xla::Shape& runtime_shape = + xla::ShapeUtil::GetSubshape(runtime_input.shape(), index); + + TF_RET_CHECK(!runtime_shape.IsTuple()); + TF_RET_CHECK( + DynamicShapeIsCompatible(runtime_shape, compile_time_shape)); + + xla::MaybeOwningDeviceMemory* mutable_input_mem = + runtime_input.MutableBuffer(index); + auto padded_data = std::make_shared>( + ShapeSizeCompact(compile_time_shape), -1); + auto raw_input_runtime = std::make_shared>( + ShapeSizeCompact(runtime_shape) / sizeof(uint32)); + stream->ThenMemcpyD2H( + se::DeviceMemory(mutable_input_mem->AsDeviceMemoryBase()), + absl::MakeSpan(absl::bit_cast(raw_input_runtime->data()), + ShapeSizeCompactRaw(runtime_shape))); + stream->ThenDoHostCallback([raw_input_runtime, padded_data, + runtime_shape, compile_time_shape]() { + // After getting the data onto the host, transpose the data to + // the correct layout by delinearizing it and linearizing it again. + XLA_Shape c_runtime_shape, c_compile_time_shape; + TpuConversions::XlaShapeToCShape(runtime_shape, &c_runtime_shape); + TpuConversions::XlaShapeToCShape(compile_time_shape, + &c_compile_time_shape); + StatusHelper status; + tensorflow::tpu::ExecuteApiFn() + ->TpuExecute_RuntimeInputToPaddedDataFn( + raw_input_runtime->data(), raw_input_runtime->size(), + padded_data->data(), padded_data->size(), &c_runtime_shape, + &c_compile_time_shape, status.c_status); + TpuConversions::CShapeCleanup(&c_runtime_shape); + TpuConversions::CShapeCleanup(&c_compile_time_shape); + return status.status(); + }); + // Allocate new input and transfer the padded and transposed data to + // the new input location. + TF_ASSIGN_OR_RETURN( + auto new_input, + allocator->Allocate(stream->parent()->device_ordinal(), + ShapeSizeCompact(compile_time_shape))); + auto typed_new_input_memory = + se::DeviceMemory(new_input.cref()); + stream->ThenMemcpyH2D(*padded_data, &typed_new_input_memory); + + // Retain the memory until the end of the transfer. + stream->ThenDoHostCallback([padded_data]() { return Status::OK(); }); + + // Modify the memory location in the input shape tree to point to the + // new input. + *mutable_input_mem = + xla::MaybeOwningDeviceMemory(std::move(new_input)); + element_modified = true; + return Status::OK(); + })); + if (element_modified) { + // The input location has been modified, need to fix tuple table to + // point to the correct address. + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(FixTupleTableAsync(stream, + compile_time_shapes_on_device, + &runtime_input, transfer_manager)); + } + } + return Status::OK(); +} + +void TPUCancelExecution(Env* env, int device_ordinal) { + if (tpu_cancellation_terminates_process) { + LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device " + << device_ordinal; + Status status = TpuNodeContext::StopChipHeartbeats(); + LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status + << " on device " << device_ordinal; + // Sleep and exit in another thread so the cancellation manager can + // continue running callbacks. The new thread will call quick_exit, + // so we discard the returned Thread pointer because we won't have + // an opportunity to delete it. + (void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown", + [env]() { ExitCountdown(env); }); + } else if (tpu_cancellation_closes_chips) { + LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal; + Status status = TpuNodeContext::CloseTpuHost(); + LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status + << " on device " << device_ordinal; + } else { + LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal + << " is suppressed"; + } +} + +std::pair RegisterCancellation( + OpKernelContext* ctx, CancellationManager* cancellation_manager, + int device_ordinal) { + // Set up a cancellation callback, to ensure the TPU program we run will + // halt if the RPC is cancelled. Without this the TPU program might block + // forever. The mechanism itself is a big hammer; we close all devices + // attached to this host on each cancellation callback. This is necessary to + // ensure the system will eventually halt, since the TensorNodes on each + // chip may be stuck waiting for mutual communication. + // + // By closing all devices, we ensure all subsequent attempts to use the + // device will fail, until the devices are re-initialized via a new call to + // tpu.initialize_system. + // + // In a multi-TensorNode setup, CloseTPUHost may be called once for each + // TensorNode, and each call will close all TensorNodes. This quadratic + // behavior ensures the mechanism is robust to various orderings + // (i.e. races) between the TPU programs, which are run on separate threads. + // In practice the quadratic behavior isn't that bad; the first call will + // actually halt any running TPU programs (which may be expensive), while + // subsequent calls will attempt to close an already-closed device (which is + // cheap). + // + // TODO(b/62262381): The cancellation manager is shared between multiple TPU + // execute ops and the cancellation will not be invoked only when RPC is + // cancelled (it may also be induced by OOM errors from a different TPU + // execute), this results in a pretty coarse cancellation domain. This + // cancellation callback should only execute in a narrower scope to not be + // triggered in such cases. + CancellationToken token = cancellation_manager->get_cancellation_token(); + // Don't rely on OpKernelContext being available when the callback runs. + Env* env = ctx->env(); + bool already_cancelled = !cancellation_manager->RegisterCallback( + token, + [device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); }); + return std::pair(token, already_cancelled); +} + +void UnregisterCancellation( + OpKernelContext* ctx, CancellationManager* cancellation_manager, + se::Stream* stream, int device_ordinal, CancellationToken token, + std::shared_ptr host_transfer_manager) { + // If execution reaches this point, the host callback enqueued below will get + // called regardless of stream status. Call inc_num_deferred_ops_function here + // and dec_num_deferred_ops_function in the host callback. + ctx->inc_num_deferred_ops_function()(); + auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function(); + + // Try to avoid running callbacks on the compute stream, because this reduces + // the frequency of back-to-back programs (which are most efficient because + // they don't require host synchronization). Instead, borrow a substream and + // have the substream wait on the compute stream. + se::Stream* deregister_stream = stream->GetOrCreateSubStream(); + deregister_stream->ThenWaitFor(stream); + deregister_stream->ThenDoHostCallback([=]() { + // Ensure the host_transfer_manager is copied into the callback scope. + (void)host_transfer_manager; + + // We must deregister the callback in the success case, to avoid closing all + // devices. In the failure case we must NOT call DeregisterCallback as that + // waits for all previous cancellation callbacks to complete and any call + // to XlaDevice::Sync() will cause deadlock. Consider: + // 1) CancellationManager::StartCancel() is in progress (state is + // cancelling_). + // 2) The call below to DeregisterCallback will block until state is + // cancelled_ (all callbacks are completed). + // 3) A different cancellation callback has called XlaDevice::Sync(), + // which will block until (2) is done. + // 4) StartCancel() in (1) cannot complete until (3) is done. + // + // Instead, call TryDeregisterCallback. The functional difference is + // TryDeregisterCallback will not block if cancellation is in proress + // so makes no guarantees as to the state of any callbacks. + // This is not a problem, as our cancellation handler does not rely on + // any external state. + VLOG(1) << "cancellation_manager->TryDeregisterCallback on device " + << device_ordinal; + cancellation_manager->TryDeregisterCallback(token); + VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device " + << device_ordinal; + + // ExecutorState is held alive until at least this point to ensure + // cancellation_manager is valid. After all outstanding + // dec_num_deferred_ops_function are called, ExecutorState::Finish will be + // allowed to proceed. + dec_num_deferred_ops_function(); + }); + stream->ReturnSubStream(deregister_stream); +} + +} // namespace + +xla::StatusOr TPUExecute( + const TPUExecutableInfoProto& executable, + const TPUHostTransferInfoProto& host_transfers, + const xla::HloProto& hlo_metadata, + std::vector arguments, + const string& rendezvous_key_base, uint32 rng_seed, + TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment, + CancellationManager* cancellation_manager, OpKernelContext* ctx, + stream_executor::Stream* stream, + stream_executor::Stream* host_to_device_stream, + const XLA_TpuProgram* tpu_program) { + profiler::TraceMe traceme("TPUExecute", 2); + TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr); + TF_RET_CHECK(tpu_program != nullptr); + VLOG(1) << "TPUExecute on device " << node_context->tensor_core_location(); + + XlaDevice* device = + tensorflow::down_cast(ctx->device()->UnderlyingDevice()); + TF_RET_CHECK(device); + + // Create a HostTransferManager to handle Send/Recv operations from the TPU. + std::shared_ptr host_transfer_manager = + std::make_shared(node_context); + TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler, + host_transfer_manager->Initialize( + host_transfers, rendezvous_key_base, ctx)); + + VLOG(2) << "Cloud TPU: Executing computation on device " + << node_context->index_on_host(); + + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_device_assignment(device_assignment); + run_options.set_rng_seed(rng_seed); + run_options.set_allocator(node_context->memory_allocator()); + run_options.set_host_to_device_stream(host_to_device_stream); + + const xla::ServiceExecutableRunOptions service_run_options(run_options); + + std::unique_ptr module; + std::vector input_shapes; + { + xla::ComputationLayout computation_layout( + xla::ShapeLayout(xla::Shape(executable.output_shape()))); + for (const xla::ShapeProto& shape_proto : executable.input_shapes()) { + xla::Shape shape(shape_proto); + computation_layout.add_parameter_layout(xla::ShapeLayout(shape)); + input_shapes.push_back(std::move(shape)); + } + module = absl::make_unique( + "TpuExecutableModule", + xla::HloModuleConfig(std::move(computation_layout))); + } + + TF_ASSIGN_OR_RETURN( + module->input_output_alias_config(), + xla::HloInputOutputAliasConfig::CreateFromProto( + node_context->transfer_manager()->HostShapeToDeviceShape( + module->config().entry_computation_layout().result_shape()), + hlo_metadata.hlo_module().input_output_alias())); + TF_RET_CHECK(executable.input_shapes().size() == arguments.size()); + + for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) { + module->AddCrossProgramPrefetch( + prefetch.parameter(), + xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end())); + } + + TF_RETURN_IF_ERROR(UpdateDynamicInputs( + stream, node_context->memory_allocator(), &arguments, input_shapes)); + + auto tpu_executable = absl::make_unique( + tpu_program, std::move(module), handler); + + const int32 device_ordinal = node_context->device_ordinal(); + CancellationToken token; + bool already_cancelled; + std::tie(token, already_cancelled) = + RegisterCancellation(ctx, cancellation_manager, device_ordinal); + + // If the RPC was already cancelled before we managed to register the + // cancellation callback, we shouldn't attempt to run the TPU program, since + // it might block forever. + if (already_cancelled) { + return errors::Cancelled( + "RPC cancelled, not running TPU program on device ", device_ordinal); + } + + xla::StatusOr output = + tpu_executable->ExecuteAsyncOnStream(&service_run_options, + std::move(arguments), + /*hlo_execution_profile=*/nullptr); + + // If !output.ok(), it means we failed to enqueue the program the TPU. This is + // possibly caused by a failed cancellation callback closing the chips. + if (!output.ok()) { + // If cancellation manager is already cancelled or cancelling, it means + // another failure has occurred earlier and this TpuExecuteOp is cancelled + // regardless of whether itself is an error. + already_cancelled = cancellation_manager->IsCancelling() || + cancellation_manager->IsCancelled(); + if (already_cancelled) { + return errors::Cancelled( + "RPC cancelled, not running TPU program on device ", device_ordinal); + } + } + UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal, + token, host_transfer_manager); + VLOG(1) << "Cloud TPU: TPUExecute done"; + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_execute.h b/tensorflow/core/tpu/tpu_execute.h new file mode 100644 index 00000000000..e2142ad7a7a --- /dev/null +++ b/tensorflow/core/tpu/tpu_execute.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ +#define TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/tpu/tpu_node_context.h" + +namespace tensorflow { + +// Runs a TPU executable. `input_allocations` and `output_allocations` are +// non-owning pointers to the root buffers of each argument/result tuple. +// `output_shape` is the output shape of the XLA computation from which +// `program` was derived. If `session_module` is not nullptr, it will be filled +// with the input and output literals of the execution. +xla::StatusOr TPUExecute( + const TPUExecutableInfoProto& executable, + const TPUHostTransferInfoProto& host_transfers, + const xla::HloProto& hlo_metadata, + std::vector arguments, + const std::string& rendezvous_key_base, uint32 rng_seed, + tpu::TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment, + CancellationManager* cancellation_manager, OpKernelContext* ctx, + stream_executor::Stream* stream, + stream_executor::Stream* host_to_device_stream, + const XLA_TpuProgram* tpu_program); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 7a7c6ecad30..06197870fee 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -43,6 +43,9 @@ tensorflow::Status SetExecuteStructFn(void* library_handle) { TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream); TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape); TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize); + TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompact); + TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompactRaw); + TFTPU_SET_FN(execute_fn, TpuExecute_RuntimeInputToPaddedData); return tensorflow::Status::OK(); }