allow early termination of Profile Rpc for programmatic mode.

PiperOrigin-RevId: 311003274
Change-Id: I7e0e81c03aa96db6c272244316a53fab16fe3ebd
This commit is contained in:
A. Unique TensorFlower 2020-05-11 15:11:46 -07:00 committed by TensorFlower Gardener
parent 46e6af455a
commit 28bbc65d66
2 changed files with 39 additions and 0 deletions

View File

@ -10,6 +10,10 @@ import "tensorflow/core/profiler/profiler_service_monitor_result.proto";
service ProfilerService {
// Starts a profiling session, blocks until it completes, and returns data.
rpc Profile(ProfileRequest) returns (ProfileResponse) {}
// Signal to terminate the Profile rpc for a on-going profiling session,
// The Profile rpc will return successfully and prematurely without timeout.
// This is used by programmatic mode to end the session in workers.
rpc Terminate(TerminateRequest) returns (TerminateResponse) {}
// Collects profiling data and returns user-friendly metrics.
rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
}
@ -81,6 +85,13 @@ message ProfileResponse {
// next-field: 8
}
message TerminateRequest {
// Which session id to terminate.
string session_id = 1;
}
message TerminateResponse {}
message MonitorRequest {
// Duration for which to profile between each update.
uint64 duration_ms = 1;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
#include "grpcpp/support/status.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@ -24,9 +25,12 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
@ -66,6 +70,11 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
if (ctx->IsCancelled()) {
return ::grpc::Status::CANCELLED;
}
if (TF_PREDICT_FALSE(IsStopped(req->session_id()))) {
mutex_lock lock(mutex_);
stop_signals_per_session_.erase(req->session_id());
break;
}
}
status = CollectDataToResponse(*req, profiler.get(), response);
@ -76,6 +85,25 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
return ::grpc::Status::OK;
}
::grpc::Status Terminate(::grpc::ServerContext* ctx,
const TerminateRequest* req,
TerminateResponse* response) override {
mutex_lock lock(mutex_);
stop_signals_per_session_[req->session_id()] = true;
return ::grpc::Status::OK;
}
private:
bool IsStopped(const std::string& session_id) {
mutex_lock lock(mutex_);
auto it = stop_signals_per_session_.find(session_id);
return it != stop_signals_per_session_.end() && it->second;
}
mutex mutex_;
absl::flat_hash_map<std::string, bool> stop_signals_per_session_
GUARDED_BY(mutex_);
};
} // namespace