Internal change only.

PiperOrigin-RevId: 316058358
Change-Id: I2317fe006b40b9f930e6610ea7c059d15662dcb1
This commit is contained in:
Wenhao Jia 2020-06-12 00:10:36 -07:00 committed by TensorFlower Gardener
parent fa58bd5e94
commit d7d97a0d72
15 changed files with 102 additions and 20 deletions

View File

@ -62,6 +62,10 @@ xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
DataType type,
bool use_fast_memory);
// Given a tensor, returns the shape of its representation on device,
// fully padded. Contents of `shape` are undefined on error.
Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
// A callback called on exit.
void LogAndExit(int code);

View File

@ -61,9 +61,11 @@ cc_library(
name = "tpu_stream",
hdrs = ["tpu_stream.h"],
deps = [
":c_api_conversions",
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_stream_interface",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
],
)
@ -222,7 +224,10 @@ cc_library(
name = "tpu_stream_interface",
hdrs = ["tpu_stream_interface.h"],
visibility = ["//visibility:public"],
deps = ["//tensorflow/stream_executor:stream_executor_internal"],
deps = [
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:stream_executor_internal",
],
)
cc_library(

View File

@ -118,14 +118,14 @@ bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
stream_executor::Event* event) {
return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
executor_, event_map_.at(event->implementation())));
executor_, event_map().at(event->implementation())));
}
Status TpuExecutor::RecordEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
event_map_.at(event->implementation()),
event_map().at(event->implementation()),
status.c_status);
return status.status();
}
@ -134,7 +134,7 @@ Status TpuExecutor::WaitForEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
event_map_.at(event->implementation()),
event_map().at(event->implementation()),
status.c_status);
return status.status();
}
@ -168,7 +168,7 @@ std::unique_ptr<::stream_executor::internal::EventInterface>
TpuExecutor::CreateEventImplementation() {
SE_Event* tpu_event = TpuEvent_New(executor_);
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
event_map_[ptr.get()] = tpu_event;
event_map()[ptr.get()] = tpu_event;
return ptr;
}

View File

@ -48,9 +48,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
using StreamExecutorInterface =
::stream_executor::internal::StreamExecutorInterface;
using EventMap =
absl::flat_hash_map<stream_executor::internal::EventInterface*,
SE_Event*>;
using TimerMap =
absl::flat_hash_map<stream_executor::internal::TimerInterface*,
SE_Timer*>;
@ -225,13 +222,16 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
}
private:
EventMap event_map_;
TimerMap timer_map_;
TpuPlatform::StreamMap& stream_map() {
return *(static_cast<TpuPlatform*>(platform_)->stream_map());
}
TpuPlatform::EventMap& event_map() {
return *(static_cast<TpuPlatform*>(platform_)->event_map());
}
::tensorflow::tpu::TpuPlatformInterface* platform_;
SE_StreamExecutor* executor_;
};

View File

@ -148,6 +148,7 @@ SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform);
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
SE_DeviceOptions* device_options, SE_Status* status);
@ -231,6 +232,11 @@ SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
void TpuStream_Free(SE_Stream*);
void* TpuStream_Stream(SE_Stream*);
bool TpuStream_Status(SE_Stream*);
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
SE_DeviceMemoryBase send_buffer,
SE_DeviceMemoryBase recv_buffer,
SE_Status* status);
SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
void TpuEvent_Free(SE_Event*);

View File

