diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index d81b9a4b84c..664780abcc7 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -168,6 +168,46 @@ cc_library( ], ) +cc_library( + name = "tpu_client", + srcs = ["tpu_client.cc"], + hdrs = ["tpu_client.h"], + visibility = [ + "//learning/brain/research/jax:__subpackages__", + "//learning/deepmind/tensorflow/tensorfn:__subpackages__", + "//learning/pathways:__subpackages__", + ], + deps = [ + ":local_device_state", + ":pjrt_client", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core/tpu:tpu_executor_dlsym_initializer", + "//tensorflow/core/tpu:tpu_on_demand_compiler", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/tpu:tpu_computation_placer", + "//tensorflow/stream_executor/tpu:tpu_executable_interface", + "//tensorflow/stream_executor/tpu:tpu_executor", + "//tensorflow/stream_executor/tpu:tpu_executor_interface", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/stream_executor/tpu:tpu_topology_external", + "//tensorflow/stream_executor/tpu:tpu_transfer_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "interpreter_device", srcs = ["interpreter_device.cc"], diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc new file mode 100644 index 00000000000..a8711631605 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -0,0 +1,246 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" + +#include <memory> +#include <vector> + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" +#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_stream.h" + +namespace tf_tpu = tensorflow::tpu; + +namespace xla { +namespace { + +class TpuDeviceState : public LocalDeviceState { + public: + TpuDeviceState(se::StreamExecutor* executor, LocalClient* client, + bool asynchronous); + + Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer) override; +}; + +TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor, + LocalClient* client, bool asynchronous) + : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous, + asynchronous, + /*allow_event_reuse=*/false) {} + +Status TpuDeviceState::ThenMemcpyDeviceToDevice( + se::Stream* transfer_stream, se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { + auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>( + transfer_stream->implementation()); + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + // TODO(b/157179600): use device-to-device transfers when implemented instead + // of copying via host. + if (topology.version() == kTpuV4) { + LOG(WARNING) + << "device-to-device transfers not yet implemented, copying via host"; + auto* dst_tpu_stream = + tensorflow::down_cast<tf_tpu::TpuStream*>(dst_stream->implementation()); + TF_RET_CHECK(src_buffer.size() == dst_buffer.size()); + auto host_tmp = std::make_unique<char[]>(src_buffer.size()); + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost( + src_buffer, host_tmp.get(), src_buffer.size())); + dst_stream->ThenWaitFor(transfer_stream); + TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice( + dst_buffer, host_tmp.get(), dst_buffer.size())); + transfer_stream->ThenWaitFor(dst_stream); + char* tmp = host_tmp.release(); + dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; }); + } else { + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal( + src_buffer, dst_buffer)); + } + return Status::OK(); +} + +class PjRtTpuClient : public PjRtClient { + public: + PjRtTpuClient(LocalClient* client, + std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform); + + StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + bool EnqueueD2DTransfersOnSrcStream() const override { + return tpu_platform_->topology().version() == kTpuV4; + } + + StatusOr<absl::optional<std::string>> ExecutableFingerprint( + const PjRtExecutable& executable) const override; + + private: + tf_tpu::TpuPlatformInterface* tpu_platform_; +}; + +PjRtTpuClient::PjRtTpuClient(LocalClient* client, + std::vector<std::unique_ptr<PjRtDevice>> devices, + int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform) + : PjRtClient("tpu", client, std::move(devices), host_id, + /*allocator=*/nullptr, + /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr), + tpu_platform_(tpu_platform) {} + +StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation(); + int num_local_devices = host.Cores(kTensorCore).size(); + if (num_replicas * num_partitions <= num_local_devices) { + return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas, + num_partitions); + } + // Fallback to default global device assignment if we can't run locally. + return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); +} + +StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint( + const PjRtExecutable& executable) const { + if (executable.client() != this) { + return InvalidArgument( + "Passed executable from different client (platform '%s') to " + "PjRtTpuClient::ExecutableFingerprint", + executable.client()->platform_name()); + } + if (executable.executables().size() > 1) { + LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " + "executables, fingerprint may not be unique."; + } + xla::TpuExecutableInterface* tpu_executable = + tensorflow::down_cast<xla::TpuExecutableInterface*>( + executable.executables()[0]->executable()); + return absl::optional<std::string>(tpu_executable->fingerprint()); +} + +StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices( + LocalClient* client, + std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) { + std::vector<std::unique_ptr<PjRtDevice>> devices; + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + + std::map<int, int> core_id_to_device_ordinal; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + tf_tpu::TpuExecutorInterface* tpu_executor = + tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>( + executor->implementation()); + core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i; + } + + for (const tf_tpu::TpuCoreLocationExternal& core : + topology.cores(TpuCoreTypeEnum::kTensorCore)) { + auto it = core_id_to_device_ordinal.find(core.Id()); + int device_ordinal = + (it != core_id_to_device_ordinal.end()) ? it->second : -1; + int host_id = topology.IdForHost(core.host_coordinates()); + const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates(); + std::array<int, 3> coords_array = {coords.x, coords.y, coords.z}; + std::unique_ptr<LocalDeviceState> local_device_state; + if (device_ordinal >= 0) { + local_device_state = std::move(local_device_states[device_ordinal]); + } + auto device = absl::make_unique<PjRtTpuDevice>( + core, std::move(local_device_state), host_id, coords_array, + std::string(tf_tpu::TpuVersionEnumToString(topology.version()))); + devices.push_back(std::move(device)); + } + return devices; +} + +} // namespace + +StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient( + bool asynchronous, absl::Duration init_retry_timeout) { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + if (platform == nullptr) { + return InvalidArgument("TpuPlatform is not available."); + } + // NOTE: We retry in a loop since some pod failures are transient (e.g. some + // RPCs may timeout waiting for other hosts to come up, but will succeed + // at a later point if retried). + auto start = absl::Now(); + // TODO(b/165870356): TpuPlatform::Initialized() always returns true! + auto status = platform->Initialize({}); + while (!platform->Initialized()) { + status = platform->Initialize({}); + if (!status.ok()) { + LOG(ERROR) << "Platform initialization failed: " << status; + if ((absl::Now() - start) >= init_retry_timeout) { + return status; + } + } + } + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("No TPU devices found."); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + + std::vector<std::unique_ptr<LocalDeviceState>> local_device_states; + local_device_states.reserve(client->device_count()); + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + local_device_states.push_back( + absl::make_unique<TpuDeviceState>(executor, client, asynchronous)); + } + + TF_ASSIGN_OR_RETURN(auto devices, + GetTpuDevices(client, std::move(local_device_states))); + int host_id = platform->GetTpuHostLocation().Id(); + + return std::shared_ptr<PjRtClient>(absl::make_unique<PjRtTpuClient>( + client, std::move(devices), host_id, platform)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h new file mode 100644 index 00000000000..1a458c1480b --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ + +#include <array> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/tpu/tpu_topology.h" + +namespace xla { + +class PjRtTpuDevice : public PjRtDevice { + public: + PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, + std::unique_ptr<LocalDeviceState> local_device_state, + int host_id, const std::array<int, 3>& coords, + std::string device_kind) + : PjRtDevice(core.Id(), std::move(local_device_state), + /*platform_name=*/"tpu", std::move(device_kind), host_id), + core_(core), + coords_(coords) {} + + const std::array<int, 3>& coords() const { return coords_; } + int core_on_chip() const { return core_.index(); } + const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; } + + std::string DebugString() const override { + return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(), + coords_[0], coords_[1], coords_[2], core_.index()); + } + + private: + const tensorflow::tpu::TpuCoreLocationExternal core_; + const std::array<int, 3> coords_; +}; + +StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient( + bool asynchronous, + absl::Duration init_retry_timeout = absl::ZeroDuration()); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 484c9a47e60..85586809014 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -177,6 +177,23 @@ cc_library( ], ) +# This is an alternative to "tpu_api_dlsym_initializer" that only initializes +# methods needed for the base TPU executor APIs (and thus has fewer deps). Do +# not link in both this and "tpu_api_dlsym_initializer". +cc_library( + name = "tpu_executor_dlsym_initializer", + srcs = ["tpu_executor_dlsym_initializer.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tpu_api_dlsym_set_fn", + ":tpu_executor_init_fns", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/tpu:tpu_computation_placer", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + ], + alwayslink = True, +) + cc_library( name = "tpu_api_dlsym_set_fn", hdrs = ["tpu_api_dlsym_set_fn.h"], @@ -294,6 +311,7 @@ cc_library( cc_library( name = "tpu_on_demand_compiler", srcs = ["tpu_on_demand_compiler.cc"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc new file mode 100644 index 00000000000..4d84781f4e3 --- /dev/null +++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(skye): this is largely a copy of tpu_api_dlsym_initializer.cc. Figure +// out how to deduplicate these files a little. + +#include <dlfcn.h> + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h" +#if !defined(PLATFORM_GOOGLE) +#include "tensorflow/core/tpu/tpu_executor_api.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" +#endif + +namespace tensorflow { +namespace tpu { + +#if defined(PLATFORM_GOOGLE) +Status InitializeTpuLibrary(void* library_handle) { + return errors::Unimplemented("You must statically link in a TPU library."); +} +#else // PLATFORM_GOOGLE +#include "tensorflow/core/tpu/tpu_executor_init_fns.inc" + +Status InitializeTpuLibrary(void* library_handle) { + Status s = SetExecutorStructFn(library_handle); + + // TPU platform registration must only be performed after the library is + // loaded. We do not want to register a TPU platform in XLA without the + // supporting library providing the necessary APIs. + if (s.ok()) { + void (*initialize_fn)(); + initialize_fn = reinterpret_cast<decltype(initialize_fn)>( + dlsym(library_handle, "TfTpu_Initialize")); + (*initialize_fn)(); + + RegisterTpuPlatform(); + } + + return s; +} + +bool FindAndLoadTpuLibrary() { + void* library = dlopen("libtpu.so", RTLD_NOW); + if (library) { + InitializeTpuLibrary(library); + } + return true; +} + +static bool tpu_library_finder = FindAndLoadTpuLibrary(); +#endif // PLATFORM_GOOGLE + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index df0867042f2..540a0a234ff 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -222,6 +222,7 @@ cc_library( cc_library( name = "tpu_transfer_manager", srcs = ["tpu_transfer_manager_registration.cc"], + visibility = ["//visibility:public"], deps = [ ":tpu_executor", ":tpu_transfer_manager_base", @@ -256,6 +257,7 @@ cc_library( name = "tpu_computation_placer", srcs = ["tpu_computation_placer.cc"], hdrs = ["tpu_computation_placer.h"], + visibility = ["//visibility:public"], deps = [ ":status_helper", ":tpu_executor", diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h index fee9d92b42d..240148977e3 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h @@ -56,6 +56,10 @@ class TpuPlatformInterface : public stream_executor::Platform { virtual const TpuTopologyPtr GetTopologyPtr() = 0; virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0; + + TpuTopologyExternal topology() { + return TpuTopologyExternal(GetTopologyPtr()); + } }; } // namespace tpu