allow early termination of Profile Rpc for programmatic mode.
PiperOrigin-RevId: 311003274 Change-Id: I7e0e81c03aa96db6c272244316a53fab16fe3ebd
This commit is contained in:
parent
46e6af455a
commit
28bbc65d66
tensorflow/core/profiler
@ -10,6 +10,10 @@ import "tensorflow/core/profiler/profiler_service_monitor_result.proto";
|
||||
service ProfilerService {
|
||||
// Starts a profiling session, blocks until it completes, and returns data.
|
||||
rpc Profile(ProfileRequest) returns (ProfileResponse) {}
|
||||
// Signal to terminate the Profile rpc for a on-going profiling session,
|
||||
// The Profile rpc will return successfully and prematurely without timeout.
|
||||
// This is used by programmatic mode to end the session in workers.
|
||||
rpc Terminate(TerminateRequest) returns (TerminateResponse) {}
|
||||
// Collects profiling data and returns user-friendly metrics.
|
||||
rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
|
||||
}
|
||||
@ -81,6 +85,13 @@ message ProfileResponse {
|
||||
// next-field: 8
|
||||
}
|
||||
|
||||
message TerminateRequest {
|
||||
// Which session id to terminate.
|
||||
string session_id = 1;
|
||||
}
|
||||
|
||||
message TerminateResponse {}
|
||||
|
||||
message MonitorRequest {
|
||||
// Duration for which to profile between each update.
|
||||
uint64 duration_ms = 1;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
|
||||
|
||||
#include "grpcpp/support/status.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -24,9 +25,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/env_time.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
|
||||
#include "tensorflow/core/profiler/internal/profiler_interface.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
#include "tensorflow/core/profiler/profiler_service.pb.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -66,6 +70,11 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
||||
if (ctx->IsCancelled()) {
|
||||
return ::grpc::Status::CANCELLED;
|
||||
}
|
||||
if (TF_PREDICT_FALSE(IsStopped(req->session_id()))) {
|
||||
mutex_lock lock(mutex_);
|
||||
stop_signals_per_session_.erase(req->session_id());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
status = CollectDataToResponse(*req, profiler.get(), response);
|
||||
@ -76,6 +85,25 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
||||
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
|
||||
::grpc::Status Terminate(::grpc::ServerContext* ctx,
|
||||
const TerminateRequest* req,
|
||||
TerminateResponse* response) override {
|
||||
mutex_lock lock(mutex_);
|
||||
stop_signals_per_session_[req->session_id()] = true;
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsStopped(const std::string& session_id) {
|
||||
mutex_lock lock(mutex_);
|
||||
auto it = stop_signals_per_session_.find(session_id);
|
||||
return it != stop_signals_per_session_.end() && it->second;
|
||||
}
|
||||
|
||||
mutex mutex_;
|
||||
absl::flat_hash_map<std::string, bool> stop_signals_per_session_
|
||||
GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user