diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto index 37ca4084e42..a096a10efe2 100644 --- a/tensorflow/core/profiler/profiler_service.proto +++ b/tensorflow/core/profiler/profiler_service.proto @@ -10,6 +10,10 @@ import "tensorflow/core/profiler/profiler_service_monitor_result.proto"; service ProfilerService { // Starts a profiling session, blocks until it completes, and returns data. rpc Profile(ProfileRequest) returns (ProfileResponse) {} + // Signal to terminate the Profile rpc for a on-going profiling session, + // The Profile rpc will return successfully and prematurely without timeout. + // This is used by programmatic mode to end the session in workers. + rpc Terminate(TerminateRequest) returns (TerminateResponse) {} // Collects profiling data and returns user-friendly metrics. rpc Monitor(MonitorRequest) returns (MonitorResponse) {} } @@ -81,6 +85,13 @@ message ProfileResponse { // next-field: 8 } +message TerminateRequest { + // Which session id to terminate. + string session_id = 1; +} + +message TerminateResponse {} + message MonitorRequest { // Duration for which to profile between each update. uint64 duration_ms = 1; diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index 555f4c3366a..8cf052f165b 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" #include "grpcpp/support/status.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -24,9 +25,12 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/lib/profiler_session.h" +#include "tensorflow/core/profiler/profiler_service.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { @@ -66,6 +70,11 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { if (ctx->IsCancelled()) { return ::grpc::Status::CANCELLED; } + if (TF_PREDICT_FALSE(IsStopped(req->session_id()))) { + mutex_lock lock(mutex_); + stop_signals_per_session_.erase(req->session_id()); + break; + } } status = CollectDataToResponse(*req, profiler.get(), response); @@ -76,6 +85,25 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { return ::grpc::Status::OK; } + + ::grpc::Status Terminate(::grpc::ServerContext* ctx, + const TerminateRequest* req, + TerminateResponse* response) override { + mutex_lock lock(mutex_); + stop_signals_per_session_[req->session_id()] = true; + return ::grpc::Status::OK; + } + + private: + bool IsStopped(const std::string& session_id) { + mutex_lock lock(mutex_); + auto it = stop_signals_per_session_.find(session_id); + return it != stop_signals_per_session_.end() && it->second; + } + + mutex mutex_; + absl::flat_hash_map stop_signals_per_session_ + GUARDED_BY(mutex_); }; } // namespace