diff --git a/tensorflow/compiler/xla/python/tpu_driver/event_id.h b/tensorflow/compiler/xla/python/tpu_driver/event_id.h index ed5f9c87cf0..92a273c3e88 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/event_id.h +++ b/tensorflow/compiler/xla/python/tpu_driver/event_id.h @@ -31,8 +31,8 @@ namespace tpu_driver { // This class provides a typed interface for these values as well as support for // hashing and ostreams (for logging). struct EventId { - int64_t client_id; - int64_t operation_id; + uint64_t client_id; + uint64_t operation_id; template friend H AbslHashValue(H h, const EventId& c) { @@ -51,9 +51,9 @@ struct EventId { return absl::StrCat(client_id, ":", operation_id); } - int64_t AsInt() const { return client_id << 44 | operation_id; } + uint64_t AsInt() const { return client_id << 44 | operation_id; } - static EventId FromInt(int64_t value) { + static EventId FromInt(uint64_t value) { return EventId{value >> 44, value & 0xfffffffffff}; } }; diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index 842b83299ae..2cbeffd62c6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -64,6 +64,25 @@ class GrpcEvent : public Event { GrpcTpuStream* stream_; }; +class ErrorEvent : public GrpcEvent { + public: + explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) { + status_ = status; + } + + xla::Status Await() override { return status_; } + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + return status_; + } + void AddCallback(std::function callback) override { + callback(status_); + } + + private: + Status status_; +}; + class GrpcBufferHandle : public BufferHandle { public: explicit GrpcBufferHandle( @@ -417,17 +436,19 @@ class GrpcTpuDriver : public TpuDriver { static std::unique_ptr CreateTpuDriverStub( const TpuDriverConfig& config); + uint32 client_id() const { return client_id_; } + private: std::unique_ptr AllocateStream(int32_t core_id); const TpuDriverConfig config_; - const int32_t client_id_; + const uint32_t client_id_; // Map from stream IDs to streams. absl::flat_hash_map> streams_; std::unique_ptr host_stream_; // Shared by all streams. - std::atomic operation_id_{0}; -}; + std::atomic operation_id_{0}; +}; // namespace GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); } @@ -464,8 +485,11 @@ GrpcTpuStream::~GrpcTpuStream() { // Mark all remaining events invalid. absl::MutexLock lock(&events_mutex_); for (auto e : events_) { - UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED, - "Tpustream was closed.")); + if (!e.second.done) { + LOG(ERROR) << "Resetting: " << e.first; + UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED, + "Driver was closed.")); + } } } VLOG(1) << "Closing stream."; @@ -511,8 +535,9 @@ void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) { // This is the first time this event finishes. Remember the results and call // the callbacks. - VLOG(1) << "Response received for GrpcEvent " << id << ". Firing " - << it->second.callbacks.size() << " callbacks."; + VLOG(1) << "Response received for GrpcEvent " << id << ". " + << status.ToString() << ". Firing " << it->second.callbacks.size() + << " callbacks."; it->second.done = true; it->second.status = status; for (const auto& callback : it->second.callbacks) { @@ -544,6 +569,7 @@ absl::optional GrpcTpuStream::WaitForEvent(EventId id, events_mutex_.AssertHeld(); return !events_.contains(id) || events_[id].done; }; + if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) { return events_.contains(id) ? events_[id].status : Status(); } @@ -594,6 +620,8 @@ void GrpcTpuStream::StreamWriterFn() { reqs.push_back(StreamRequest()); request_bytes = 0; } + VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id()); + VLOG(2) << "Sending request: " << e->DebugString(); reqs.back().mutable_entry()->AddAllocated(e); } num_pending_requests_ = 0; @@ -611,9 +639,10 @@ void GrpcTpuStream::StreamWriterFn() { void GrpcTpuStream::StreamReaderFn() { StreamResponse resp; while (stream_->Read(&resp)) { - VLOG(1) << "Received response: " << resp.DebugString(); + VLOG(2) << "Received response: " << resp.DebugString(); for (const StreamResponse::Entry entry : resp.entry()) { EventId event_id = EventId::FromInt(entry.operation_id()); + VLOG(1) << "Received response for: " << event_id; TraceMe activity("GrpcTpuStream::RequestComplete"); if (entry.has_transfer_from()) { @@ -805,8 +834,15 @@ std::unique_ptr GrpcTpuStream::LoadProgram( InitializeRequest(req.get(), wait_for); TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram")); req->mutable_load()->set_core_id(core_id); - req->mutable_load()->set_compiled_program_handle( - static_cast(handle)->id().AsInt()); + auto grpc_handle = static_cast(handle); + if (grpc_handle->id().client_id != driver_->client_id()) { + auto event = absl::make_unique( + xla::InvalidArgument("Invalid program handle (wrong client id). Did " + "you restart the server or use a stale handle?")); + return absl::make_unique(event->id(), + std::move(event)); + } + req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt()); auto event = absl::make_unique(EventId::FromInt(req->operation_id()), this); AddWriteRequest(std::move(req)); @@ -835,13 +871,33 @@ std::unique_ptr GrpcTpuStream::ExecuteProgram( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - req->mutable_execute()->set_loaded_program_handle( - static_cast(program)->id().AsInt()); - for (BufferHandle* input : inputs) { - req->mutable_execute()->add_input_handle( - static_cast(input)->id().AsInt()); + auto program_handle = static_cast(program); + if (program_handle->id().client_id != driver_->client_id()) { + return absl::make_unique( + xla::InvalidArgument("Invalid program handle (wrong client id). Did " + "you restart the server or use a stale handle?")); } + + req->mutable_execute()->set_loaded_program_handle( + program_handle->id().AsInt()); + + for (BufferHandle* input : inputs) { + auto* grpc_handle = static_cast(input); + if (grpc_handle->id().client_id != driver_->client_id()) { + return absl::make_unique(xla::InvalidArgument( + "Invalid input buffer (wrong client id). Did you restart the server " + "or use a stale handle?")); + } + req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt()); + } + for (BufferHandle* output : outputs) { + auto* grpc_handle = static_cast(output); + if (grpc_handle->id().client_id != driver_->client_id()) { + return absl::make_unique(xla::InvalidArgument( + "Invalid output buffer (wrong client id). Did you restart the server " + "or use a stale handle?")); + } req->mutable_execute()->add_output_handle( static_cast(output)->id().AsInt()); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto b/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto index 3b9b69e7cb4..9ad1c54d912 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto @@ -137,11 +137,15 @@ message StreamResponse { message OpenRequest {} message OpenResponse { - required int32 client_id = 1; + required fixed32 client_id = 1; + + // Maximum time this client can be idle before it is GC'ed and all resources + // released. + optional int32 max_idle_time_seconds = 2 [default = 3600]; } message CloseRequest { - required int32 client_id = 1; + required fixed32 client_id = 1; } message CloseResponse {}