@ -27,7 +27,8 @@ namespace tpu {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
/*static*/
StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
int device_ordinal) {
StatusHelper status;
XLA_TpuNodeContext* node_context =
@ -41,6 +42,13 @@ using stream_executor::port::StatusOr;
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
/* static */
Status TpuNodeContext::Initialize(int device_ordinal) {
StatusHelper status;
TpuNodeContext_Initialize(device_ordinal, status.c_status);
return status.status();
}
/* static */
Status TpuNodeContext::StopChipHeartbeats() {
StatusHelper status;

View File

@ -39,8 +39,7 @@ class TpuNodeContext final {
template <typename T>
using StatusOr = stream_executor::port::StatusOr<T>;
static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
int device_ordinal);
static StatusOr<std::unique_ptr<TpuNodeContext>> Create(int device_ordinal);
explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
: device_ordinal_(device_ordinal), node_context_(node_context) {
@ -51,6 +50,8 @@ class TpuNodeContext final {
TpuNodeContext(const TpuNodeContext&) = delete;
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
static Status Initialize(int device_ordinal);
static Status StopChipHeartbeats();
static Status CloseTpuHost();

View File

@ -23,6 +23,8 @@ XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
SE_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
void TpuNodeContext_CloseTpuHost(SE_Status* status);

View File

@ -106,6 +106,10 @@ int64 TpuPlatform::TpuMemoryLimit() {
return TpuPlatform_TpuMemoryLimit(platform_);
}
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
return TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(platform_);
}
} // namespace tensorflow
void RegisterTpuPlatform() {

View File

@ -33,6 +33,9 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
using StreamMap =
absl::flat_hash_map<stream_executor::internal::StreamInterface*,
SE_Stream*>;
using EventMap =
absl::flat_hash_map<stream_executor::internal::EventInterface*,
SE_Event*>;
static const ::stream_executor::Platform::Id kId;
static constexpr char kName[] = "TPU";
@ -55,6 +58,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
int64 TpuMemoryLimit() override;
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
bool Initialized() const override {
return TpuPlatform_Initialized(platform_);
}
@ -109,11 +114,14 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
StreamMap* stream_map() { return &stream_map_; }
EventMap* event_map() { return &event_map_; }
private:
SE_Platform* platform_;
stream_executor::ExecutorCache executor_cache_;
StreamMap stream_map_;
EventMap event_map_;
};
} // namespace tensorflow

View File

@ -15,19 +15,30 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
#include <atomic>
#include "tensorflow/stream_executor/multi_platform_manager.h"
namespace tensorflow {
namespace tpu {
namespace {
TpuPlatformInterface* tpu_registered_platform = nullptr;
} // namespace
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
if (tpu_registered_platform != nullptr) {
return tpu_registered_platform;
}
// Prefer TpuPlatform if it's registered.
auto status_or_tpu_platform =
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
if (status_or_tpu_platform.ok()) {
return static_cast<TpuPlatformInterface*>(
status_or_tpu_platform.ValueOrDie());
tpu_registered_platform =
static_cast<TpuPlatformInterface*>(status_or_tpu_platform.ValueOrDie());
return tpu_registered_platform;
}
if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
LOG(WARNING) << "Error when getting the TPU platform: "
@ -52,7 +63,9 @@ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
LOG(WARNING) << other_tpu_platforms.size()
<< " TPU platforms registered, selecting "
<< other_tpu_platforms[0]->Name();
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
tpu_registered_platform =
static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
return tpu_registered_platform;
}
LOG(WARNING) << "No TPU platform registered";

View File

@ -36,6 +36,8 @@ class TpuPlatformInterface : public stream_executor::Platform {
virtual Status Reset(bool only_tear_down) = 0;
virtual int64 TpuMemoryLimit() = 0;
virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0;
};
} // namespace tpu

View File

@ -17,13 +17,36 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
class TpuStream : public stream_executor::internal::StreamInterface {
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
public:
using Status = stream_executor::port::Status;
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
~TpuStream() override { TpuStream_Free(stream_); }
bool IsSameSharedMemoryLocation(
tensorflow::tpu::TpuStreamInterface* other) override {
return TpuStream_IsSameSharedMemoryLocation(
stream_, static_cast<TpuStream*>(other)->stream_);
}
Status EnqueueOnTpuDeviceSendRecvLocal(
stream_executor::DeviceMemoryBase send_buffer,
stream_executor::DeviceMemoryBase recv_buffer) override {
StatusHelper status;
TpuStream_TpuEnqueueOnDeviceSendRecvLocal(
stream_,
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
status.c_status);
return status.status();
}
private:
SE_Stream* stream_;
};

View File

@ -16,12 +16,20 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace tensorflow {
namespace tpu {
class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
class TpuStreamInterface : public stream_executor::internal::StreamInterface {
public:
using Status = stream_executor::port::Status;
virtual bool IsSameSharedMemoryLocation(TpuStreamInterface* other) = 0;
virtual Status EnqueueOnTpuDeviceSendRecvLocal(
stream_executor::DeviceMemoryBase send_buffer,
stream_executor::DeviceMemoryBase recv_buffer) = 0;
};
} // namespace tpu

View File

@ -27,8 +27,6 @@ limitations under the License.
namespace tensorflow {
using Status = stream_executor::port::Status;
template <typename T>
using StatusOr = stream_executor::port::StatusOr<T>;
TpuTransferManager::TpuTransferManager() {
manager_ = TpuTransferManager_New();