[XLA:PJRT] Add TPU client and associated libtpu.so initializer.
Also adds TpuPlatformInterface::topology() convenience method. PiperOrigin-RevId: 335730454 Change-Id: I74a4e3a555d8b9030f735f9f8ca423959b250424
This commit is contained in:
parent
11b687edfe
commit
62544da433
@ -168,6 +168,46 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_client",
|
||||
srcs = ["tpu_client.cc"],
|
||||
hdrs = ["tpu_client.h"],
|
||||
visibility = [
|
||||
"//learning/brain/research/jax:__subpackages__",
|
||||
"//learning/deepmind/tensorflow/tensorfn:__subpackages__",
|
||||
"//learning/pathways:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":local_device_state",
|
||||
":pjrt_client",
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_executor_dlsym_initializer",
|
||||
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_device",
|
||||
srcs = ["interpreter_device.cc"],
|
||||
|
246
tensorflow/compiler/xla/pjrt/tpu_client.cc
Normal file
246
tensorflow/compiler/xla/pjrt/tpu_client.cc
Normal file
@ -0,0 +1,246 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||
|
||||
namespace tf_tpu = tensorflow::tpu;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class TpuDeviceState : public LocalDeviceState {
|
||||
public:
|
||||
TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
|
||||
bool asynchronous);
|
||||
|
||||
Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
|
||||
se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer,
|
||||
se::DeviceMemoryBase dst_buffer) override;
|
||||
};
|
||||
|
||||
TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
|
||||
LocalClient* client, bool asynchronous)
|
||||
: LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
|
||||
asynchronous,
|
||||
/*allow_event_reuse=*/false) {}
|
||||
|
||||
Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
||||
se::Stream* transfer_stream, se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
|
||||
auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
|
||||
transfer_stream->implementation());
|
||||
tf_tpu::TpuTopologyExternal topology =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||
// TODO(b/157179600): use device-to-device transfers when implemented instead
|
||||
// of copying via host.
|
||||
if (topology.version() == kTpuV4) {
|
||||
LOG(WARNING)
|
||||
<< "device-to-device transfers not yet implemented, copying via host";
|
||||
auto* dst_tpu_stream =
|
||||
tensorflow::down_cast<tf_tpu::TpuStream*>(dst_stream->implementation());
|
||||
TF_RET_CHECK(src_buffer.size() == dst_buffer.size());
|
||||
auto host_tmp = std::make_unique<char[]>(src_buffer.size());
|
||||
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost(
|
||||
src_buffer, host_tmp.get(), src_buffer.size()));
|
||||
dst_stream->ThenWaitFor(transfer_stream);
|
||||
TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice(
|
||||
dst_buffer, host_tmp.get(), dst_buffer.size()));
|
||||
transfer_stream->ThenWaitFor(dst_stream);
|
||||
char* tmp = host_tmp.release();
|
||||
dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; });
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
|
||||
src_buffer, dst_buffer));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class PjRtTpuClient : public PjRtClient {
|
||||
public:
|
||||
PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
tf_tpu::TpuPlatformInterface* tpu_platform);
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
|
||||
bool EnqueueD2DTransfersOnSrcStream() const override {
|
||||
return tpu_platform_->topology().version() == kTpuV4;
|
||||
}
|
||||
|
||||
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const override;
|
||||
|
||||
private:
|
||||
tf_tpu::TpuPlatformInterface* tpu_platform_;
|
||||
};
|
||||
|
||||
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
||||
int host_id,
|
||||
tf_tpu::TpuPlatformInterface* tpu_platform)
|
||||
: PjRtClient("tpu", client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr),
|
||||
tpu_platform_(tpu_platform) {}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
tf_tpu::TpuPlatformInterface* platform =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||
tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
|
||||
int num_local_devices = host.Cores(kTensorCore).size();
|
||||
if (num_replicas * num_partitions <= num_local_devices) {
|
||||
return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
// Fallback to default global device assignment if we can't run locally.
|
||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
}
|
||||
|
||||
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const {
|
||||
if (executable.client() != this) {
|
||||
return InvalidArgument(
|
||||
"Passed executable from different client (platform '%s') to "
|
||||
"PjRtTpuClient::ExecutableFingerprint",
|
||||
executable.client()->platform_name());
|
||||
}
|
||||
if (executable.executables().size() > 1) {
|
||||
LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
|
||||
"executables, fingerprint may not be unique.";
|
||||
}
|
||||
xla::TpuExecutableInterface* tpu_executable =
|
||||
tensorflow::down_cast<xla::TpuExecutableInterface*>(
|
||||
executable.executables()[0]->executable());
|
||||
return absl::optional<std::string>(tpu_executable->fingerprint());
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
tf_tpu::TpuTopologyExternal topology =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||
|
||||
std::map<int, int> core_id_to_device_ordinal;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(i).ValueOrDie();
|
||||
tf_tpu::TpuExecutorInterface* tpu_executor =
|
||||
tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
|
||||
executor->implementation());
|
||||
core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
|
||||
}
|
||||
|
||||
for (const tf_tpu::TpuCoreLocationExternal& core :
|
||||
topology.cores(TpuCoreTypeEnum::kTensorCore)) {
|
||||
auto it = core_id_to_device_ordinal.find(core.Id());
|
||||
int device_ordinal =
|
||||
(it != core_id_to_device_ordinal.end()) ? it->second : -1;
|
||||
int host_id = topology.IdForHost(core.host_coordinates());
|
||||
const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
|
||||
std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
|
||||
std::unique_ptr<LocalDeviceState> local_device_state;
|
||||
if (device_ordinal >= 0) {
|
||||
local_device_state = std::move(local_device_states[device_ordinal]);
|
||||
}
|
||||
auto device = absl::make_unique<PjRtTpuDevice>(
|
||||
core, std::move(local_device_state), host_id, coords_array,
|
||||
std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
|
||||
devices.push_back(std::move(device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
|
||||
bool asynchronous, absl::Duration init_retry_timeout) {
|
||||
tf_tpu::TpuPlatformInterface* platform =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||
if (platform == nullptr) {
|
||||
return InvalidArgument("TpuPlatform is not available.");
|
||||
}
|
||||
// NOTE: We retry in a loop since some pod failures are transient (e.g. some
|
||||
// RPCs may timeout waiting for other hosts to come up, but will succeed
|
||||
// at a later point if retried).
|
||||
auto start = absl::Now();
|
||||
// TODO(b/165870356): TpuPlatform::Initialized() always returns true!
|
||||
auto status = platform->Initialize({});
|
||||
while (!platform->Initialized()) {
|
||||
status = platform->Initialize({});
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Platform initialization failed: " << status;
|
||||
if ((absl::Now() - start) >= init_retry_timeout) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (platform->VisibleDeviceCount() <= 0) {
|
||||
return InvalidArgument("No TPU devices found.");
|
||||
}
|
||||
LocalClientOptions options;
|
||||
options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
|
||||
local_device_states.reserve(client->device_count());
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(i).ValueOrDie();
|
||||
local_device_states.push_back(
|
||||
absl::make_unique<TpuDeviceState>(executor, client, asynchronous));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto devices,
|
||||
GetTpuDevices(client, std::move(local_device_states)));
|
||||
int host_id = platform->GetTpuHostLocation().Id();
|
||||
|
||||
return std::shared_ptr<PjRtClient>(absl::make_unique<PjRtTpuClient>(
|
||||
client, std::move(devices), host_id, platform));
|
||||
}
|
||||
|
||||
} // namespace xla
|
60
tensorflow/compiler/xla/pjrt/tpu_client.h
Normal file
60
tensorflow/compiler/xla/pjrt/tpu_client.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class PjRtTpuDevice : public PjRtDevice {
|
||||
public:
|
||||
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
int host_id, const std::array<int, 3>& coords,
|
||||
std::string device_kind)
|
||||
: PjRtDevice(core.Id(), std::move(local_device_state),
|
||||
/*platform_name=*/"tpu", std::move(device_kind), host_id),
|
||||
core_(core),
|
||||
coords_(coords) {}
|
||||
|
||||
const std::array<int, 3>& coords() const { return coords_; }
|
||||
int core_on_chip() const { return core_.index(); }
|
||||
const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; }
|
||||
|
||||
std::string DebugString() const override {
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
|
||||
coords_[0], coords_[1], coords_[2], core_.index());
|
||||
}
|
||||
|
||||
private:
|
||||
const tensorflow::tpu::TpuCoreLocationExternal core_;
|
||||
const std::array<int, 3> coords_;
|
||||
};
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
|
||||
bool asynchronous,
|
||||
absl::Duration init_retry_timeout = absl::ZeroDuration());
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
|
@ -177,6 +177,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# This is an alternative to "tpu_api_dlsym_initializer" that only initializes
|
||||
# methods needed for the base TPU executor APIs (and thus has fewer deps). Do
|
||||
# not link in both this and "tpu_api_dlsym_initializer".
|
||||
cc_library(
|
||||
name = "tpu_executor_dlsym_initializer",
|
||||
srcs = ["tpu_executor_dlsym_initializer.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_api_dlsym_set_fn",
|
||||
":tpu_executor_init_fns",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_api_dlsym_set_fn",
|
||||
hdrs = ["tpu_api_dlsym_set_fn.h"],
|
||||
@ -294,6 +311,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tpu_on_demand_compiler",
|
||||
srcs = ["tpu_on_demand_compiler.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
70
tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
Normal file
70
tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TODO(skye): this is largely a copy of tpu_api_dlsym_initializer.cc. Figure
|
||||
// out how to deduplicate these files a little.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h"
|
||||
#if !defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/core/tpu/tpu_executor_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
return errors::Unimplemented("You must statically link in a TPU library.");
|
||||
}
|
||||
#else // PLATFORM_GOOGLE
|
||||
#include "tensorflow/core/tpu/tpu_executor_init_fns.inc"
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
Status s = SetExecutorStructFn(library_handle);
|
||||
|
||||
// TPU platform registration must only be performed after the library is
|
||||
// loaded. We do not want to register a TPU platform in XLA without the
|
||||
// supporting library providing the necessary APIs.
|
||||
if (s.ok()) {
|
||||
void (*initialize_fn)();
|
||||
initialize_fn = reinterpret_cast<decltype(initialize_fn)>(
|
||||
dlsym(library_handle, "TfTpu_Initialize"));
|
||||
(*initialize_fn)();
|
||||
|
||||
RegisterTpuPlatform();
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
bool FindAndLoadTpuLibrary() {
|
||||
void* library = dlopen("libtpu.so", RTLD_NOW);
|
||||
if (library) {
|
||||
InitializeTpuLibrary(library);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool tpu_library_finder = FindAndLoadTpuLibrary();
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -222,6 +222,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager",
|
||||
srcs = ["tpu_transfer_manager_registration.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_executor",
|
||||
":tpu_transfer_manager_base",
|
||||
@ -256,6 +257,7 @@ cc_library(
|
||||
name = "tpu_computation_placer",
|
||||
srcs = ["tpu_computation_placer.cc"],
|
||||
hdrs = ["tpu_computation_placer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":status_helper",
|
||||
":tpu_executor",
|
||||
|
@ -56,6 +56,10 @@ class TpuPlatformInterface : public stream_executor::Platform {
|
||||
virtual const TpuTopologyPtr GetTopologyPtr() = 0;
|
||||
|
||||
virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0;
|
||||
|
||||
TpuTopologyExternal topology() {
|
||||
return TpuTopologyExternal(GetTopologyPtr());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
Loading…
Reference in New Issue
Block a user