Added multi-worker sampling mode.
* Profile() now calls RemoteProfilerSessionManager (RPSM). * Profiler server saves XSpace to repository instead of returning XSpace by ProfileResponse. * ProfileResponse will have a non-empty trace iff there are any XPlanes. PiperOrigin-RevId: 336807134 Change-Id: I664400c9715d73605b9daa4cfdcf6475dee4e959
This commit is contained in:
parent
78ba66c122
commit
8813a286b8
@ -47,8 +47,11 @@ cc_library(
|
|||||||
"//tensorflow/core/profiler:profiler_service_proto_cc",
|
"//tensorflow/core/profiler:profiler_service_proto_cc",
|
||||||
"//tensorflow/core/profiler/lib:profiler_session",
|
"//tensorflow/core/profiler/lib:profiler_session",
|
||||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||||
|
"//tensorflow/core/profiler/utils:file_system_utils",
|
||||||
|
"//tensorflow/core/profiler/utils:xplane_utils",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":profiler_client_for_pybind",
|
":profiler_client_for_pybind",
|
||||||
|
":remote_profiler_session_manager",
|
||||||
":save_profile",
|
":save_profile",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
@ -133,14 +134,11 @@ cc_library(
|
|||||||
srcs = ["remote_profiler_session_manager.cc"],
|
srcs = ["remote_profiler_session_manager.cc"],
|
||||||
hdrs = ["remote_profiler_session_manager.h"],
|
hdrs = ["remote_profiler_session_manager.h"],
|
||||||
copts = tf_profiler_copts(),
|
copts = tf_profiler_copts(),
|
||||||
visibility = ["//tensorflow/core/profiler:internal"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":profiler_client_for_pybind",
|
":profiler_client_for_pybind",
|
||||||
":save_profile",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler:profiler_options_proto_cc",
|
"//tensorflow/core/profiler:profiler_options_proto_cc",
|
||||||
"//tensorflow/core/profiler/lib:profiler_session",
|
|
||||||
"//tensorflow/core/profiler/utils:time_utils",
|
"//tensorflow/core/profiler/utils:time_utils",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -30,12 +30,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/profiler/profiler_options.pb.h"
|
#include "tensorflow/core/profiler/profiler_options.pb.h"
|
||||||
#include "tensorflow/core/profiler/profiler_service.pb.h"
|
#include "tensorflow/core/profiler/profiler_service.pb.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
|
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
|
||||||
|
#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
|
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace profiler {
|
namespace profiler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::tensorflow::profiler::RemoteProfilerSessionManager;
|
||||||
|
using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response;
|
||||||
|
|
||||||
constexpr uint64 kMaxEvents = 1000000;
|
constexpr uint64 kMaxEvents = 1000000;
|
||||||
const absl::string_view kXPlanePb = "xplane.pb";
|
const absl::string_view kXPlanePb = "xplane.pb";
|
||||||
|
|
||||||
@ -48,17 +52,18 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
|
|||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
ProfileRequest PopulateProfileRequest(int duration_ms,
|
ProfileRequest PopulateProfileRequest(
|
||||||
const std::string& repository_root,
|
absl::string_view repository_root, absl::string_view session_id,
|
||||||
const std::string& session_id,
|
absl::string_view host_name,
|
||||||
const std::string& host_name,
|
const RemoteProfilerSessionManagerOptions& options) {
|
||||||
const ProfileOptions& opts) {
|
|
||||||
ProfileRequest request;
|
ProfileRequest request;
|
||||||
request.set_duration_ms(duration_ms);
|
// TODO(b/169976117) Remove duration from request.
|
||||||
|
request.set_duration_ms(options.profiler_options().duration_ms());
|
||||||
request.set_max_events(kMaxEvents);
|
request.set_max_events(kMaxEvents);
|
||||||
request.set_repository_root(repository_root);
|
request.set_repository_root(repository_root.data(), repository_root.size());
|
||||||
request.set_session_id(session_id);
|
request.set_session_id(session_id.data(), session_id.size());
|
||||||
request.set_host_name(host_name);
|
request.set_host_name(host_name.data(), host_name.size());
|
||||||
|
// These tools are only used by TPU profiler.
|
||||||
request.add_tools("trace_viewer");
|
request.add_tools("trace_viewer");
|
||||||
request.add_tools("op_profile");
|
request.add_tools("op_profile");
|
||||||
request.add_tools("input_pipeline");
|
request.add_tools("input_pipeline");
|
||||||
@ -68,21 +73,26 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
|
|||||||
request.add_tools("overview_page");
|
request.add_tools("overview_page");
|
||||||
request.add_tools("pod_viewer");
|
request.add_tools("pod_viewer");
|
||||||
request.add_tools("tensorflow_stats");
|
request.add_tools("tensorflow_stats");
|
||||||
*request.mutable_opts() = opts;
|
// XPlane tool is only used by OSS profiler and safely ignored by TPU
|
||||||
|
// profiler.
|
||||||
|
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
|
||||||
|
*request.mutable_opts() = options.profiler_options();
|
||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
NewProfileSessionRequest PopulateNewProfileSessionRequest(
|
NewProfileSessionRequest PopulateNewProfileSessionRequest(
|
||||||
const std::string& service_addr, const std::string& repository_root,
|
absl::string_view repository_root, absl::string_view session_id,
|
||||||
const std::vector<string>& hostnames, int duration_ms,
|
const RemoteProfilerSessionManagerOptions& opts) {
|
||||||
const std::string& session_id, const ProfileOptions& opts) {
|
|
||||||
NewProfileSessionRequest request;
|
NewProfileSessionRequest request;
|
||||||
std::vector<std::string> parts = absl::StrSplit(service_addr, ':');
|
std::vector<absl::string_view> parts =
|
||||||
*request.mutable_request() = PopulateProfileRequest(
|
absl::StrSplit(opts.service_addresses(0), ':');
|
||||||
duration_ms, repository_root, session_id, parts[0], opts);
|
DCHECK(!parts.empty());
|
||||||
request.set_repository_root(repository_root);
|
|
||||||
request.set_session_id(session_id);
|
*request.mutable_request() =
|
||||||
for (const auto& hostname : hostnames) {
|
PopulateProfileRequest(repository_root, session_id, parts[0], opts);
|
||||||
|
request.set_repository_root(repository_root.data(), repository_root.size());
|
||||||
|
request.set_session_id(session_id.data(), session_id.size());
|
||||||
|
for (const auto& hostname : opts.service_addresses()) {
|
||||||
request.add_hosts(hostname);
|
request.add_hosts(hostname);
|
||||||
}
|
}
|
||||||
return request;
|
return request;
|
||||||
@ -99,44 +109,40 @@ inline bool ShouldRetryTracing(Status status) {
|
|||||||
status.error_message() == "Stream removed");
|
status.error_message() == "Stream removed");
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the ProfileResponse has single 'xplane.pb' tool, convert the xplane to
|
Status Profile(const std::string& repository_root,
|
||||||
// other tools and add in ProfileResponse. Otherwise, the ProfileResponse is
|
const std::string& session_id,
|
||||||
// already converted, simply return.
|
const RemoteProfilerSessionManagerOptions& opts) {
|
||||||
Status ConvertXSpaceToToolsInProfileResponse(const ProfileRequest& request,
|
Status status;
|
||||||
ProfileResponse* response) {
|
// Host name will be overwritten by RemoteProfilerSessionManager later.
|
||||||
if (response->tool_data_size() != 1) return Status::OK();
|
ProfileRequest request = PopulateProfileRequest(repository_root, session_id,
|
||||||
if (response->tool_data(0).name() != kXPlanePb) return Status::OK();
|
/*host_name=*/"", opts);
|
||||||
XSpace xspace;
|
auto session = RemoteProfilerSessionManager::Create(opts, request, status);
|
||||||
xspace.ParseFromString(response->tool_data(0).data());
|
TF_RETURN_IF_ERROR(status);
|
||||||
TF_RETURN_IF_ERROR(ConvertXSpaceToProfileResponse(xspace, request, response));
|
// Expect one or more service addresses.
|
||||||
return Status::OK();
|
DCHECK_GT(opts.service_addresses_size(), 0);
|
||||||
}
|
std::vector<Response> responses = session->WaitForCompletion();
|
||||||
|
// Expect responses to have the same size as clients.
|
||||||
|
DCHECK_EQ(responses.size(), opts.service_addresses_size());
|
||||||
|
|
||||||
Status Profile(const std::string& service_addr,
|
bool has_trace_data = false;
|
||||||
const std::string& repository_root, int duration_ms,
|
for (const auto& client_response : responses) {
|
||||||
const std::string& session_id, const ProfileOptions& opts) {
|
ProfileResponse& response = *client_response.profile_response;
|
||||||
std::vector<std::string> parts = absl::StrSplit(service_addr, ':');
|
if (response.empty_trace()) {
|
||||||
ProfileRequest request = PopulateProfileRequest(duration_ms, repository_root,
|
LOG(WARNING) << "No trace event is collected from "
|
||||||
session_id, parts[0], opts);
|
<< client_response.service_address;
|
||||||
ProfileResponse response;
|
} else {
|
||||||
TF_RETURN_IF_ERROR(ProfileGrpc(service_addr, request, &response));
|
has_trace_data = true;
|
||||||
|
}
|
||||||
if (!response.empty_trace()) {
|
if (!client_response.status.ok()) {
|
||||||
TF_RETURN_IF_ERROR(
|
LOG(WARNING) << client_response.service_address << " returned "
|
||||||
ConvertXSpaceToToolsInProfileResponse(request, &response));
|
<< client_response.status;
|
||||||
TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id,
|
}
|
||||||
request.host_name(), response, &std::cout));
|
|
||||||
// Print this at the end so that it's not buried in irrelevant LOG messages.
|
|
||||||
std::cout
|
|
||||||
<< "NOTE: using the trace duration " << duration_ms << "ms.\n"
|
|
||||||
<< "Set an appropriate duration (with --duration_ms) if you "
|
|
||||||
"don't see a full step in your trace or the captured trace is too "
|
|
||||||
"large."
|
|
||||||
<< std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (response.empty_trace()) {
|
if (!has_trace_data) {
|
||||||
return Status(error::Code::UNAVAILABLE, "No trace event is collected");
|
return Status(error::Code::UNAVAILABLE,
|
||||||
|
"No trace event was collected because there were no responses"
|
||||||
|
" from clients or the responses did not have trace data.");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -144,52 +150,47 @@ Status Profile(const std::string& service_addr,
|
|||||||
// Start a new profiling session that include all the hosts included in
|
// Start a new profiling session that include all the hosts included in
|
||||||
// hostnames, for the time interval of duration_ms. Possibly save the profiling
|
// hostnames, for the time interval of duration_ms. Possibly save the profiling
|
||||||
// result in the directory specified by repository_root and session_id.
|
// result in the directory specified by repository_root and session_id.
|
||||||
Status NewSession(const std::string& service_addr,
|
Status NewSession(absl::string_view repository_root,
|
||||||
const std::string& repository_root,
|
absl::string_view session_id,
|
||||||
const std::vector<string>& hostnames, int duration_ms,
|
const RemoteProfilerSessionManagerOptions& opts) {
|
||||||
const std::string& session_id, const ProfileOptions& opts) {
|
NewProfileSessionRequest request =
|
||||||
NewProfileSessionRequest request = PopulateNewProfileSessionRequest(
|
PopulateNewProfileSessionRequest(repository_root, session_id, opts);
|
||||||
service_addr, repository_root, hostnames, duration_ms, session_id, opts);
|
|
||||||
NewProfileSessionResponse response;
|
NewProfileSessionResponse response;
|
||||||
TF_RETURN_IF_ERROR(NewSessionGrpc(service_addr, request, &response));
|
TF_RETURN_IF_ERROR(
|
||||||
|
NewSessionGrpc(opts.service_addresses(0), request, &response));
|
||||||
|
|
||||||
std::cout << "Profile session succeed for host(s):"
|
std::cout << "Profile session succeed for host(s):"
|
||||||
<< absl::StrJoin(hostnames, ",") << std::endl;
|
<< absl::StrJoin(opts.service_addresses(), ",") << std::endl;
|
||||||
if (response.empty_trace()) {
|
if (response.empty_trace()) {
|
||||||
return Status(error::Code::UNAVAILABLE, "No trace event is collected");
|
return errors::Unavailable("No trace event is collected");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Starts tracing on a single or multiple hosts and saves the result in the
|
Status Trace(const std::string& logdir, int num_tracing_attempts,
|
||||||
// given logdir. If no trace was collected, retries tracing for
|
const RemoteProfilerSessionManagerOptions& opts,
|
||||||
// num_tracing_attempts.
|
bool is_cloud_tpu_session) {
|
||||||
Status Trace(const std::string& service_addr, const std::string& logdir,
|
DCHECK_GT(opts.profiler_options().duration_ms(), 0);
|
||||||
const std::string& workers_list, int duration_ms,
|
DCHECK(!opts.service_addresses().empty());
|
||||||
int num_tracing_attempts, const ProfileOptions& opts) {
|
|
||||||
// Use the current timestamp as the run name.
|
// Use the current timestamp as the run name.
|
||||||
std::string session_id = GetCurrentTimeStampAsString();
|
std::string session_id = GetCurrentTimeStampAsString();
|
||||||
std::vector<std::string> hostnames;
|
std::string repository_root = GetTensorBoardProfilePluginDir(logdir);
|
||||||
if (!workers_list.empty()) {
|
auto duration_ms = opts.profiler_options().duration_ms();
|
||||||
hostnames = absl::StrSplit(workers_list, ',');
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
|
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
|
||||||
std::string repository_root =
|
|
||||||
profiler::GetTensorBoardProfilePluginDir(logdir);
|
|
||||||
|
|
||||||
Status status = Status::OK();
|
Status status;
|
||||||
int remaining_attempts = num_tracing_attempts;
|
int remaining_attempts = num_tracing_attempts;
|
||||||
while (true) {
|
while (true) {
|
||||||
std::cout << "Starting to trace for " << duration_ms << " ms. "
|
std::cout << "Starting to trace for " << duration_ms << " ms. "
|
||||||
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;
|
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;
|
||||||
if (hostnames.empty()) {
|
|
||||||
status =
|
if (is_cloud_tpu_session) {
|
||||||
Profile(service_addr, repository_root, duration_ms, session_id, opts);
|
status = NewSession(repository_root, session_id, opts);
|
||||||
} else {
|
} else {
|
||||||
status = NewSession(service_addr, repository_root, hostnames, duration_ms,
|
status = Profile(repository_root, session_id, opts);
|
||||||
session_id, opts);
|
|
||||||
}
|
}
|
||||||
if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
|
if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
|
||||||
break;
|
break;
|
||||||
@ -223,11 +224,10 @@ Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) {
|
|||||||
|
|
||||||
ProfileResponse response;
|
ProfileResponse response;
|
||||||
ProfileRequest request = PopulateProfileRequest(
|
ProfileRequest request = PopulateProfileRequest(
|
||||||
/*duration_ms=*/0, GetTensorBoardProfilePluginDir(logdir),
|
GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(),
|
||||||
GetCurrentTimeStampAsString(), port::Hostname(), /*opts=*/{});
|
port::Hostname(), /*options=*/{});
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertXSpaceToProfileResponse(xspace, request, &response));
|
ConvertXSpaceToProfileResponse(xspace, request, &response));
|
||||||
|
|
||||||
std::stringstream ss; // Record LOG messages.
|
std::stringstream ss; // Record LOG messages.
|
||||||
TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
|
TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
|
||||||
request.session_id(), request.host_name(),
|
request.session_id(), request.host_name(),
|
||||||
|
@ -36,12 +36,12 @@ Status Monitor(const std::string& service_addr, int duration_ms,
|
|||||||
int monitoring_level, bool display_timestamp,
|
int monitoring_level, bool display_timestamp,
|
||||||
std::string* result);
|
std::string* result);
|
||||||
|
|
||||||
// Starts tracing on a single or multiple hosts and saves the result in the
|
// Starts tracing on a single or multiple hosts. Each host will save the result
|
||||||
// given logdir. If no trace was collected, retries tracing for
|
// in the given logdir. If no trace was collected, retries tracing for
|
||||||
// num_tracing_attempts.
|
// num_tracing_attempts. Assumes that options have been validated.
|
||||||
Status Trace(const std::string& service_addr, const std::string& logdir,
|
Status Trace(const std::string& logdir, int num_tracing_attempts,
|
||||||
const std::string& workers_list, int duration_ms,
|
const RemoteProfilerSessionManagerOptions& opts,
|
||||||
int num_tracing_attempts, const ProfileOptions& opts);
|
bool is_cloud_tpu_session);
|
||||||
|
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -99,10 +99,11 @@ RemoteProfilerSession::RemoteProfilerSession(std::string service_address,
|
|||||||
service_address_(std::move(service_address)),
|
service_address_(std::move(service_address)),
|
||||||
stub_(CreateStub<grpc::ProfilerService>(service_address_)),
|
stub_(CreateStub<grpc::ProfilerService>(service_address_)),
|
||||||
deadline_(deadline),
|
deadline_(deadline),
|
||||||
profile_request_(std::move(profile_request)) {}
|
profile_request_(std::move(profile_request)) {
|
||||||
|
response_->set_empty_trace(true);
|
||||||
|
}
|
||||||
|
|
||||||
RemoteProfilerSession::~RemoteProfilerSession() {
|
RemoteProfilerSession::~RemoteProfilerSession() {
|
||||||
LOG(INFO) << "Waiting for completion.";
|
|
||||||
Status dummy;
|
Status dummy;
|
||||||
WaitForCompletion(dummy);
|
WaitForCompletion(dummy);
|
||||||
grpc_context_.TryCancel();
|
grpc_context_.TryCancel();
|
||||||
@ -113,6 +114,8 @@ void RemoteProfilerSession::ProfileAsync() {
|
|||||||
grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
|
grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
|
||||||
VLOG(1) << "Deadline set to " << deadline_;
|
VLOG(1) << "Deadline set to " << deadline_;
|
||||||
rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
|
rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
|
||||||
|
// Connection failure will create lame channel whereby grpc_status_ will be an
|
||||||
|
// error.
|
||||||
rpc_->Finish(response_.get(), &grpc_status_,
|
rpc_->Finish(response_.get(), &grpc_status_,
|
||||||
static_cast<void*>(&status_on_completion_));
|
static_cast<void*>(&status_on_completion_));
|
||||||
VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
|
VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
|
||||||
@ -125,6 +128,7 @@ std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
|
|||||||
"WaitForCompletion must only be called once.");
|
"WaitForCompletion must only be called once.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
LOG(INFO) << "Waiting for completion.";
|
||||||
|
|
||||||
void* got_tag = nullptr;
|
void* got_tag = nullptr;
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
|
@ -82,7 +82,7 @@ class RemoteProfilerSession {
|
|||||||
absl::Time deadline_;
|
absl::Time deadline_;
|
||||||
::grpc::ClientContext grpc_context_;
|
::grpc::ClientContext grpc_context_;
|
||||||
std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_;
|
std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_;
|
||||||
::grpc::Status grpc_status_;
|
::grpc::Status grpc_status_ = ::grpc::Status::OK;
|
||||||
|
|
||||||
// Asynchronous completion queue states.
|
// Asynchronous completion queue states.
|
||||||
::grpc::CompletionQueue cq_;
|
::grpc::CompletionQueue cq_;
|
||||||
|
@ -52,8 +52,10 @@ TEST(RemoteProfilerSession, Simple) {
|
|||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
// At end of session this evaluates to true still.
|
// At end of session this evaluates to true still.
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
EXPECT_FALSE(response->empty_trace());
|
// True because there was no workload traced and subsequently no XEvents.
|
||||||
EXPECT_GT(response->tool_data_size(), 0);
|
EXPECT_TRUE(response->empty_trace());
|
||||||
|
// XSpaces are serialized and not returned as tools in ProfileResponse.
|
||||||
|
EXPECT_EQ(response->tool_data_size(), 0);
|
||||||
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,8 +88,9 @@ TEST(RemoteProfilerSession, Timeout) {
|
|||||||
auto response = remote_session->WaitForCompletion(status);
|
auto response = remote_session->WaitForCompletion(status);
|
||||||
// At end of session we will have a timeout error.
|
// At end of session we will have a timeout error.
|
||||||
EXPECT_TRUE(errors::IsDeadlineExceeded(status));
|
EXPECT_TRUE(errors::IsDeadlineExceeded(status));
|
||||||
|
// True because there was no workload traced and subsequently no XEvents.
|
||||||
EXPECT_FALSE(response->empty_trace()); // This defaults to false.
|
EXPECT_TRUE(response->empty_trace());
|
||||||
|
// XSpaces are serialized and not returned as tools in ProfileResponse.
|
||||||
EXPECT_EQ(response->tool_data_size(), 0);
|
EXPECT_EQ(response->tool_data_size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,8 +112,10 @@ TEST(RemoteProfilerSession, LongDeadline) {
|
|||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
// At end of session this evaluates to true still.
|
// At end of session this evaluates to true still.
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
EXPECT_FALSE(response->empty_trace());
|
// True because there was no workload traced and subsequently no XEvents.
|
||||||
EXPECT_GT(response->tool_data_size(), 0);
|
EXPECT_TRUE(response->empty_trace());
|
||||||
|
// XSpaces are serialized and not returned as tools in ProfileResponse.
|
||||||
|
EXPECT_EQ(response->tool_data_size(), 0);
|
||||||
// Elapsed time is near profiling duration despite long grace period.
|
// Elapsed time is near profiling duration despite long grace period.
|
||||||
EXPECT_THAT(elapsed, DurationNear(duration));
|
EXPECT_THAT(elapsed, DurationNear(duration));
|
||||||
}
|
}
|
||||||
@ -134,8 +139,10 @@ TEST(RemoteProfilerSession, LongDuration) {
|
|||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
// At end of session this evaluates to true still.
|
// At end of session this evaluates to true still.
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
EXPECT_FALSE(response->empty_trace());
|
// True because there was no workload traced and subsequently no XEvents.
|
||||||
EXPECT_GT(response->tool_data_size(), 0);
|
EXPECT_TRUE(response->empty_trace());
|
||||||
|
// XSpaces are serialized and not returned as tools in ProfileResponse.
|
||||||
|
EXPECT_EQ(response->tool_data_size(), 0);
|
||||||
// Elapsed time takes longer to complete for larger traces.
|
// Elapsed time takes longer to complete for larger traces.
|
||||||
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
||||||
}
|
}
|
||||||
|
@ -37,14 +37,14 @@ namespace profiler {
|
|||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
inline std::unique_ptr<ProfilerServer> StartServer(
|
inline std::unique_ptr<ProfilerServer> StartServer(
|
||||||
absl::Duration duration, std::string* service_addresses,
|
absl::Duration duration, std::string* service_address,
|
||||||
ProfileRequest* request = nullptr) {
|
ProfileRequest* request = nullptr) {
|
||||||
auto profiler_server = absl::make_unique<ProfilerServer>();
|
auto profiler_server = absl::make_unique<ProfilerServer>();
|
||||||
int port = testing::PickUnusedPortOrDie();
|
int port = testing::PickUnusedPortOrDie();
|
||||||
profiler_server->StartProfilerServer(port);
|
profiler_server->StartProfilerServer(port);
|
||||||
|
|
||||||
DCHECK(service_addresses);
|
DCHECK(service_address);
|
||||||
*service_addresses = absl::StrCat("localhost:", port);
|
*service_address = absl::StrCat("localhost:", port);
|
||||||
|
|
||||||
if (request) {
|
if (request) {
|
||||||
request->set_duration_ms(absl::ToInt64Milliseconds(duration));
|
request->set_duration_ms(absl::ToInt64Milliseconds(duration));
|
||||||
@ -53,10 +53,11 @@ inline std::unique_ptr<ProfilerServer> StartServer(
|
|||||||
request->mutable_opts()->set_duration_ms(
|
request->mutable_opts()->set_duration_ms(
|
||||||
absl::ToInt64Milliseconds(duration));
|
absl::ToInt64Milliseconds(duration));
|
||||||
request->set_session_id("test_session");
|
request->set_session_id("test_session");
|
||||||
request->set_host_name(*service_addresses);
|
request->set_host_name(*service_address);
|
||||||
|
request->set_repository_root(testing::TmpDir());
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG(INFO) << "Started " << *service_addresses << " at " << absl::Now();
|
LOG(INFO) << "Started " << *service_address << " at " << absl::Now();
|
||||||
LOG(INFO) << "Duration: " << duration;
|
LOG(INFO) << "Duration: " << duration;
|
||||||
|
|
||||||
return profiler_server;
|
return profiler_server;
|
||||||
|
@ -26,47 +26,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
|
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
|
||||||
#include "tensorflow/core/profiler/utils/time_utils.h"
|
#include "tensorflow/core/profiler/utils/time_utils.h"
|
||||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace profiler {
|
namespace profiler {
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr uint64 kMaxEvents = 1000000;
|
|
||||||
|
|
||||||
// TODO(yisitu) merge with the implementation in capture_profile.
|
|
||||||
void PopulateProfileRequest(const RemoteProfilerSessionManagerOptions& options,
|
|
||||||
absl::string_view session_id,
|
|
||||||
absl::string_view host_name,
|
|
||||||
ProfileRequest* request) {
|
|
||||||
request->set_max_events(kMaxEvents);
|
|
||||||
request->set_repository_root(options.profiler_options().repository_path());
|
|
||||||
request->set_session_id(session_id.data(), session_id.size());
|
|
||||||
request->add_tools("trace_viewer");
|
|
||||||
request->add_tools("op_profile");
|
|
||||||
request->add_tools("input_pipeline");
|
|
||||||
request->add_tools("kernel_stats");
|
|
||||||
request->add_tools("memory_viewer");
|
|
||||||
request->add_tools("memory_profile");
|
|
||||||
request->add_tools("overview_page");
|
|
||||||
request->add_tools("pod_viewer");
|
|
||||||
request->add_tools("tensorflow_stats");
|
|
||||||
request->set_host_name(host_name.data(), host_name.size());
|
|
||||||
*request->mutable_opts() = options.profiler_options();
|
|
||||||
request->set_duration_ms(options.profiler_options().duration_ms());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
/*static*/ std::unique_ptr<RemoteProfilerSessionManager>
|
/*static*/ std::unique_ptr<RemoteProfilerSessionManager>
|
||||||
RemoteProfilerSessionManager::Create(
|
RemoteProfilerSessionManager::Create(
|
||||||
const RemoteProfilerSessionManagerOptions& options,
|
const RemoteProfilerSessionManagerOptions& options,
|
||||||
tensorflow::Status& out_status, AddressResolver resolver) {
|
const ProfileRequest& request, tensorflow::Status& out_status,
|
||||||
|
AddressResolver resolver) {
|
||||||
VLOG(1) << "Creating a RemoteProfilerSessionManager.";
|
VLOG(1) << "Creating a RemoteProfilerSessionManager.";
|
||||||
auto session_manager =
|
auto session_manager = absl::WrapUnique(
|
||||||
absl::WrapUnique(new RemoteProfilerSessionManager(options, resolver));
|
new RemoteProfilerSessionManager(options, request, resolver));
|
||||||
out_status = session_manager->Init();
|
out_status = session_manager->Init();
|
||||||
if (!out_status.ok()) {
|
if (!out_status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -75,8 +48,9 @@ RemoteProfilerSessionManager::Create(
|
|||||||
}
|
}
|
||||||
|
|
||||||
RemoteProfilerSessionManager::RemoteProfilerSessionManager(
|
RemoteProfilerSessionManager::RemoteProfilerSessionManager(
|
||||||
RemoteProfilerSessionManagerOptions options, AddressResolver resolver)
|
RemoteProfilerSessionManagerOptions options, ProfileRequest request,
|
||||||
: options_(std::move(options)) {
|
AddressResolver resolver)
|
||||||
|
: options_(std::move(options)), request_(std::move(request)) {
|
||||||
if (resolver) {
|
if (resolver) {
|
||||||
resolver_ = std::move(resolver);
|
resolver_ = std::move(resolver);
|
||||||
} else {
|
} else {
|
||||||
@ -91,14 +65,7 @@ RemoteProfilerSessionManager::~RemoteProfilerSessionManager() {
|
|||||||
Status RemoteProfilerSessionManager::Init() {
|
Status RemoteProfilerSessionManager::Init() {
|
||||||
mutex_lock lock(mutex_);
|
mutex_lock lock(mutex_);
|
||||||
VLOG(1) << "SessionManager initializing.";
|
VLOG(1) << "SessionManager initializing.";
|
||||||
// TODO(b/169482824) Move validation to call site.
|
|
||||||
Status status = ValidateOptionsLocked();
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << status;
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string session_id = GetCurrentTimeStampAsString();
|
|
||||||
const absl::Time session_created_ts =
|
const absl::Time session_created_ts =
|
||||||
absl::FromUnixNanos(options_.session_creation_timestamp_ns());
|
absl::FromUnixNanos(options_.session_creation_timestamp_ns());
|
||||||
const absl::Time deadline =
|
const absl::Time deadline =
|
||||||
@ -115,16 +82,14 @@ Status RemoteProfilerSessionManager::Init() {
|
|||||||
// Prepare a list of clients.
|
// Prepare a list of clients.
|
||||||
clients_.reserve(options_.service_addresses_size());
|
clients_.reserve(options_.service_addresses_size());
|
||||||
|
|
||||||
for (auto& service_addr : options_.service_addresses()) {
|
for (auto& service_address : options_.service_addresses()) {
|
||||||
std::string resolved_service_addr = resolver_(service_addr);
|
std::string resolved_service_address = resolver_(service_address);
|
||||||
ProfileRequest profile_request;
|
ProfileRequest request = request_;
|
||||||
PopulateProfileRequest(options_, session_id, resolved_service_addr,
|
request.set_host_name(resolved_service_address);
|
||||||
&profile_request);
|
|
||||||
|
|
||||||
// Creation also issues Profile RPC asynchronously.
|
// Creation also issues Profile RPC asynchronously.
|
||||||
auto client = RemoteProfilerSession::Create(
|
auto client = RemoteProfilerSession::Create(
|
||||||
std::move(resolved_service_addr), deadline, std::move(profile_request));
|
std::move(resolved_service_address), deadline, std::move(request));
|
||||||
|
|
||||||
clients_.push_back(std::move(client));
|
clients_.push_back(std::move(client));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,41 +97,18 @@ Status RemoteProfilerSessionManager::Init() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteProfilerSessionManager::ValidateOptionsLocked() {
|
|
||||||
if (!options_.service_addresses_size()) {
|
|
||||||
return errors::InvalidArgument("No service addresses specified.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options_.profiler_options().duration_ms() == 0) {
|
|
||||||
if (options_.max_session_duration_ms() != 0) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"If local profiler duration is unbounded, profiling session duration "
|
|
||||||
"must be unbounded.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options_.max_session_duration_ms() <
|
|
||||||
options_.profiler_options().duration_ms()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The maximum profiling session duration must be greater than or equal "
|
|
||||||
"to the local profiler duration.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<RemoteProfilerSessionManager::Response>
|
std::vector<RemoteProfilerSessionManager::Response>
|
||||||
RemoteProfilerSessionManager::WaitForCompletion() {
|
RemoteProfilerSessionManager::WaitForCompletion() {
|
||||||
mutex_lock lock(mutex_);
|
mutex_lock lock(mutex_);
|
||||||
std::vector<RemoteProfilerSessionManager::Response> remote_responses;
|
std::vector<RemoteProfilerSessionManager::Response> remote_responses(
|
||||||
remote_responses.reserve(clients_.size());
|
clients_.size());
|
||||||
|
|
||||||
for (auto& client : clients_) {
|
for (int32 idx = 0; idx < clients_.size(); ++idx) {
|
||||||
remote_responses.emplace_back();
|
auto& remote_response = remote_responses[idx];
|
||||||
auto* profile_response = &remote_responses.back().profile_response;
|
auto* client = clients_[idx].get();
|
||||||
Status& status = remote_responses.back().status;
|
remote_response.profile_response =
|
||||||
std::string* service_addr = &remote_responses.back().service_addr;
|
client->WaitForCompletion(remote_response.status);
|
||||||
*profile_response = client->WaitForCompletion(status);
|
remote_response.service_address = std::string(client->GetServiceAddress());
|
||||||
*service_addr = std::string(client->GetServiceAddress());
|
|
||||||
}
|
}
|
||||||
return remote_responses;
|
return remote_responses;
|
||||||
}
|
}
|
||||||
|
@ -26,9 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
|
||||||
#include "tensorflow/core/profiler/profiler_options.pb.h"
|
|
||||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
|
||||||
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
|
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -40,21 +37,16 @@ using AddressResolver = std::function<std::string(absl::string_view)>;
|
|||||||
class RemoteProfilerSessionManager {
|
class RemoteProfilerSessionManager {
|
||||||
public:
|
public:
|
||||||
struct Response {
|
struct Response {
|
||||||
std::string service_addr;
|
std::string service_address;
|
||||||
std::unique_ptr<ProfileResponse> profile_response;
|
std::unique_ptr<ProfileResponse> profile_response;
|
||||||
Status status;
|
Status status;
|
||||||
};
|
};
|
||||||
// Instantiates a collection of RemoteProfilerSessions starts profiling on
|
// Instantiates a collection of RemoteProfilerSessions starts profiling on
|
||||||
// each of them immediately.
|
// each of them immediately. Assumes that options have already been validated.
|
||||||
static std::unique_ptr<RemoteProfilerSessionManager> Create(
|
static std::unique_ptr<RemoteProfilerSessionManager> Create(
|
||||||
const RemoteProfilerSessionManagerOptions& options,
|
const RemoteProfilerSessionManagerOptions& options,
|
||||||
tensorflow::Status& out_status, AddressResolver resolver = nullptr);
|
const ProfileRequest& request, tensorflow::Status& out_status,
|
||||||
|
AddressResolver resolver = nullptr);
|
||||||
static RemoteProfilerSessionManagerOptions DefaultOptions() {
|
|
||||||
RemoteProfilerSessionManagerOptions options;
|
|
||||||
*options.mutable_profiler_options() = ProfilerSession::DefaultOptions();
|
|
||||||
return options;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Awaits for responses from remote profiler sessions and returns them as a
|
// Awaits for responses from remote profiler sessions and returns them as a
|
||||||
// list. Subsequent calls beyond the first will yield a list of errors.
|
// list. Subsequent calls beyond the first will yield a list of errors.
|
||||||
@ -69,16 +61,16 @@ class RemoteProfilerSessionManager {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
explicit RemoteProfilerSessionManager(
|
explicit RemoteProfilerSessionManager(
|
||||||
RemoteProfilerSessionManagerOptions options, AddressResolver resolver);
|
RemoteProfilerSessionManagerOptions options, ProfileRequest request,
|
||||||
|
AddressResolver resolver);
|
||||||
|
|
||||||
// Initialization of all client contexts.
|
// Initialization of all client contexts.
|
||||||
Status Init();
|
Status Init();
|
||||||
|
|
||||||
Status ValidateOptionsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
|
||||||
|
|
||||||
mutex mutex_;
|
mutex mutex_;
|
||||||
// Remote profiler session options.
|
// Remote profiler session options.
|
||||||
RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_);
|
RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_);
|
||||||
|
ProfileRequest request_ TF_GUARDED_BY(mutex_);
|
||||||
// List of clients, each connects to a profiling service.
|
// List of clients, each connects to a profiling service.
|
||||||
std::vector<std::unique_ptr<RemoteProfilerSession>> clients_
|
std::vector<std::unique_ptr<RemoteProfilerSession>> clients_
|
||||||
TF_GUARDED_BY(mutex_);
|
TF_GUARDED_BY(mutex_);
|
||||||
|
@ -35,46 +35,73 @@ namespace {
|
|||||||
using ::tensorflow::profiler::test::DurationApproxLess;
|
using ::tensorflow::profiler::test::DurationApproxLess;
|
||||||
using ::tensorflow::profiler::test::DurationNear;
|
using ::tensorflow::profiler::test::DurationNear;
|
||||||
using ::tensorflow::profiler::test::StartServer;
|
using ::tensorflow::profiler::test::StartServer;
|
||||||
|
using ::tensorflow::testing::TmpDir;
|
||||||
using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response;
|
using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response;
|
||||||
|
|
||||||
|
// Copied from capture_profile to not introduce a dependency.
|
||||||
|
ProfileRequest PopulateProfileRequest(
|
||||||
|
absl::string_view repository_root, absl::string_view session_id,
|
||||||
|
absl::string_view host_name,
|
||||||
|
const RemoteProfilerSessionManagerOptions& options) {
|
||||||
|
constexpr uint64 kMaxEvents = 1000000;
|
||||||
|
const absl::string_view kXPlanePb = "xplane.pb";
|
||||||
|
ProfileRequest request;
|
||||||
|
// TODO(b/169976117) Remove duration from request.
|
||||||
|
request.set_duration_ms(options.profiler_options().duration_ms());
|
||||||
|
request.set_max_events(kMaxEvents);
|
||||||
|
request.set_repository_root(repository_root.data(), repository_root.size());
|
||||||
|
request.set_session_id(session_id.data(), session_id.size());
|
||||||
|
request.set_host_name(host_name.data(), host_name.size());
|
||||||
|
// XPlane tool is only used by OSS profiler and safely ignored by TPU
|
||||||
|
// profiler.
|
||||||
|
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
|
||||||
|
*request.mutable_opts() = options.profiler_options();
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
TEST(RemoteProfilerSessionManagerTest, Simple) {
|
TEST(RemoteProfilerSessionManagerTest, Simple) {
|
||||||
absl::Duration duration = absl::Milliseconds(30);
|
absl::Duration duration = absl::Milliseconds(30);
|
||||||
RemoteProfilerSessionManagerOptions options =
|
RemoteProfilerSessionManagerOptions options;
|
||||||
RemoteProfilerSessionManager::DefaultOptions();
|
*options.mutable_profiler_options() =
|
||||||
|
tensorflow::ProfilerSession::DefaultOptions();
|
||||||
options.mutable_profiler_options()->set_duration_ms(
|
options.mutable_profiler_options()->set_duration_ms(
|
||||||
absl::ToInt64Milliseconds(duration));
|
absl::ToInt64Milliseconds(duration));
|
||||||
|
|
||||||
std::string service_addresses;
|
std::string service_address;
|
||||||
auto server = StartServer(duration, &service_addresses);
|
auto server = StartServer(duration, &service_address);
|
||||||
options.add_service_addresses(service_addresses);
|
options.add_service_addresses(service_address);
|
||||||
absl::Time approx_start = absl::Now();
|
absl::Time approx_start = absl::Now();
|
||||||
absl::Duration grace = absl::Seconds(1);
|
absl::Duration grace = absl::Seconds(1);
|
||||||
absl::Duration max_duration = duration + grace;
|
absl::Duration max_duration = duration + grace;
|
||||||
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
||||||
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
|
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
|
||||||
|
|
||||||
|
ProfileRequest request =
|
||||||
|
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
|
||||||
Status status;
|
Status status;
|
||||||
auto sessions = RemoteProfilerSessionManager::Create(options, status);
|
auto sessions =
|
||||||
|
RemoteProfilerSessionManager::Create(options, request, status);
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
std::vector<Response> responses = sessions->WaitForCompletion();
|
std::vector<Response> responses = sessions->WaitForCompletion();
|
||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
ASSERT_EQ(responses.size(), 1);
|
ASSERT_EQ(responses.size(), 1);
|
||||||
EXPECT_TRUE(responses.back().status.ok());
|
EXPECT_TRUE(responses.back().status.ok());
|
||||||
EXPECT_FALSE(responses.back().profile_response->empty_trace());
|
EXPECT_TRUE(responses.back().profile_response->empty_trace());
|
||||||
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
|
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
|
||||||
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
|
TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
|
||||||
absl::Duration duration = absl::Milliseconds(30);
|
absl::Duration duration = absl::Milliseconds(30);
|
||||||
RemoteProfilerSessionManagerOptions options =
|
RemoteProfilerSessionManagerOptions options;
|
||||||
RemoteProfilerSessionManager::DefaultOptions();
|
*options.mutable_profiler_options() =
|
||||||
|
tensorflow::ProfilerSession::DefaultOptions();
|
||||||
options.mutable_profiler_options()->set_duration_ms(
|
options.mutable_profiler_options()->set_duration_ms(
|
||||||
absl::ToInt64Milliseconds(duration));
|
absl::ToInt64Milliseconds(duration));
|
||||||
|
|
||||||
std::string service_addresses;
|
std::string service_address;
|
||||||
auto server = StartServer(duration, &service_addresses);
|
auto server = StartServer(duration, &service_address);
|
||||||
options.add_service_addresses(service_addresses);
|
options.add_service_addresses(service_address);
|
||||||
absl::Duration grace = absl::Seconds(1);
|
absl::Duration grace = absl::Seconds(1);
|
||||||
absl::Duration max_duration = duration + grace;
|
absl::Duration max_duration = duration + grace;
|
||||||
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
||||||
@ -82,28 +109,32 @@ TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
|
|||||||
options.set_session_creation_timestamp_ns(0);
|
options.set_session_creation_timestamp_ns(0);
|
||||||
|
|
||||||
absl::Time approx_start = absl::Now();
|
absl::Time approx_start = absl::Now();
|
||||||
|
ProfileRequest request =
|
||||||
|
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
|
||||||
Status status;
|
Status status;
|
||||||
auto sessions = RemoteProfilerSessionManager::Create(options, status);
|
auto sessions =
|
||||||
|
RemoteProfilerSessionManager::Create(options, request, status);
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
std::vector<Response> responses = sessions->WaitForCompletion();
|
std::vector<Response> responses = sessions->WaitForCompletion();
|
||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0)));
|
EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0)));
|
||||||
ASSERT_EQ(responses.size(), 1);
|
ASSERT_EQ(responses.size(), 1);
|
||||||
EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status));
|
EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status));
|
||||||
EXPECT_FALSE(responses.back().profile_response->empty_trace());
|
EXPECT_TRUE(responses.back().profile_response->empty_trace());
|
||||||
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
|
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RemoteProfilerSessionManagerTest, LongSession) {
|
TEST(RemoteProfilerSessionManagerTest, LongSession) {
|
||||||
absl::Duration duration = absl::Seconds(3);
|
absl::Duration duration = absl::Seconds(3);
|
||||||
RemoteProfilerSessionManagerOptions options =
|
RemoteProfilerSessionManagerOptions options;
|
||||||
RemoteProfilerSessionManager::DefaultOptions();
|
*options.mutable_profiler_options() =
|
||||||
|
tensorflow::ProfilerSession::DefaultOptions();
|
||||||
options.mutable_profiler_options()->set_duration_ms(
|
options.mutable_profiler_options()->set_duration_ms(
|
||||||
absl::ToInt64Milliseconds(duration));
|
absl::ToInt64Milliseconds(duration));
|
||||||
|
|
||||||
std::string service_addresses;
|
std::string service_address;
|
||||||
auto server = StartServer(duration, &service_addresses);
|
auto server = StartServer(duration, &service_address);
|
||||||
options.add_service_addresses(service_addresses);
|
options.add_service_addresses(service_address);
|
||||||
absl::Time approx_start = absl::Now();
|
absl::Time approx_start = absl::Now();
|
||||||
// Empirically determined value.
|
// Empirically determined value.
|
||||||
absl::Duration grace = absl::Seconds(20);
|
absl::Duration grace = absl::Seconds(20);
|
||||||
@ -111,15 +142,18 @@ TEST(RemoteProfilerSessionManagerTest, LongSession) {
|
|||||||
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
|
||||||
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
|
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
|
||||||
|
|
||||||
|
ProfileRequest request =
|
||||||
|
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
|
||||||
Status status;
|
Status status;
|
||||||
auto sessions = RemoteProfilerSessionManager::Create(options, status);
|
auto sessions =
|
||||||
|
RemoteProfilerSessionManager::Create(options, request, status);
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
std::vector<Response> responses = sessions->WaitForCompletion();
|
std::vector<Response> responses = sessions->WaitForCompletion();
|
||||||
absl::Duration elapsed = absl::Now() - approx_start;
|
absl::Duration elapsed = absl::Now() - approx_start;
|
||||||
ASSERT_EQ(responses.size(), 1);
|
ASSERT_EQ(responses.size(), 1);
|
||||||
EXPECT_TRUE(responses.back().status.ok());
|
EXPECT_TRUE(responses.back().status.ok());
|
||||||
EXPECT_FALSE(responses.back().profile_response->empty_trace());
|
EXPECT_TRUE(responses.back().profile_response->empty_trace());
|
||||||
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
|
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
|
||||||
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "grpcpp/support/status.h"
|
#include "grpcpp/support/status.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_replace.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/env_time.h"
|
#include "tensorflow/core/platform/env_time.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
@ -31,6 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
|
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
|
||||||
#include "tensorflow/core/profiler/profiler_service.pb.h"
|
#include "tensorflow/core/profiler/profiler_service.pb.h"
|
||||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||||
|
#include "tensorflow/core/profiler/utils/file_system_utils.h"
|
||||||
|
#include "tensorflow/core/profiler/utils/xplane_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace profiler {
|
namespace profiler {
|
||||||
@ -38,15 +41,31 @@ namespace {
|
|||||||
|
|
||||||
const absl::string_view kXPlanePb = "xplane.pb";
|
const absl::string_view kXPlanePb = "xplane.pb";
|
||||||
|
|
||||||
Status CollectDataToResponse(const ProfileRequest& req,
|
// Collects data in XSpace format. The data is saved to a repository
|
||||||
ProfilerSession* profiler,
|
// unconditionally.
|
||||||
ProfileResponse* response) {
|
Status CollectDataToRepository(const ProfileRequest& request,
|
||||||
profiler::XSpace xspace;
|
ProfilerSession* profiler,
|
||||||
|
ProfileResponse* response) {
|
||||||
|
response->set_empty_trace(true);
|
||||||
|
// Read the profile data into xspace.
|
||||||
|
XSpace xspace;
|
||||||
TF_RETURN_IF_ERROR(profiler->CollectData(&xspace));
|
TF_RETURN_IF_ERROR(profiler->CollectData(&xspace));
|
||||||
auto* tool_data = response->add_tool_data();
|
VLOG(3) << "Collected XSpace to repository.";
|
||||||
tool_data->set_name(kXPlanePb.data(), kXPlanePb.size());
|
response->set_empty_trace(IsEmpty(xspace));
|
||||||
xspace.SerializeToString(tool_data->mutable_data());
|
|
||||||
return Status::OK();
|
std::string log_dir_path =
|
||||||
|
ProfilerJoinPath(request.repository_root(), request.session_id());
|
||||||
|
VLOG(1) << "Creating " << log_dir_path;
|
||||||
|
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir_path));
|
||||||
|
|
||||||
|
std::string file_name = absl::StrCat(request.host_name(), ".", kXPlanePb);
|
||||||
|
// Windows file names do not support colons.
|
||||||
|
absl::StrReplaceAll({{":", "_"}}, &file_name);
|
||||||
|
// Dumps profile data to <repository_root>/<run>/<host>_<port>.<kXPlanePb>
|
||||||
|
std::string out_path = ProfilerJoinPath(log_dir_path, file_name);
|
||||||
|
LOG(INFO) << "Collecting XSpace to repository: " << out_path;
|
||||||
|
|
||||||
|
return WriteBinaryProto(Env::Default(), out_path, xspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
||||||
@ -68,7 +87,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
for (uint64 i = 0; i < req->duration_ms(); ++i) {
|
for (uint64 i = 0; i < req->opts().duration_ms(); ++i) {
|
||||||
env->SleepForMicroseconds(EnvTime::kMillisToMicros);
|
env->SleepForMicroseconds(EnvTime::kMillisToMicros);
|
||||||
if (ctx->IsCancelled()) {
|
if (ctx->IsCancelled()) {
|
||||||
return ::grpc::Status::CANCELLED;
|
return ::grpc::Status::CANCELLED;
|
||||||
@ -80,7 +99,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
status = CollectDataToResponse(*req, profiler.get(), response);
|
status = CollectDataToRepository(*req, profiler.get(), response);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
|
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
|
||||||
status.error_message());
|
status.error_message());
|
||||||
@ -116,5 +135,4 @@ std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -172,7 +172,7 @@ class ServerLibTest(test.TestCase):
|
|||||||
# return UnavailableError with no trace events collected string.
|
# return UnavailableError with no trace events collected string.
|
||||||
with self.assertRaises(errors.UnavailableError) as error:
|
with self.assertRaises(errors.UnavailableError) as error:
|
||||||
profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
|
profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
|
||||||
self.assertEqual("No trace event is collected", str(error.exception))
|
self.assertStartsWith(str(error.exception), "No trace event was collected")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -19,6 +19,7 @@ cuda_py_test(
|
|||||||
srcs = ["profiler_api_test.py"],
|
srcs = ["profiler_api_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
|
"external", # So that test suite reruns unconditionally.
|
||||||
"no_pip",
|
"no_pip",
|
||||||
"no_rocm",
|
"no_rocm",
|
||||||
],
|
],
|
||||||
|
@ -67,10 +67,15 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
|||||||
'kernel_stats.pb',
|
'kernel_stats.pb',
|
||||||
]
|
]
|
||||||
for file in expected_files:
|
for file in expected_files:
|
||||||
path = os.path.join(logdir, 'plugins/profile/*/*{}'.format(file))
|
path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file))
|
||||||
self.assertEqual(1, len(glob.glob(path)),
|
self.assertEqual(1, len(glob.glob(path)),
|
||||||
'Expected one path match: ' + path)
|
'Expected one path match: ' + path)
|
||||||
|
|
||||||
|
def _check_xspace_pb_exist(self, logdir):
|
||||||
|
path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb')
|
||||||
|
self.assertEqual(1, len(glob.glob(path)),
|
||||||
|
'Expected one path match: ' + path)
|
||||||
|
|
||||||
def test_single_worker_no_profiling(self):
|
def test_single_worker_no_profiling(self):
|
||||||
"""Test single worker without profiling."""
|
"""Test single worker without profiling."""
|
||||||
|
|
||||||
@ -86,7 +91,6 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
|||||||
profiler.start_server(port)
|
profiler.start_server(port)
|
||||||
_, steps, train_ds, model = _model_setup()
|
_, steps, train_ds, model = _model_setup()
|
||||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||||
logging.info('worker finishing')
|
|
||||||
|
|
||||||
def on_profile(port, logdir):
|
def on_profile(port, logdir):
|
||||||
# Request for 30 milliseconds of profile.
|
# Request for 30 milliseconds of profile.
|
||||||
@ -109,7 +113,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
|||||||
thread_profiler.start()
|
thread_profiler.start()
|
||||||
thread_profiler.join()
|
thread_profiler.join()
|
||||||
thread_worker.join(120)
|
thread_worker.join(120)
|
||||||
self._check_tools_pb_exist(logdir)
|
self._check_xspace_pb_exist(logdir)
|
||||||
|
|
||||||
def test_single_worker_programmatic_mode(self):
|
def test_single_worker_programmatic_mode(self):
|
||||||
"""Test single worker programmatic mode."""
|
"""Test single worker programmatic mode."""
|
||||||
|
@ -130,6 +130,7 @@ tf_python_pybind_extension(
|
|||||||
"//tensorflow/python:pybind11_status",
|
"//tensorflow/python:pybind11_status",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/time",
|
||||||
"@pybind11",
|
"@pybind11",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -14,11 +14,17 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "pybind11/cast.h"
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/strip.h"
|
||||||
|
#include "absl/time/clock.h"
|
||||||
|
#include "absl/time/time.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/pytypes.h"
|
#include "pybind11/pytypes.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
@ -38,7 +44,12 @@ namespace py = ::pybind11;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
tensorflow::Status ValidateHostPortPair(const std::string& host_port) {
|
using ::tensorflow::RemoteProfilerSessionManagerOptions;
|
||||||
|
|
||||||
|
// Profiler gives grace after profiling duration to terminate.
|
||||||
|
constexpr absl::Duration kSessionGraceTime = absl::Seconds(5);
|
||||||
|
|
||||||
|
tensorflow::Status ValidateHostPortPair(absl::string_view host_port) {
|
||||||
tensorflow::uint32 port;
|
tensorflow::uint32 port;
|
||||||
std::vector<absl::string_view> parts = absl::StrSplit(host_port, ':');
|
std::vector<absl::string_view> parts = absl::StrSplit(host_port, ':');
|
||||||
// Must be host:port, port must be a number, host must not contain a '/',
|
// Must be host:port, port must be a number, host must not contain a '/',
|
||||||
@ -51,34 +62,156 @@ tensorflow::Status ValidateHostPortPair(const std::string& host_port) {
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Takes profiler options in a py::dict and returns a ProfileOptions.
|
tensorflow::Status ValidateOptions(
|
||||||
|
const RemoteProfilerSessionManagerOptions& options) {
|
||||||
|
if (options.service_addresses().empty()) {
|
||||||
|
return tensorflow::errors::InvalidArgument("No service address provided.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.profiler_options().duration_ms() == 0) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"duration_ms must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (absl::string_view host_port : options.service_addresses()) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHostPortPair(host_port));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.max_session_duration_ms() <
|
||||||
|
options.profiler_options().duration_ms()) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"The maximum profiling session duration must be greater than or equal "
|
||||||
|
"to the local profiler duration.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receives a comma delimited list of service_addresses and adds them to
|
||||||
|
// RemoteProfilerSessionManagerOptions::service_addresses.
|
||||||
|
void AddServiceAddresses(absl::string_view service_addresses,
|
||||||
|
RemoteProfilerSessionManagerOptions* options) {
|
||||||
|
for (absl::string_view server : absl::StrSplit(service_addresses, ',')) {
|
||||||
|
options->add_service_addresses(server.data(), server.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets gRPC deadline to a grace period based on the profiling duration.
|
||||||
|
void UpdateMaxSessionDuration(RemoteProfilerSessionManagerOptions& options) {
|
||||||
|
auto local_profiler_duration = options.profiler_options().duration_ms();
|
||||||
|
auto session_creation_ts = options.session_creation_timestamp_ns();
|
||||||
|
auto requested_start_ts = options.profiler_options().start_timestamp_ns();
|
||||||
|
// User only needs to set maximal session duration if the profiling duration
|
||||||
|
// is bounded.
|
||||||
|
DCHECK_GT(local_profiler_duration, 0);
|
||||||
|
VLOG(3) << "duration_ms was given as " << local_profiler_duration;
|
||||||
|
// Max session duration includes the profiling session and grace time.
|
||||||
|
auto profile_duration =
|
||||||
|
absl::Milliseconds(local_profiler_duration) + kSessionGraceTime;
|
||||||
|
absl::Duration delay_duration;
|
||||||
|
// When requested start timestamp is 0, profiling starts immediately.
|
||||||
|
if (requested_start_ts > 0) {
|
||||||
|
delay_duration =
|
||||||
|
absl::Nanoseconds(requested_start_ts - session_creation_ts);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto max_session_duration = profile_duration + delay_duration;
|
||||||
|
options.set_max_session_duration_ms(
|
||||||
|
absl::ToInt64Milliseconds(max_session_duration));
|
||||||
|
VLOG(1) << "max_session_duration set to " << max_session_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Takes profiler options in a py::dict and returns a
|
||||||
|
// RemoteProfilerSessionManagerOptions.
|
||||||
// This must be called under GIL because it reads Python objects. Reading Python
|
// This must be called under GIL because it reads Python objects. Reading Python
|
||||||
// objects require GIL because the objects can be mutated by other Python
|
// objects require GIL because the objects can be mutated by other Python
|
||||||
// threads. In addition, Python objects are reference counted; reading py::dict
|
// threads. In addition, Python objects are reference counted; reading py::dict
|
||||||
// will increase its reference count.
|
// will increase its reference count.
|
||||||
tensorflow::ProfileOptions GetOptionsLocked(const py::dict& opts) {
|
RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir,
|
||||||
tensorflow::ProfileOptions options =
|
const py::dict& opts) {
|
||||||
|
RemoteProfilerSessionManagerOptions options;
|
||||||
|
*options.mutable_profiler_options() =
|
||||||
tensorflow::ProfilerSession::DefaultOptions();
|
tensorflow::ProfilerSession::DefaultOptions();
|
||||||
|
// Store a timestamp of when this session was created. This will be the basis
|
||||||
|
// of gRPC deadline afterwards.
|
||||||
|
auto now = absl::Now();
|
||||||
|
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(now));
|
||||||
|
VLOG(2) << "set_session_creation_timestamp_ns set to "
|
||||||
|
<< options.session_creation_timestamp_ns() << " [" << now << "]";
|
||||||
|
|
||||||
|
// Set the path of where to store XSpaces.
|
||||||
|
options.mutable_profiler_options()->set_repository_path(logdir.data(),
|
||||||
|
logdir.size());
|
||||||
|
VLOG(2) << "repository_path set to "
|
||||||
|
<< options.profiler_options().repository_path();
|
||||||
|
|
||||||
for (const auto& kw : opts) {
|
for (const auto& kw : opts) {
|
||||||
std::string key = py::cast<std::string>(kw.first);
|
std::string key = py::cast<std::string>(kw.first);
|
||||||
if (key == "host_tracer_level") {
|
if (key == "host_tracer_level") {
|
||||||
options.set_host_tracer_level(py::cast<int>(kw.second));
|
auto value = py::cast<int>(kw.second);
|
||||||
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
|
options.mutable_profiler_options()->set_host_tracer_level(value);
|
||||||
|
VLOG(1) << "host_tracer_level set to " << value;
|
||||||
} else if (key == "device_tracer_level") {
|
} else if (key == "device_tracer_level") {
|
||||||
options.set_device_tracer_level(py::cast<int>(kw.second));
|
auto value = py::cast<int>(kw.second);
|
||||||
VLOG(1) << "device_tracer_level set to " << options.device_tracer_level();
|
options.mutable_profiler_options()->set_device_tracer_level(value);
|
||||||
|
VLOG(1) << "device_tracer_level set to " << value;
|
||||||
} else if (key == "python_tracer_level") {
|
} else if (key == "python_tracer_level") {
|
||||||
options.set_python_tracer_level(py::cast<int>(kw.second));
|
auto value = py::cast<int>(kw.second);
|
||||||
VLOG(1) << "python_tracer_level set to " << options.python_tracer_level();
|
options.mutable_profiler_options()->set_python_tracer_level(value);
|
||||||
|
VLOG(1) << "python_tracer_level set to " << value;
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "Unrecognised key: " << key;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
RemoteProfilerSessionManagerOptions GetOptionsLocked(
|
||||||
|
absl::string_view service_addresses, absl::string_view logdir,
|
||||||
|
absl::string_view worker_list, bool include_dataset_ops,
|
||||||
|
tensorflow::int32 duration_ms, py::dict opts, bool* is_cloud_tpu_session) {
|
||||||
|
RemoteProfilerSessionManagerOptions options = GetOptionsLocked(logdir, opts);
|
||||||
|
|
||||||
|
// Remote profiling does not support any use cases where the following options
|
||||||
|
// are set by `py::dict opts`. e.g. `opts['service_addrs']` will not happen.
|
||||||
|
DCHECK(options.service_addresses().empty());
|
||||||
|
// In remote profiling, duration is always passed by value explicitly and not
|
||||||
|
// set in py::dict opts.
|
||||||
|
DCHECK_EQ(options.profiler_options().duration_ms(), 0);
|
||||||
|
// Because duration_ms is not set from py::dict opts, it follows that
|
||||||
|
// max_session_duration_ms must be unset as well.
|
||||||
|
DCHECK_EQ(options.max_session_duration_ms(), 0);
|
||||||
|
|
||||||
|
// Worker_list is only used for TensorBoard TPU capture cases. For a TPU
|
||||||
|
// cluster, service_address is the Master, which can already be found in the
|
||||||
|
// list of workers. These sessions will be used with the ProfileAnalysis
|
||||||
|
// service.
|
||||||
|
*is_cloud_tpu_session = !worker_list.empty();
|
||||||
|
AddServiceAddresses(*is_cloud_tpu_session ? worker_list : service_addresses,
|
||||||
|
&options);
|
||||||
|
|
||||||
|
// Set local profiler duration and profiler session durations.
|
||||||
|
options.mutable_profiler_options()->set_include_dataset_ops(
|
||||||
|
include_dataset_ops);
|
||||||
|
options.mutable_profiler_options()->set_duration_ms(duration_ms);
|
||||||
|
UpdateMaxSessionDuration(options);
|
||||||
|
|
||||||
|
for (int idx = 0; idx < options.service_addresses_size(); ++idx) {
|
||||||
|
VLOG(1) << "service_addr " << idx << " set to "
|
||||||
|
<< options.service_addresses(idx);
|
||||||
|
}
|
||||||
|
VLOG(1) << "include_dataset_ops set to " << include_dataset_ops;
|
||||||
|
VLOG(1) << "duration_ms set to " << duration_ms;
|
||||||
|
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProfilerSessionWrapper {
|
class ProfilerSessionWrapper {
|
||||||
public:
|
public:
|
||||||
void Start(const char* logdir, const py::dict& options) {
|
void Start(const char* logdir, const py::dict& options) {
|
||||||
session_ = tensorflow::ProfilerSession::Create(GetOptionsLocked(options));
|
auto opts = GetOptionsLocked(logdir, options);
|
||||||
|
session_ = tensorflow::ProfilerSession::Create(opts.profiler_options());
|
||||||
logdir_ = logdir;
|
logdir_ = logdir;
|
||||||
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
|
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
|
||||||
}
|
}
|
||||||
@ -130,26 +263,28 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
|
|||||||
profiler_server.release();
|
profiler_server.release();
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("trace",
|
m.def("trace", [](const char* service_addr, const char* logdir,
|
||||||
[](const char* service_addr, const char* logdir,
|
const char* worker_list, bool include_dataset_ops,
|
||||||
const char* worker_list, bool include_dataset_ops, int duration_ms,
|
int duration_ms, int num_tracing_attempts,
|
||||||
int num_tracing_attempts, py::dict options) {
|
py::dict options) {
|
||||||
// Normalize py::dict into a well defined proto.
|
// TPU capture is true if the user sets worker_list.
|
||||||
tensorflow::ProfileOptions opts = GetOptionsLocked(options);
|
bool is_cloud_tpu_session = false;
|
||||||
|
// Normalize py::dict into a well defined and validated proto.
|
||||||
|
tensorflow::RemoteProfilerSessionManagerOptions opts =
|
||||||
|
GetOptionsLocked(service_addr, logdir, worker_list, include_dataset_ops,
|
||||||
|
duration_ms, options, &is_cloud_tpu_session);
|
||||||
|
tensorflow::Status status = ValidateOptions(opts);
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||||
|
|
||||||
tensorflow::Status status = ValidateHostPortPair(service_addr);
|
{
|
||||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
// Release the lock to keep the lock scope to a minimum, and allow
|
||||||
opts.set_include_dataset_ops(include_dataset_ops);
|
// other threads to proceed.
|
||||||
{
|
py::gil_scoped_release release;
|
||||||
// Release the lock to keep the lock scope to a minimum, and allow
|
status = tensorflow::profiler::Trace(logdir, num_tracing_attempts, opts,
|
||||||
// other threads to proceed.
|
is_cloud_tpu_session);
|
||||||
py::gil_scoped_release release;
|
}
|
||||||
status = tensorflow::profiler::Trace(service_addr, logdir,
|
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||||
worker_list, duration_ms,
|
});
|
||||||
num_tracing_attempts, opts);
|
|
||||||
}
|
|
||||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
|
||||||
});
|
|
||||||
|
|
||||||
m.def("monitor", [](const char* service_addr, int duration_ms,
|
m.def("monitor", [](const char* service_addr, int duration_ms,
|
||||||
int monitoring_level, bool display_timestamp) {
|
int monitoring_level, bool display_timestamp) {
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.profiler.internal import _pywrap_profiler
|
from tensorflow.python.profiler.internal import _pywrap_profiler
|
||||||
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -32,38 +33,57 @@ def trace(service_addr,
|
|||||||
worker_list='',
|
worker_list='',
|
||||||
num_tracing_attempts=3,
|
num_tracing_attempts=3,
|
||||||
options=None):
|
options=None):
|
||||||
"""Sends grpc requests to profiler server to perform on-demand profiling.
|
"""Sends gRPC requests to one or more profiler servers to perform on-demand profiling.
|
||||||
|
|
||||||
This method will block caller thread until it receives tracing result. This
|
This method will block the calling thread until it receives responses from all
|
||||||
method supports CPU, GPU, and Cloud TPU. This method supports profiling a
|
servers or until deadline expiration. Both single host and multiple host
|
||||||
single host for CPU, GPU, TPU, as well as multiple TPU workers.
|
profiling are supported on CPU, GPU, and TPU.
|
||||||
The profiled results will be saved to your specified TensorBoard log
|
The profiled results will be saved by each server to the specified TensorBoard
|
||||||
directory (e.g. the directory you save your model checkpoints). Use the
|
log directory (i.e. the directory you save your model checkpoints). Use the
|
||||||
TensorBoard profile plugin to view the visualization and analysis results.
|
TensorBoard profile plugin to view the visualization and analysis results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_addr: gRPC address of profiler service e.g. grpc://localhost:6009.
|
service_addr: A comma delimited string of gRPC addresses of the workers to
|
||||||
logdir: Path of TensorBoard log directory e.g. /tmp/tb_log.
|
profile.
|
||||||
duration_ms: Duration of tracing or monitoring in ms.
|
e.g. service_addr='grpc://localhost:6009'
|
||||||
worker_list: Optional. The list of workers that we are about to profile in
|
service_addr='grpc://10.0.0.2:8466,grpc://10.0.0.3:8466'
|
||||||
the current session (TPU only).
|
service_addr='grpc://localhost:12345,grpc://localhost:23456'
|
||||||
|
logdir: Path to save profile data to, typically a TensorBoard log directory.
|
||||||
|
This path must be accessible to both the client and server.
|
||||||
|
e.g. logdir='gs://your_tb_dir'
|
||||||
|
duration_ms: Duration of tracing or monitoring in mliiseconds. Must be
|
||||||
|
greater than zero.
|
||||||
|
worker_list: An optional TPU only configuration. The list of workers to
|
||||||
|
profile in the current session.
|
||||||
num_tracing_attempts: Optional. Automatically retry N times when no trace
|
num_tracing_attempts: Optional. Automatically retry N times when no trace
|
||||||
event is collected (default 3).
|
event is collected (default 3).
|
||||||
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
|
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
|
||||||
profiler options.
|
profiler options.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
UnavailableError: If no trace event is collected.
|
InvalidArgumentError: For when arguments fail validation checks.
|
||||||
|
UnavailableError: If no trace event was collected.
|
||||||
|
|
||||||
Example usage (CPU/GPU):
|
Example usage (CPU/GPU):
|
||||||
# Start a profiler server before your model runs.
|
# Start a profiler server before your model runs.
|
||||||
```python
|
```python
|
||||||
tf.profiler.experimental.server.start(6009)
|
tf.profiler.experimental.server.start(6009)
|
||||||
# your model code.
|
# (Model code goes here).
|
||||||
# Send gRPC request to the profiler server to collect a trace of your model.
|
# Send gRPC request to the profiler server to collect a trace of your model.
|
||||||
```python
|
```python
|
||||||
tf.profiler.experimental.client.trace('grpc://localhost:6009',
|
tf.profiler.experimental.client.trace('grpc://localhost:6009',
|
||||||
'/tmp/tb_log', 2000)
|
'/nfs/tb_log', 2000)
|
||||||
|
|
||||||
|
Example usage (Multiple GPUs):
|
||||||
|
# E.g. your worker IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you
|
||||||
|
# would like to schedule start of profiling 1 second from now, for a duration
|
||||||
|
# of 2 seconds.
|
||||||
|
options['delay_ms'] = 1000
|
||||||
|
tf.profiler.experimental.client.trace(
|
||||||
|
'grpc://10.0.0.2:8466,grpc://10.0.0.3:8466,grpc://10.0.0.4:8466',
|
||||||
|
'gs://your_tb_dir',
|
||||||
|
2000,
|
||||||
|
options=options)
|
||||||
|
|
||||||
Example usage (TPU):
|
Example usage (TPU):
|
||||||
# Send gRPC request to a TPU worker to collect a trace of your model. A
|
# Send gRPC request to a TPU worker to collect a trace of your model. A
|
||||||
@ -82,16 +102,19 @@ def trace(service_addr,
|
|||||||
# profile for 2 seconds.
|
# profile for 2 seconds.
|
||||||
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
|
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
|
||||||
'gs://your_tb_dir',
|
'gs://your_tb_dir',
|
||||||
2000, '10.0.0.3,10.0.0.4')
|
2000, '10.0.0.2,10.0.0.3,10.0.0.4')
|
||||||
|
|
||||||
Launch TensorBoard and point it to the same logdir you provided to this API.
|
Launch TensorBoard and point it to the same logdir you provided to this API.
|
||||||
$ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples)
|
$ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples)
|
||||||
Open your browser and go to localhost:6006/#profile to view profiling results.
|
Open your browser and go to localhost:6006/#profile to view profiling results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if duration_ms <= 0:
|
||||||
|
raise errors.InvalidArgumentError(None, None,
|
||||||
|
'duration_ms must be greater than zero.')
|
||||||
|
|
||||||
opts = dict(options._asdict()) if options is not None else {}
|
opts = dict(options._asdict()) if options is not None else {}
|
||||||
_pywrap_profiler.trace(
|
_pywrap_profiler.trace(
|
||||||
_strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
|
_strip_addresses(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
|
||||||
duration_ms, num_tracing_attempts, opts)
|
duration_ms, num_tracing_attempts, opts)
|
||||||
|
|
||||||
|
|
||||||
@ -127,3 +150,7 @@ def monitor(service_addr, duration_ms, level=1):
|
|||||||
|
|
||||||
def _strip_prefix(s, prefix):
|
def _strip_prefix(s, prefix):
|
||||||
return s[len(prefix):] if s.startswith(prefix) else s
|
return s[len(prefix):] if s.startswith(prefix) else s
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_addresses(addresses, prefix):
|
||||||
|
return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')])
|
||||||
|
@ -38,7 +38,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(errors.UnavailableError) as error:
|
with self.assertRaises(errors.UnavailableError) as error:
|
||||||
profiler_client.trace(
|
profiler_client.trace(
|
||||||
'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10)
|
'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10)
|
||||||
self.assertEqual('No trace event is collected', str(error.exception))
|
self.assertStartsWith(str(error.exception), 'No trace event was collected')
|
||||||
|
|
||||||
def testTrace_ProfileIdleServerWithOptions(self):
|
def testTrace_ProfileIdleServerWithOptions(self):
|
||||||
test_port = portpicker.pick_unused_port()
|
test_port = portpicker.pick_unused_port()
|
||||||
@ -54,7 +54,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
|
|||||||
self.get_temp_dir(),
|
self.get_temp_dir(),
|
||||||
duration_ms=10,
|
duration_ms=10,
|
||||||
options=options)
|
options=options)
|
||||||
self.assertEqual('No trace event is collected', str(error.exception))
|
self.assertStartsWith(str(error.exception), 'No trace event was collected')
|
||||||
|
|
||||||
def testMonitor_ProcessInvalidAddress(self):
|
def testMonitor_ProcessInvalidAddress(self):
|
||||||
# Monitor is only supported in cloud TPU. Test invalid address instead.
|
# Monitor is only supported in cloud TPU. Test invalid address instead.
|
||||||
|
@ -346,6 +346,10 @@ tensorflow::profiler::ProfilerServer::~ProfilerServer
|
|||||||
tensorflow::profiler::ProfileGrpc
|
tensorflow::profiler::ProfileGrpc
|
||||||
tensorflow::profiler::NewSessionGrpc
|
tensorflow::profiler::NewSessionGrpc
|
||||||
tensorflow::profiler::MonitorGrpc
|
tensorflow::profiler::MonitorGrpc
|
||||||
|
tensorflow::profiler::RemoteProfilerSession::Create
|
||||||
|
tensorflow::profiler::RemoteProfilerSession::GetServiceAddress
|
||||||
|
tensorflow::profiler::RemoteProfilerSession::WaitForCompletion
|
||||||
|
tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession
|
||||||
|
|
||||||
[status_macros] # tfcompile
|
[status_macros] # tfcompile
|
||||||
xla::status_macros::MakeErrorStream::Impl::Impl
|
xla::status_macros::MakeErrorStream::Impl::Impl
|
||||||
|
Loading…
x
Reference in New Issue
Block a user