TpuDriver: Improve handling of restarted clients/servers.

PiperOrigin-RevId: 280727303
Change-Id: I9c11368de26d8ce851c799a53a903a452d385975
This commit is contained in:
Russell Power 2019-11-15 13:45:48 -08:00 committed by TensorFlower Gardener
parent 64ac71b1e9
commit 690d47e60b
3 changed files with 81 additions and 21 deletions

View File

@ -31,8 +31,8 @@ namespace tpu_driver {
// This class provides a typed interface for these values as well as support for
// hashing and ostreams (for logging).
struct EventId {
int64_t client_id;
int64_t operation_id;
uint64_t client_id;
uint64_t operation_id;
template <typename H>
friend H AbslHashValue(H h, const EventId& c) {
@ -51,9 +51,9 @@ struct EventId {
return absl::StrCat(client_id, ":", operation_id);
}
int64_t AsInt() const { return client_id << 44 | operation_id; }
uint64_t AsInt() const { return client_id << 44 | operation_id; }
static EventId FromInt(int64_t value) {
static EventId FromInt(uint64_t value) {
return EventId{value >> 44, value & 0xfffffffffff};
}
};

View File

@ -64,6 +64,25 @@ class GrpcEvent : public Event {
GrpcTpuStream* stream_;
};
class ErrorEvent : public GrpcEvent {
public:
explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) {
status_ = status;
}
xla::Status Await() override { return status_; }
absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) override {
return status_;
}
void AddCallback(std::function<void(Status)> callback) override {
callback(status_);
}
private:
Status status_;
};
class GrpcBufferHandle : public BufferHandle {
public:
explicit GrpcBufferHandle(
@ -417,17 +436,19 @@ class GrpcTpuDriver : public TpuDriver {
static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
const TpuDriverConfig& config);
uint32 client_id() const { return client_id_; }
private:
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
const TpuDriverConfig config_;
const int32_t client_id_;
const uint32_t client_id_;
// Map from stream IDs to streams.
absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
std::unique_ptr<GrpcTpuStream> host_stream_;
// Shared by all streams.
std::atomic<int64_t> operation_id_{0};
};
std::atomic<uint64_t> operation_id_{0};
}; // namespace
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
@ -464,8 +485,11 @@ GrpcTpuStream::~GrpcTpuStream() {
// Mark all remaining events invalid.
absl::MutexLock lock(&events_mutex_);
for (auto e : events_) {
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
"Tpustream was closed."));
if (!e.second.done) {
LOG(ERROR) << "Resetting: " << e.first;
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
"Driver was closed."));
}
}
}
VLOG(1) << "Closing stream.";
@ -511,8 +535,9 @@ void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) {
// This is the first time this event finishes. Remember the results and call
// the callbacks.
VLOG(1) << "Response received for GrpcEvent " << id << ". Firing "
<< it->second.callbacks.size() << " callbacks.";
VLOG(1) << "Response received for GrpcEvent " << id << ". "
<< status.ToString() << ". Firing " << it->second.callbacks.size()
<< " callbacks.";
it->second.done = true;
it->second.status = status;
for (const auto& callback : it->second.callbacks) {
@ -544,6 +569,7 @@ absl::optional<Status> GrpcTpuStream::WaitForEvent(EventId id,
events_mutex_.AssertHeld();
return !events_.contains(id) || events_[id].done;
};
if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
return events_.contains(id) ? events_[id].status : Status();
}
@ -594,6 +620,8 @@ void GrpcTpuStream::StreamWriterFn() {
reqs.push_back(StreamRequest());
request_bytes = 0;
}
VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id());
VLOG(2) << "Sending request: " << e->DebugString();
reqs.back().mutable_entry()->AddAllocated(e);
}
num_pending_requests_ = 0;
@ -611,9 +639,10 @@ void GrpcTpuStream::StreamWriterFn() {
void GrpcTpuStream::StreamReaderFn() {
StreamResponse resp;
while (stream_->Read(&resp)) {
VLOG(1) << "Received response: " << resp.DebugString();
VLOG(2) << "Received response: " << resp.DebugString();
for (const StreamResponse::Entry entry : resp.entry()) {
EventId event_id = EventId::FromInt(entry.operation_id());
VLOG(1) << "Received response for: " << event_id;
TraceMe activity("GrpcTpuStream::RequestComplete");
if (entry.has_transfer_from()) {
@ -805,8 +834,15 @@ std::unique_ptr<LoadedProgramHandle> GrpcTpuStream::LoadProgram(
InitializeRequest(req.get(), wait_for);
TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram"));
req->mutable_load()->set_core_id(core_id);
req->mutable_load()->set_compiled_program_handle(
static_cast<const GrpcCompiledProgramHandle*>(handle)->id().AsInt());
auto grpc_handle = static_cast<const GrpcCompiledProgramHandle*>(handle);
if (grpc_handle->id().client_id != driver_->client_id()) {
auto event = absl::make_unique<ErrorEvent>(
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
"you restart the server or use a stale handle?"));
return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
std::move(event));
}
req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt());
auto event =
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
AddWriteRequest(std::move(req));
@ -835,13 +871,33 @@ std::unique_ptr<Event> GrpcTpuStream::ExecuteProgram(
absl::Span<Event* const> wait_for) {
auto req = absl::make_unique<StreamRequest::Entry>();
InitializeRequest(req.get(), wait_for);
req->mutable_execute()->set_loaded_program_handle(
static_cast<GrpcLoadedProgramHandle*>(program)->id().AsInt());
for (BufferHandle* input : inputs) {
req->mutable_execute()->add_input_handle(
static_cast<GrpcBufferHandle*>(input)->id().AsInt());
auto program_handle = static_cast<GrpcLoadedProgramHandle*>(program);
if (program_handle->id().client_id != driver_->client_id()) {
return absl::make_unique<ErrorEvent>(
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
"you restart the server or use a stale handle?"));
}
req->mutable_execute()->set_loaded_program_handle(
program_handle->id().AsInt());
for (BufferHandle* input : inputs) {
auto* grpc_handle = static_cast<GrpcBufferHandle*>(input);
if (grpc_handle->id().client_id != driver_->client_id()) {
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
"Invalid input buffer (wrong client id). Did you restart the server "
"or use a stale handle?"));
}
req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt());
}
for (BufferHandle* output : outputs) {
auto* grpc_handle = static_cast<GrpcBufferHandle*>(output);
if (grpc_handle->id().client_id != driver_->client_id()) {
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
"Invalid output buffer (wrong client id). Did you restart the server "
"or use a stale handle?"));
}
req->mutable_execute()->add_output_handle(
static_cast<GrpcBufferHandle*>(output)->id().AsInt());
}

View File

@ -137,11 +137,15 @@ message StreamResponse {
message OpenRequest {}
message OpenResponse {
required int32 client_id = 1;
required fixed32 client_id = 1;
// Maximum time this client can be idle before it is GC'ed and all resources
// released.
optional int32 max_idle_time_seconds = 2 [default = 3600];
}
message CloseRequest {
required int32 client_id = 1;
required fixed32 client_id = 1;
}
message CloseResponse {}