TpuDriver: Improve handling of restarted clients/servers.
PiperOrigin-RevId: 280727303 Change-Id: I9c11368de26d8ce851c799a53a903a452d385975
This commit is contained in:
parent
64ac71b1e9
commit
690d47e60b
@ -31,8 +31,8 @@ namespace tpu_driver {
|
|||||||
// This class provides a typed interface for these values as well as support for
|
// This class provides a typed interface for these values as well as support for
|
||||||
// hashing and ostreams (for logging).
|
// hashing and ostreams (for logging).
|
||||||
struct EventId {
|
struct EventId {
|
||||||
int64_t client_id;
|
uint64_t client_id;
|
||||||
int64_t operation_id;
|
uint64_t operation_id;
|
||||||
|
|
||||||
template <typename H>
|
template <typename H>
|
||||||
friend H AbslHashValue(H h, const EventId& c) {
|
friend H AbslHashValue(H h, const EventId& c) {
|
||||||
@ -51,9 +51,9 @@ struct EventId {
|
|||||||
return absl::StrCat(client_id, ":", operation_id);
|
return absl::StrCat(client_id, ":", operation_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t AsInt() const { return client_id << 44 | operation_id; }
|
uint64_t AsInt() const { return client_id << 44 | operation_id; }
|
||||||
|
|
||||||
static EventId FromInt(int64_t value) {
|
static EventId FromInt(uint64_t value) {
|
||||||
return EventId{value >> 44, value & 0xfffffffffff};
|
return EventId{value >> 44, value & 0xfffffffffff};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -64,6 +64,25 @@ class GrpcEvent : public Event {
|
|||||||
GrpcTpuStream* stream_;
|
GrpcTpuStream* stream_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ErrorEvent : public GrpcEvent {
|
||||||
|
public:
|
||||||
|
explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) {
|
||||||
|
status_ = status;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Status Await() override { return status_; }
|
||||||
|
absl::optional<xla::Status> AwaitWithTimeout(
|
||||||
|
absl::Duration duration) override {
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
void AddCallback(std::function<void(Status)> callback) override {
|
||||||
|
callback(status_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status status_;
|
||||||
|
};
|
||||||
|
|
||||||
class GrpcBufferHandle : public BufferHandle {
|
class GrpcBufferHandle : public BufferHandle {
|
||||||
public:
|
public:
|
||||||
explicit GrpcBufferHandle(
|
explicit GrpcBufferHandle(
|
||||||
@ -417,17 +436,19 @@ class GrpcTpuDriver : public TpuDriver {
|
|||||||
static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
|
static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
|
||||||
const TpuDriverConfig& config);
|
const TpuDriverConfig& config);
|
||||||
|
|
||||||
|
uint32 client_id() const { return client_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
||||||
|
|
||||||
const TpuDriverConfig config_;
|
const TpuDriverConfig config_;
|
||||||
const int32_t client_id_;
|
const uint32_t client_id_;
|
||||||
// Map from stream IDs to streams.
|
// Map from stream IDs to streams.
|
||||||
absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
|
absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
|
||||||
std::unique_ptr<GrpcTpuStream> host_stream_;
|
std::unique_ptr<GrpcTpuStream> host_stream_;
|
||||||
// Shared by all streams.
|
// Shared by all streams.
|
||||||
std::atomic<int64_t> operation_id_{0};
|
std::atomic<uint64_t> operation_id_{0};
|
||||||
};
|
}; // namespace
|
||||||
|
|
||||||
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
||||||
|
|
||||||
@ -464,8 +485,11 @@ GrpcTpuStream::~GrpcTpuStream() {
|
|||||||
// Mark all remaining events invalid.
|
// Mark all remaining events invalid.
|
||||||
absl::MutexLock lock(&events_mutex_);
|
absl::MutexLock lock(&events_mutex_);
|
||||||
for (auto e : events_) {
|
for (auto e : events_) {
|
||||||
|
if (!e.second.done) {
|
||||||
|
LOG(ERROR) << "Resetting: " << e.first;
|
||||||
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
|
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
|
||||||
"Tpustream was closed."));
|
"Driver was closed."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
VLOG(1) << "Closing stream.";
|
VLOG(1) << "Closing stream.";
|
||||||
@ -511,8 +535,9 @@ void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) {
|
|||||||
|
|
||||||
// This is the first time this event finishes. Remember the results and call
|
// This is the first time this event finishes. Remember the results and call
|
||||||
// the callbacks.
|
// the callbacks.
|
||||||
VLOG(1) << "Response received for GrpcEvent " << id << ". Firing "
|
VLOG(1) << "Response received for GrpcEvent " << id << ". "
|
||||||
<< it->second.callbacks.size() << " callbacks.";
|
<< status.ToString() << ". Firing " << it->second.callbacks.size()
|
||||||
|
<< " callbacks.";
|
||||||
it->second.done = true;
|
it->second.done = true;
|
||||||
it->second.status = status;
|
it->second.status = status;
|
||||||
for (const auto& callback : it->second.callbacks) {
|
for (const auto& callback : it->second.callbacks) {
|
||||||
@ -544,6 +569,7 @@ absl::optional<Status> GrpcTpuStream::WaitForEvent(EventId id,
|
|||||||
events_mutex_.AssertHeld();
|
events_mutex_.AssertHeld();
|
||||||
return !events_.contains(id) || events_[id].done;
|
return !events_.contains(id) || events_[id].done;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
|
if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
|
||||||
return events_.contains(id) ? events_[id].status : Status();
|
return events_.contains(id) ? events_[id].status : Status();
|
||||||
}
|
}
|
||||||
@ -594,6 +620,8 @@ void GrpcTpuStream::StreamWriterFn() {
|
|||||||
reqs.push_back(StreamRequest());
|
reqs.push_back(StreamRequest());
|
||||||
request_bytes = 0;
|
request_bytes = 0;
|
||||||
}
|
}
|
||||||
|
VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id());
|
||||||
|
VLOG(2) << "Sending request: " << e->DebugString();
|
||||||
reqs.back().mutable_entry()->AddAllocated(e);
|
reqs.back().mutable_entry()->AddAllocated(e);
|
||||||
}
|
}
|
||||||
num_pending_requests_ = 0;
|
num_pending_requests_ = 0;
|
||||||
@ -611,9 +639,10 @@ void GrpcTpuStream::StreamWriterFn() {
|
|||||||
void GrpcTpuStream::StreamReaderFn() {
|
void GrpcTpuStream::StreamReaderFn() {
|
||||||
StreamResponse resp;
|
StreamResponse resp;
|
||||||
while (stream_->Read(&resp)) {
|
while (stream_->Read(&resp)) {
|
||||||
VLOG(1) << "Received response: " << resp.DebugString();
|
VLOG(2) << "Received response: " << resp.DebugString();
|
||||||
for (const StreamResponse::Entry entry : resp.entry()) {
|
for (const StreamResponse::Entry entry : resp.entry()) {
|
||||||
EventId event_id = EventId::FromInt(entry.operation_id());
|
EventId event_id = EventId::FromInt(entry.operation_id());
|
||||||
|
VLOG(1) << "Received response for: " << event_id;
|
||||||
|
|
||||||
TraceMe activity("GrpcTpuStream::RequestComplete");
|
TraceMe activity("GrpcTpuStream::RequestComplete");
|
||||||
if (entry.has_transfer_from()) {
|
if (entry.has_transfer_from()) {
|
||||||
@ -805,8 +834,15 @@ std::unique_ptr<LoadedProgramHandle> GrpcTpuStream::LoadProgram(
|
|||||||
InitializeRequest(req.get(), wait_for);
|
InitializeRequest(req.get(), wait_for);
|
||||||
TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram"));
|
TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram"));
|
||||||
req->mutable_load()->set_core_id(core_id);
|
req->mutable_load()->set_core_id(core_id);
|
||||||
req->mutable_load()->set_compiled_program_handle(
|
auto grpc_handle = static_cast<const GrpcCompiledProgramHandle*>(handle);
|
||||||
static_cast<const GrpcCompiledProgramHandle*>(handle)->id().AsInt());
|
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||||
|
auto event = absl::make_unique<ErrorEvent>(
|
||||||
|
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
|
||||||
|
"you restart the server or use a stale handle?"));
|
||||||
|
return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
|
||||||
|
std::move(event));
|
||||||
|
}
|
||||||
|
req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt());
|
||||||
auto event =
|
auto event =
|
||||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||||
AddWriteRequest(std::move(req));
|
AddWriteRequest(std::move(req));
|
||||||
@ -835,13 +871,33 @@ std::unique_ptr<Event> GrpcTpuStream::ExecuteProgram(
|
|||||||
absl::Span<Event* const> wait_for) {
|
absl::Span<Event* const> wait_for) {
|
||||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||||
InitializeRequest(req.get(), wait_for);
|
InitializeRequest(req.get(), wait_for);
|
||||||
req->mutable_execute()->set_loaded_program_handle(
|
auto program_handle = static_cast<GrpcLoadedProgramHandle*>(program);
|
||||||
static_cast<GrpcLoadedProgramHandle*>(program)->id().AsInt());
|
if (program_handle->id().client_id != driver_->client_id()) {
|
||||||
for (BufferHandle* input : inputs) {
|
return absl::make_unique<ErrorEvent>(
|
||||||
req->mutable_execute()->add_input_handle(
|
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
|
||||||
static_cast<GrpcBufferHandle*>(input)->id().AsInt());
|
"you restart the server or use a stale handle?"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req->mutable_execute()->set_loaded_program_handle(
|
||||||
|
program_handle->id().AsInt());
|
||||||
|
|
||||||
|
for (BufferHandle* input : inputs) {
|
||||||
|
auto* grpc_handle = static_cast<GrpcBufferHandle*>(input);
|
||||||
|
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||||
|
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
|
||||||
|
"Invalid input buffer (wrong client id). Did you restart the server "
|
||||||
|
"or use a stale handle?"));
|
||||||
|
}
|
||||||
|
req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt());
|
||||||
|
}
|
||||||
|
|
||||||
for (BufferHandle* output : outputs) {
|
for (BufferHandle* output : outputs) {
|
||||||
|
auto* grpc_handle = static_cast<GrpcBufferHandle*>(output);
|
||||||
|
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||||
|
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
|
||||||
|
"Invalid output buffer (wrong client id). Did you restart the server "
|
||||||
|
"or use a stale handle?"));
|
||||||
|
}
|
||||||
req->mutable_execute()->add_output_handle(
|
req->mutable_execute()->add_output_handle(
|
||||||
static_cast<GrpcBufferHandle*>(output)->id().AsInt());
|
static_cast<GrpcBufferHandle*>(output)->id().AsInt());
|
||||||
}
|
}
|
||||||
|
@ -137,11 +137,15 @@ message StreamResponse {
|
|||||||
message OpenRequest {}
|
message OpenRequest {}
|
||||||
|
|
||||||
message OpenResponse {
|
message OpenResponse {
|
||||||
required int32 client_id = 1;
|
required fixed32 client_id = 1;
|
||||||
|
|
||||||
|
// Maximum time this client can be idle before it is GC'ed and all resources
|
||||||
|
// released.
|
||||||
|
optional int32 max_idle_time_seconds = 2 [default = 3600];
|
||||||
}
|
}
|
||||||
|
|
||||||
message CloseRequest {
|
message CloseRequest {
|
||||||
required int32 client_id = 1;
|
required fixed32 client_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CloseResponse {}
|
message CloseResponse {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user