Internal change only.
PiperOrigin-RevId: 316058358 Change-Id: I2317fe006b40b9f930e6610ea7c059d15662dcb1
This commit is contained in:
parent
fa58bd5e94
commit
d7d97a0d72
@ -62,6 +62,10 @@ xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
|
||||
DataType type,
|
||||
bool use_fast_memory);
|
||||
|
||||
// Given a tensor, returns the shape of its representation on device,
|
||||
// fully padded. Contents of `shape` are undefined on error.
|
||||
Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
|
||||
|
||||
// A callback called on exit.
|
||||
void LogAndExit(int code);
|
||||
|
||||
|
@ -61,9 +61,11 @@ cc_library(
|
||||
name = "tpu_stream",
|
||||
hdrs = ["tpu_stream.h"],
|
||||
deps = [
|
||||
":c_api_conversions",
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_stream_interface",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -222,7 +224,10 @@ cc_library(
|
||||
name = "tpu_stream_interface",
|
||||
hdrs = ["tpu_stream_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/stream_executor:stream_executor_internal"],
|
||||
deps = [
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:stream_executor_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -118,14 +118,14 @@ bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
|
||||
stream_executor::Event* event) {
|
||||
return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
|
||||
executor_, event_map_.at(event->implementation())));
|
||||
executor_, event_map().at(event->implementation())));
|
||||
}
|
||||
|
||||
Status TpuExecutor::RecordEvent(Stream* stream,
|
||||
::stream_executor::Event* event) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
|
||||
event_map_.at(event->implementation()),
|
||||
event_map().at(event->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
@ -134,7 +134,7 @@ Status TpuExecutor::WaitForEvent(Stream* stream,
|
||||
::stream_executor::Event* event) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
|
||||
event_map_.at(event->implementation()),
|
||||
event_map().at(event->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
@ -168,7 +168,7 @@ std::unique_ptr<::stream_executor::internal::EventInterface>
|
||||
TpuExecutor::CreateEventImplementation() {
|
||||
SE_Event* tpu_event = TpuEvent_New(executor_);
|
||||
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
|
||||
event_map_[ptr.get()] = tpu_event;
|
||||
event_map()[ptr.get()] = tpu_event;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
@ -48,9 +48,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
using StreamExecutorInterface =
|
||||
::stream_executor::internal::StreamExecutorInterface;
|
||||
|
||||
using EventMap =
|
||||
absl::flat_hash_map<stream_executor::internal::EventInterface*,
|
||||
SE_Event*>;
|
||||
using TimerMap =
|
||||
absl::flat_hash_map<stream_executor::internal::TimerInterface*,
|
||||
SE_Timer*>;
|
||||
@ -225,13 +222,16 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
}
|
||||
|
||||
private:
|
||||
EventMap event_map_;
|
||||
TimerMap timer_map_;
|
||||
|
||||
TpuPlatform::StreamMap& stream_map() {
|
||||
return *(static_cast<TpuPlatform*>(platform_)->stream_map());
|
||||
}
|
||||
|
||||
TpuPlatform::EventMap& event_map() {
|
||||
return *(static_cast<TpuPlatform*>(platform_)->event_map());
|
||||
}
|
||||
|
||||
::tensorflow::tpu::TpuPlatformInterface* platform_;
|
||||
SE_StreamExecutor* executor_;
|
||||
};
|
||||
|
@ -148,6 +148,7 @@ SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
|
||||
SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
|
||||
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
|
||||
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
|
||||
bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform);
|
||||
|
||||
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
|
||||
SE_DeviceOptions* device_options, SE_Status* status);
|
||||
@ -231,6 +232,11 @@ SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
|
||||
void TpuStream_Free(SE_Stream*);
|
||||
void* TpuStream_Stream(SE_Stream*);
|
||||
bool TpuStream_Status(SE_Stream*);
|
||||
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
|
||||
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase send_buffer,
|
||||
SE_DeviceMemoryBase recv_buffer,
|
||||
SE_Status* status);
|
||||
|
||||
SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
|
||||
void TpuEvent_Free(SE_Event*);
|
||||
|
@ -27,7 +27,8 @@ namespace tpu {
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
|
||||
/*static*/
|
||||
StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
|
||||
int device_ordinal) {
|
||||
StatusHelper status;
|
||||
XLA_TpuNodeContext* node_context =
|
||||
@ -41,6 +42,13 @@ using stream_executor::port::StatusOr;
|
||||
|
||||
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_Initialize(device_ordinal, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::StopChipHeartbeats() {
|
||||
StatusHelper status;
|
||||
|
@ -39,8 +39,7 @@ class TpuNodeContext final {
|
||||
template <typename T>
|
||||
using StatusOr = stream_executor::port::StatusOr<T>;
|
||||
|
||||
static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
|
||||
int device_ordinal);
|
||||
static StatusOr<std::unique_ptr<TpuNodeContext>> Create(int device_ordinal);
|
||||
|
||||
explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
|
||||
: device_ordinal_(device_ordinal), node_context_(node_context) {
|
||||
@ -51,6 +50,8 @@ class TpuNodeContext final {
|
||||
TpuNodeContext(const TpuNodeContext&) = delete;
|
||||
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
|
||||
|
||||
static Status Initialize(int device_ordinal);
|
||||
|
||||
static Status StopChipHeartbeats();
|
||||
|
||||
static Status CloseTpuHost();
|
||||
|
@ -23,6 +23,8 @@ XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||
SE_Status* status);
|
||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||
|
||||
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
||||
|
||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||
|
||||
|
@ -106,6 +106,10 @@ int64 TpuPlatform::TpuMemoryLimit() {
|
||||
return TpuPlatform_TpuMemoryLimit(platform_);
|
||||
}
|
||||
|
||||
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
|
||||
return TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(platform_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
void RegisterTpuPlatform() {
|
||||
|
@ -33,6 +33,9 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
using StreamMap =
|
||||
absl::flat_hash_map<stream_executor::internal::StreamInterface*,
|
||||
SE_Stream*>;
|
||||
using EventMap =
|
||||
absl::flat_hash_map<stream_executor::internal::EventInterface*,
|
||||
SE_Event*>;
|
||||
|
||||
static const ::stream_executor::Platform::Id kId;
|
||||
static constexpr char kName[] = "TPU";
|
||||
@ -55,6 +58,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
|
||||
int64 TpuMemoryLimit() override;
|
||||
|
||||
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
|
||||
|
||||
bool Initialized() const override {
|
||||
return TpuPlatform_Initialized(platform_);
|
||||
}
|
||||
@ -109,11 +114,14 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
|
||||
StreamMap* stream_map() { return &stream_map_; }
|
||||
|
||||
EventMap* event_map() { return &event_map_; }
|
||||
|
||||
private:
|
||||
SE_Platform* platform_;
|
||||
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
StreamMap stream_map_;
|
||||
EventMap event_map_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,19 +15,30 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
TpuPlatformInterface* tpu_registered_platform = nullptr;
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
|
||||
if (tpu_registered_platform != nullptr) {
|
||||
return tpu_registered_platform;
|
||||
}
|
||||
|
||||
// Prefer TpuPlatform if it's registered.
|
||||
auto status_or_tpu_platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
|
||||
if (status_or_tpu_platform.ok()) {
|
||||
return static_cast<TpuPlatformInterface*>(
|
||||
status_or_tpu_platform.ValueOrDie());
|
||||
tpu_registered_platform =
|
||||
static_cast<TpuPlatformInterface*>(status_or_tpu_platform.ValueOrDie());
|
||||
return tpu_registered_platform;
|
||||
}
|
||||
if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
|
||||
LOG(WARNING) << "Error when getting the TPU platform: "
|
||||
@ -52,7 +63,9 @@ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
|
||||
LOG(WARNING) << other_tpu_platforms.size()
|
||||
<< " TPU platforms registered, selecting "
|
||||
<< other_tpu_platforms[0]->Name();
|
||||
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
|
||||
tpu_registered_platform =
|
||||
static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
|
||||
return tpu_registered_platform;
|
||||
}
|
||||
|
||||
LOG(WARNING) << "No TPU platform registered";
|
||||
|
@ -36,6 +36,8 @@ class TpuPlatformInterface : public stream_executor::Platform {
|
||||
virtual Status Reset(bool only_tear_down) = 0;
|
||||
|
||||
virtual int64 TpuMemoryLimit() = 0;
|
||||
|
||||
virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -17,13 +17,36 @@ limitations under the License.
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
|
||||
|
||||
class TpuStream : public stream_executor::internal::StreamInterface {
|
||||
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||
public:
|
||||
using Status = stream_executor::port::Status;
|
||||
|
||||
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
|
||||
~TpuStream() override { TpuStream_Free(stream_); }
|
||||
|
||||
bool IsSameSharedMemoryLocation(
|
||||
tensorflow::tpu::TpuStreamInterface* other) override {
|
||||
return TpuStream_IsSameSharedMemoryLocation(
|
||||
stream_, static_cast<TpuStream*>(other)->stream_);
|
||||
}
|
||||
|
||||
Status EnqueueOnTpuDeviceSendRecvLocal(
|
||||
stream_executor::DeviceMemoryBase send_buffer,
|
||||
stream_executor::DeviceMemoryBase recv_buffer) override {
|
||||
StatusHelper status;
|
||||
TpuStream_TpuEnqueueOnDeviceSendRecvLocal(
|
||||
stream_,
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
private:
|
||||
SE_Stream* stream_;
|
||||
};
|
||||
|
@ -16,12 +16,20 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
|
||||
class TpuStreamInterface : public stream_executor::internal::StreamInterface {
|
||||
public:
|
||||
using Status = stream_executor::port::Status;
|
||||
|
||||
virtual bool IsSameSharedMemoryLocation(TpuStreamInterface* other) = 0;
|
||||
virtual Status EnqueueOnTpuDeviceSendRecvLocal(
|
||||
stream_executor::DeviceMemoryBase send_buffer,
|
||||
stream_executor::DeviceMemoryBase recv_buffer) = 0;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -27,8 +27,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
using Status = stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuTransferManager::TpuTransferManager() {
|
||||
manager_ = TpuTransferManager_New();
|
||||
|
Loading…
Reference in New Issue
Block a user