TpuDriver: Improve handling of restarted clients/servers.
PiperOrigin-RevId: 280727303 Change-Id: I9c11368de26d8ce851c799a53a903a452d385975
This commit is contained in:
parent
64ac71b1e9
commit
690d47e60b
@ -31,8 +31,8 @@ namespace tpu_driver {
|
||||
// This class provides a typed interface for these values as well as support for
|
||||
// hashing and ostreams (for logging).
|
||||
struct EventId {
|
||||
int64_t client_id;
|
||||
int64_t operation_id;
|
||||
uint64_t client_id;
|
||||
uint64_t operation_id;
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const EventId& c) {
|
||||
@ -51,9 +51,9 @@ struct EventId {
|
||||
return absl::StrCat(client_id, ":", operation_id);
|
||||
}
|
||||
|
||||
int64_t AsInt() const { return client_id << 44 | operation_id; }
|
||||
uint64_t AsInt() const { return client_id << 44 | operation_id; }
|
||||
|
||||
static EventId FromInt(int64_t value) {
|
||||
static EventId FromInt(uint64_t value) {
|
||||
return EventId{value >> 44, value & 0xfffffffffff};
|
||||
}
|
||||
};
|
||||
|
@ -64,6 +64,25 @@ class GrpcEvent : public Event {
|
||||
GrpcTpuStream* stream_;
|
||||
};
|
||||
|
||||
class ErrorEvent : public GrpcEvent {
|
||||
public:
|
||||
explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) {
|
||||
status_ = status;
|
||||
}
|
||||
|
||||
xla::Status Await() override { return status_; }
|
||||
absl::optional<xla::Status> AwaitWithTimeout(
|
||||
absl::Duration duration) override {
|
||||
return status_;
|
||||
}
|
||||
void AddCallback(std::function<void(Status)> callback) override {
|
||||
callback(status_);
|
||||
}
|
||||
|
||||
private:
|
||||
Status status_;
|
||||
};
|
||||
|
||||
class GrpcBufferHandle : public BufferHandle {
|
||||
public:
|
||||
explicit GrpcBufferHandle(
|
||||
@ -417,17 +436,19 @@ class GrpcTpuDriver : public TpuDriver {
|
||||
static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
|
||||
const TpuDriverConfig& config);
|
||||
|
||||
uint32 client_id() const { return client_id_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
||||
|
||||
const TpuDriverConfig config_;
|
||||
const int32_t client_id_;
|
||||
const uint32_t client_id_;
|
||||
// Map from stream IDs to streams.
|
||||
absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
|
||||
std::unique_ptr<GrpcTpuStream> host_stream_;
|
||||
// Shared by all streams.
|
||||
std::atomic<int64_t> operation_id_{0};
|
||||
};
|
||||
std::atomic<uint64_t> operation_id_{0};
|
||||
}; // namespace
|
||||
|
||||
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
||||
|
||||
@ -464,8 +485,11 @@ GrpcTpuStream::~GrpcTpuStream() {
|
||||
// Mark all remaining events invalid.
|
||||
absl::MutexLock lock(&events_mutex_);
|
||||
for (auto e : events_) {
|
||||
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
|
||||
"Tpustream was closed."));
|
||||
if (!e.second.done) {
|
||||
LOG(ERROR) << "Resetting: " << e.first;
|
||||
UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
|
||||
"Driver was closed."));
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Closing stream.";
|
||||
@ -511,8 +535,9 @@ void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) {
|
||||
|
||||
// This is the first time this event finishes. Remember the results and call
|
||||
// the callbacks.
|
||||
VLOG(1) << "Response received for GrpcEvent " << id << ". Firing "
|
||||
<< it->second.callbacks.size() << " callbacks.";
|
||||
VLOG(1) << "Response received for GrpcEvent " << id << ". "
|
||||
<< status.ToString() << ". Firing " << it->second.callbacks.size()
|
||||
<< " callbacks.";
|
||||
it->second.done = true;
|
||||
it->second.status = status;
|
||||
for (const auto& callback : it->second.callbacks) {
|
||||
@ -544,6 +569,7 @@ absl::optional<Status> GrpcTpuStream::WaitForEvent(EventId id,
|
||||
events_mutex_.AssertHeld();
|
||||
return !events_.contains(id) || events_[id].done;
|
||||
};
|
||||
|
||||
if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
|
||||
return events_.contains(id) ? events_[id].status : Status();
|
||||
}
|
||||
@ -594,6 +620,8 @@ void GrpcTpuStream::StreamWriterFn() {
|
||||
reqs.push_back(StreamRequest());
|
||||
request_bytes = 0;
|
||||
}
|
||||
VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id());
|
||||
VLOG(2) << "Sending request: " << e->DebugString();
|
||||
reqs.back().mutable_entry()->AddAllocated(e);
|
||||
}
|
||||
num_pending_requests_ = 0;
|
||||
@ -611,9 +639,10 @@ void GrpcTpuStream::StreamWriterFn() {
|
||||
void GrpcTpuStream::StreamReaderFn() {
|
||||
StreamResponse resp;
|
||||
while (stream_->Read(&resp)) {
|
||||
VLOG(1) << "Received response: " << resp.DebugString();
|
||||
VLOG(2) << "Received response: " << resp.DebugString();
|
||||
for (const StreamResponse::Entry entry : resp.entry()) {
|
||||
EventId event_id = EventId::FromInt(entry.operation_id());
|
||||
VLOG(1) << "Received response for: " << event_id;
|
||||
|
||||
TraceMe activity("GrpcTpuStream::RequestComplete");
|
||||
if (entry.has_transfer_from()) {
|
||||
@ -805,8 +834,15 @@ std::unique_ptr<LoadedProgramHandle> GrpcTpuStream::LoadProgram(
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram"));
|
||||
req->mutable_load()->set_core_id(core_id);
|
||||
req->mutable_load()->set_compiled_program_handle(
|
||||
static_cast<const GrpcCompiledProgramHandle*>(handle)->id().AsInt());
|
||||
auto grpc_handle = static_cast<const GrpcCompiledProgramHandle*>(handle);
|
||||
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||
auto event = absl::make_unique<ErrorEvent>(
|
||||
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
|
||||
"you restart the server or use a stale handle?"));
|
||||
return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
|
||||
std::move(event));
|
||||
}
|
||||
req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt());
|
||||
auto event =
|
||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||
AddWriteRequest(std::move(req));
|
||||
@ -835,13 +871,33 @@ std::unique_ptr<Event> GrpcTpuStream::ExecuteProgram(
|
||||
absl::Span<Event* const> wait_for) {
|
||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
req->mutable_execute()->set_loaded_program_handle(
|
||||
static_cast<GrpcLoadedProgramHandle*>(program)->id().AsInt());
|
||||
for (BufferHandle* input : inputs) {
|
||||
req->mutable_execute()->add_input_handle(
|
||||
static_cast<GrpcBufferHandle*>(input)->id().AsInt());
|
||||
auto program_handle = static_cast<GrpcLoadedProgramHandle*>(program);
|
||||
if (program_handle->id().client_id != driver_->client_id()) {
|
||||
return absl::make_unique<ErrorEvent>(
|
||||
xla::InvalidArgument("Invalid program handle (wrong client id). Did "
|
||||
"you restart the server or use a stale handle?"));
|
||||
}
|
||||
|
||||
req->mutable_execute()->set_loaded_program_handle(
|
||||
program_handle->id().AsInt());
|
||||
|
||||
for (BufferHandle* input : inputs) {
|
||||
auto* grpc_handle = static_cast<GrpcBufferHandle*>(input);
|
||||
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
|
||||
"Invalid input buffer (wrong client id). Did you restart the server "
|
||||
"or use a stale handle?"));
|
||||
}
|
||||
req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt());
|
||||
}
|
||||
|
||||
for (BufferHandle* output : outputs) {
|
||||
auto* grpc_handle = static_cast<GrpcBufferHandle*>(output);
|
||||
if (grpc_handle->id().client_id != driver_->client_id()) {
|
||||
return absl::make_unique<ErrorEvent>(xla::InvalidArgument(
|
||||
"Invalid output buffer (wrong client id). Did you restart the server "
|
||||
"or use a stale handle?"));
|
||||
}
|
||||
req->mutable_execute()->add_output_handle(
|
||||
static_cast<GrpcBufferHandle*>(output)->id().AsInt());
|
||||
}
|
||||
|
@ -137,11 +137,15 @@ message StreamResponse {
|
||||
message OpenRequest {}
|
||||
|
||||
message OpenResponse {
|
||||
required int32 client_id = 1;
|
||||
required fixed32 client_id = 1;
|
||||
|
||||
// Maximum time this client can be idle before it is GC'ed and all resources
|
||||
// released.
|
||||
optional int32 max_idle_time_seconds = 2 [default = 3600];
|
||||
}
|
||||
|
||||
message CloseRequest {
|
||||
required int32 client_id = 1;
|
||||
required fixed32 client_id = 1;
|
||||
}
|
||||
|
||||
message CloseResponse {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user