Added multi-worker sampling mode.

* Profile() now calls RemoteProfilerSessionManager (RPSM).
* Profiler server saves XSpace to repository instead of returning XSpace by ProfileResponse.
  * ProfileResponse will have a non-empty trace iff there are any XPlanes.

PiperOrigin-RevId: 336807134
Change-Id: I664400c9715d73605b9daa4cfdcf6475dee4e959
This commit is contained in:
Yi Situ 2020-10-12 21:54:35 -07:00 committed by TensorFlower Gardener
parent 78ba66c122
commit 8813a286b8
20 changed files with 462 additions and 291 deletions

View File

@ -47,8 +47,11 @@ cc_library(
"//tensorflow/core/profiler:profiler_service_proto_cc", "//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:file_system_utils",
"//tensorflow/core/profiler/utils:xplane_utils",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
tf_grpc_cc_dependency(), tf_grpc_cc_dependency(),
], ],
) )

View File

@ -27,6 +27,7 @@ cc_library(
], ],
deps = [ deps = [
":profiler_client_for_pybind", ":profiler_client_for_pybind",
":remote_profiler_session_manager",
":save_profile", ":save_profile",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
@ -133,14 +134,11 @@ cc_library(
srcs = ["remote_profiler_session_manager.cc"], srcs = ["remote_profiler_session_manager.cc"],
hdrs = ["remote_profiler_session_manager.h"], hdrs = ["remote_profiler_session_manager.h"],
copts = tf_profiler_copts(), copts = tf_profiler_copts(),
visibility = ["//tensorflow/core/profiler:internal"],
deps = [ deps = [
":profiler_client_for_pybind", ":profiler_client_for_pybind",
":save_profile",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/utils:time_utils", "//tensorflow/core/profiler/utils:time_utils",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -30,12 +30,16 @@ limitations under the License.
#include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/profiler_service.pb.h" #include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/rpc/client/profiler_client.h" #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h"
namespace tensorflow { namespace tensorflow {
namespace profiler { namespace profiler {
namespace { namespace {
using ::tensorflow::profiler::RemoteProfilerSessionManager;
using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response;
constexpr uint64 kMaxEvents = 1000000; constexpr uint64 kMaxEvents = 1000000;
const absl::string_view kXPlanePb = "xplane.pb"; const absl::string_view kXPlanePb = "xplane.pb";
@ -48,17 +52,18 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
return request; return request;
} }
ProfileRequest PopulateProfileRequest(int duration_ms, ProfileRequest PopulateProfileRequest(
const std::string& repository_root, absl::string_view repository_root, absl::string_view session_id,
const std::string& session_id, absl::string_view host_name,
const std::string& host_name, const RemoteProfilerSessionManagerOptions& options) {
const ProfileOptions& opts) {
ProfileRequest request; ProfileRequest request;
request.set_duration_ms(duration_ms); // TODO(b/169976117) Remove duration from request.
request.set_duration_ms(options.profiler_options().duration_ms());
request.set_max_events(kMaxEvents); request.set_max_events(kMaxEvents);
request.set_repository_root(repository_root); request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id); request.set_session_id(session_id.data(), session_id.size());
request.set_host_name(host_name); request.set_host_name(host_name.data(), host_name.size());
// These tools are only used by TPU profiler.
request.add_tools("trace_viewer"); request.add_tools("trace_viewer");
request.add_tools("op_profile"); request.add_tools("op_profile");
request.add_tools("input_pipeline"); request.add_tools("input_pipeline");
@ -68,21 +73,26 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
request.add_tools("overview_page"); request.add_tools("overview_page");
request.add_tools("pod_viewer"); request.add_tools("pod_viewer");
request.add_tools("tensorflow_stats"); request.add_tools("tensorflow_stats");
*request.mutable_opts() = opts; // XPlane tool is only used by OSS profiler and safely ignored by TPU
// profiler.
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
*request.mutable_opts() = options.profiler_options();
return request; return request;
} }
NewProfileSessionRequest PopulateNewProfileSessionRequest( NewProfileSessionRequest PopulateNewProfileSessionRequest(
const std::string& service_addr, const std::string& repository_root, absl::string_view repository_root, absl::string_view session_id,
const std::vector<string>& hostnames, int duration_ms, const RemoteProfilerSessionManagerOptions& opts) {
const std::string& session_id, const ProfileOptions& opts) {
NewProfileSessionRequest request; NewProfileSessionRequest request;
std::vector<std::string> parts = absl::StrSplit(service_addr, ':'); std::vector<absl::string_view> parts =
*request.mutable_request() = PopulateProfileRequest( absl::StrSplit(opts.service_addresses(0), ':');
duration_ms, repository_root, session_id, parts[0], opts); DCHECK(!parts.empty());
request.set_repository_root(repository_root);
request.set_session_id(session_id); *request.mutable_request() =
for (const auto& hostname : hostnames) { PopulateProfileRequest(repository_root, session_id, parts[0], opts);
request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id.data(), session_id.size());
for (const auto& hostname : opts.service_addresses()) {
request.add_hosts(hostname); request.add_hosts(hostname);
} }
return request; return request;
@ -99,44 +109,40 @@ inline bool ShouldRetryTracing(Status status) {
status.error_message() == "Stream removed"); status.error_message() == "Stream removed");
} }
// If the ProfileResponse has single 'xplane.pb' tool, convert the xplane to Status Profile(const std::string& repository_root,
// other tools and add in ProfileResponse. Otherwise, the ProfileResponse is const std::string& session_id,
// already converted, simply return. const RemoteProfilerSessionManagerOptions& opts) {
Status ConvertXSpaceToToolsInProfileResponse(const ProfileRequest& request, Status status;
ProfileResponse* response) { // Host name will be overwritten by RemoteProfilerSessionManager later.
if (response->tool_data_size() != 1) return Status::OK(); ProfileRequest request = PopulateProfileRequest(repository_root, session_id,
if (response->tool_data(0).name() != kXPlanePb) return Status::OK(); /*host_name=*/"", opts);
XSpace xspace; auto session = RemoteProfilerSessionManager::Create(opts, request, status);
xspace.ParseFromString(response->tool_data(0).data()); TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(ConvertXSpaceToProfileResponse(xspace, request, response)); // Expect one or more service addresses.
return Status::OK(); DCHECK_GT(opts.service_addresses_size(), 0);
} std::vector<Response> responses = session->WaitForCompletion();
// Expect responses to have the same size as clients.
DCHECK_EQ(responses.size(), opts.service_addresses_size());
Status Profile(const std::string& service_addr, bool has_trace_data = false;
const std::string& repository_root, int duration_ms, for (const auto& client_response : responses) {
const std::string& session_id, const ProfileOptions& opts) { ProfileResponse& response = *client_response.profile_response;
std::vector<std::string> parts = absl::StrSplit(service_addr, ':'); if (response.empty_trace()) {
ProfileRequest request = PopulateProfileRequest(duration_ms, repository_root, LOG(WARNING) << "No trace event is collected from "
session_id, parts[0], opts); << client_response.service_address;
ProfileResponse response; } else {
TF_RETURN_IF_ERROR(ProfileGrpc(service_addr, request, &response)); has_trace_data = true;
}
if (!response.empty_trace()) { if (!client_response.status.ok()) {
TF_RETURN_IF_ERROR( LOG(WARNING) << client_response.service_address << " returned "
ConvertXSpaceToToolsInProfileResponse(request, &response)); << client_response.status;
TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id, }
request.host_name(), response, &std::cout));
// Print this at the end so that it's not buried in irrelevant LOG messages.
std::cout
<< "NOTE: using the trace duration " << duration_ms << "ms.\n"
<< "Set an appropriate duration (with --duration_ms) if you "
"don't see a full step in your trace or the captured trace is too "
"large."
<< std::endl;
} }
if (response.empty_trace()) { if (!has_trace_data) {
return Status(error::Code::UNAVAILABLE, "No trace event is collected"); return Status(error::Code::UNAVAILABLE,
"No trace event was collected because there were no responses"
" from clients or the responses did not have trace data.");
} }
return Status::OK(); return Status::OK();
} }
@ -144,52 +150,47 @@ Status Profile(const std::string& service_addr,
// Start a new profiling session that include all the hosts included in // Start a new profiling session that include all the hosts included in
// hostnames, for the time interval of duration_ms. Possibly save the profiling // hostnames, for the time interval of duration_ms. Possibly save the profiling
// result in the directory specified by repository_root and session_id. // result in the directory specified by repository_root and session_id.
Status NewSession(const std::string& service_addr, Status NewSession(absl::string_view repository_root,
const std::string& repository_root, absl::string_view session_id,
const std::vector<string>& hostnames, int duration_ms, const RemoteProfilerSessionManagerOptions& opts) {
const std::string& session_id, const ProfileOptions& opts) { NewProfileSessionRequest request =
NewProfileSessionRequest request = PopulateNewProfileSessionRequest( PopulateNewProfileSessionRequest(repository_root, session_id, opts);
service_addr, repository_root, hostnames, duration_ms, session_id, opts);
NewProfileSessionResponse response; NewProfileSessionResponse response;
TF_RETURN_IF_ERROR(NewSessionGrpc(service_addr, request, &response)); TF_RETURN_IF_ERROR(
NewSessionGrpc(opts.service_addresses(0), request, &response));
std::cout << "Profile session succeed for host(s):" std::cout << "Profile session succeed for host(s):"
<< absl::StrJoin(hostnames, ",") << std::endl; << absl::StrJoin(opts.service_addresses(), ",") << std::endl;
if (response.empty_trace()) { if (response.empty_trace()) {
return Status(error::Code::UNAVAILABLE, "No trace event is collected"); return errors::Unavailable("No trace event is collected");
} }
return Status::OK(); return Status::OK();
} }
} // namespace } // namespace
// Starts tracing on a single or multiple hosts and saves the result in the Status Trace(const std::string& logdir, int num_tracing_attempts,
// given logdir. If no trace was collected, retries tracing for const RemoteProfilerSessionManagerOptions& opts,
// num_tracing_attempts. bool is_cloud_tpu_session) {
Status Trace(const std::string& service_addr, const std::string& logdir, DCHECK_GT(opts.profiler_options().duration_ms(), 0);
const std::string& workers_list, int duration_ms, DCHECK(!opts.service_addresses().empty());
int num_tracing_attempts, const ProfileOptions& opts) {
// Use the current timestamp as the run name. // Use the current timestamp as the run name.
std::string session_id = GetCurrentTimeStampAsString(); std::string session_id = GetCurrentTimeStampAsString();
std::vector<std::string> hostnames; std::string repository_root = GetTensorBoardProfilePluginDir(logdir);
if (!workers_list.empty()) { auto duration_ms = opts.profiler_options().duration_ms();
hostnames = absl::StrSplit(workers_list, ',');
}
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir)); TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
std::string repository_root =
profiler::GetTensorBoardProfilePluginDir(logdir);
Status status = Status::OK(); Status status;
int remaining_attempts = num_tracing_attempts; int remaining_attempts = num_tracing_attempts;
while (true) { while (true) {
std::cout << "Starting to trace for " << duration_ms << " ms. " std::cout << "Starting to trace for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << --remaining_attempts << std::endl; << "Remaining attempt(s): " << --remaining_attempts << std::endl;
if (hostnames.empty()) {
status = if (is_cloud_tpu_session) {
Profile(service_addr, repository_root, duration_ms, session_id, opts); status = NewSession(repository_root, session_id, opts);
} else { } else {
status = NewSession(service_addr, repository_root, hostnames, duration_ms, status = Profile(repository_root, session_id, opts);
session_id, opts);
} }
if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status)) if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
break; break;
@ -223,11 +224,10 @@ Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) {
ProfileResponse response; ProfileResponse response;
ProfileRequest request = PopulateProfileRequest( ProfileRequest request = PopulateProfileRequest(
/*duration_ms=*/0, GetTensorBoardProfilePluginDir(logdir), GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(),
GetCurrentTimeStampAsString(), port::Hostname(), /*opts=*/{}); port::Hostname(), /*options=*/{});
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ConvertXSpaceToProfileResponse(xspace, request, &response)); ConvertXSpaceToProfileResponse(xspace, request, &response));
std::stringstream ss; // Record LOG messages. std::stringstream ss; // Record LOG messages.
TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(), TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
request.session_id(), request.host_name(), request.session_id(), request.host_name(),

View File

@ -36,12 +36,12 @@ Status Monitor(const std::string& service_addr, int duration_ms,
int monitoring_level, bool display_timestamp, int monitoring_level, bool display_timestamp,
std::string* result); std::string* result);
// Starts tracing on a single or multiple hosts and saves the result in the // Starts tracing on a single or multiple hosts. Each host will save the result
// given logdir. If no trace was collected, retries tracing for // in the given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts. // num_tracing_attempts. Assumes that options have been validated.
Status Trace(const std::string& service_addr, const std::string& logdir, Status Trace(const std::string& logdir, int num_tracing_attempts,
const std::string& workers_list, int duration_ms, const RemoteProfilerSessionManagerOptions& opts,
int num_tracing_attempts, const ProfileOptions& opts); bool is_cloud_tpu_session);
} // namespace profiler } // namespace profiler
} // namespace tensorflow } // namespace tensorflow

View File

@ -99,10 +99,11 @@ RemoteProfilerSession::RemoteProfilerSession(std::string service_address,
service_address_(std::move(service_address)), service_address_(std::move(service_address)),
stub_(CreateStub<grpc::ProfilerService>(service_address_)), stub_(CreateStub<grpc::ProfilerService>(service_address_)),
deadline_(deadline), deadline_(deadline),
profile_request_(std::move(profile_request)) {} profile_request_(std::move(profile_request)) {
response_->set_empty_trace(true);
}
RemoteProfilerSession::~RemoteProfilerSession() { RemoteProfilerSession::~RemoteProfilerSession() {
LOG(INFO) << "Waiting for completion.";
Status dummy; Status dummy;
WaitForCompletion(dummy); WaitForCompletion(dummy);
grpc_context_.TryCancel(); grpc_context_.TryCancel();
@ -113,6 +114,8 @@ void RemoteProfilerSession::ProfileAsync() {
grpc_context_.set_deadline(absl::ToChronoTime(deadline_)); grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
VLOG(1) << "Deadline set to " << deadline_; VLOG(1) << "Deadline set to " << deadline_;
rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_); rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
// Connection failure will create lame channel whereby grpc_status_ will be an
// error.
rpc_->Finish(response_.get(), &grpc_status_, rpc_->Finish(response_.get(), &grpc_status_,
static_cast<void*>(&status_on_completion_)); static_cast<void*>(&status_on_completion_));
VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now(); VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
@ -125,6 +128,7 @@ std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
"WaitForCompletion must only be called once."); "WaitForCompletion must only be called once.");
return nullptr; return nullptr;
} }
LOG(INFO) << "Waiting for completion.";
void* got_tag = nullptr; void* got_tag = nullptr;
bool ok = false; bool ok = false;

View File

@ -82,7 +82,7 @@ class RemoteProfilerSession {
absl::Time deadline_; absl::Time deadline_;
::grpc::ClientContext grpc_context_; ::grpc::ClientContext grpc_context_;
std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_; std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_;
::grpc::Status grpc_status_; ::grpc::Status grpc_status_ = ::grpc::Status::OK;
// Asynchronous completion queue states. // Asynchronous completion queue states.
::grpc::CompletionQueue cq_; ::grpc::CompletionQueue cq_;

View File

@ -52,8 +52,10 @@ TEST(RemoteProfilerSession, Simple) {
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still. // At end of session this evaluates to true still.
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace()); // True because there was no workload traced and subsequently no XEvents.
EXPECT_GT(response->tool_data_size(), 0); EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
} }
@ -86,8 +88,9 @@ TEST(RemoteProfilerSession, Timeout) {
auto response = remote_session->WaitForCompletion(status); auto response = remote_session->WaitForCompletion(status);
// At end of session we will have a timeout error. // At end of session we will have a timeout error.
EXPECT_TRUE(errors::IsDeadlineExceeded(status)); EXPECT_TRUE(errors::IsDeadlineExceeded(status));
// True because there was no workload traced and subsequently no XEvents.
EXPECT_FALSE(response->empty_trace()); // This defaults to false. EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0); EXPECT_EQ(response->tool_data_size(), 0);
} }
@ -109,8 +112,10 @@ TEST(RemoteProfilerSession, LongDeadline) {
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still. // At end of session this evaluates to true still.
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace()); // True because there was no workload traced and subsequently no XEvents.
EXPECT_GT(response->tool_data_size(), 0); EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
// Elapsed time is near profiling duration despite long grace period. // Elapsed time is near profiling duration despite long grace period.
EXPECT_THAT(elapsed, DurationNear(duration)); EXPECT_THAT(elapsed, DurationNear(duration));
} }
@ -134,8 +139,10 @@ TEST(RemoteProfilerSession, LongDuration) {
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still. // At end of session this evaluates to true still.
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace()); // True because there was no workload traced and subsequently no XEvents.
EXPECT_GT(response->tool_data_size(), 0); EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
// Elapsed time takes longer to complete for larger traces. // Elapsed time takes longer to complete for larger traces.
EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
} }

View File

@ -37,14 +37,14 @@ namespace profiler {
namespace test { namespace test {
inline std::unique_ptr<ProfilerServer> StartServer( inline std::unique_ptr<ProfilerServer> StartServer(
absl::Duration duration, std::string* service_addresses, absl::Duration duration, std::string* service_address,
ProfileRequest* request = nullptr) { ProfileRequest* request = nullptr) {
auto profiler_server = absl::make_unique<ProfilerServer>(); auto profiler_server = absl::make_unique<ProfilerServer>();
int port = testing::PickUnusedPortOrDie(); int port = testing::PickUnusedPortOrDie();
profiler_server->StartProfilerServer(port); profiler_server->StartProfilerServer(port);
DCHECK(service_addresses); DCHECK(service_address);
*service_addresses = absl::StrCat("localhost:", port); *service_address = absl::StrCat("localhost:", port);
if (request) { if (request) {
request->set_duration_ms(absl::ToInt64Milliseconds(duration)); request->set_duration_ms(absl::ToInt64Milliseconds(duration));
@ -53,10 +53,11 @@ inline std::unique_ptr<ProfilerServer> StartServer(
request->mutable_opts()->set_duration_ms( request->mutable_opts()->set_duration_ms(
absl::ToInt64Milliseconds(duration)); absl::ToInt64Milliseconds(duration));
request->set_session_id("test_session"); request->set_session_id("test_session");
request->set_host_name(*service_addresses); request->set_host_name(*service_address);
request->set_repository_root(testing::TmpDir());
} }
LOG(INFO) << "Started " << *service_addresses << " at " << absl::Now(); LOG(INFO) << "Started " << *service_address << " at " << absl::Now();
LOG(INFO) << "Duration: " << duration; LOG(INFO) << "Duration: " << duration;
return profiler_server; return profiler_server;

View File

@ -26,47 +26,20 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h" #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
#include "tensorflow/core/profiler/utils/time_utils.h" #include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow { namespace tensorflow {
namespace profiler { namespace profiler {
namespace {
constexpr uint64 kMaxEvents = 1000000;
// TODO(yisitu) merge with the implementation in capture_profile.
void PopulateProfileRequest(const RemoteProfilerSessionManagerOptions& options,
absl::string_view session_id,
absl::string_view host_name,
ProfileRequest* request) {
request->set_max_events(kMaxEvents);
request->set_repository_root(options.profiler_options().repository_path());
request->set_session_id(session_id.data(), session_id.size());
request->add_tools("trace_viewer");
request->add_tools("op_profile");
request->add_tools("input_pipeline");
request->add_tools("kernel_stats");
request->add_tools("memory_viewer");
request->add_tools("memory_profile");
request->add_tools("overview_page");
request->add_tools("pod_viewer");
request->add_tools("tensorflow_stats");
request->set_host_name(host_name.data(), host_name.size());
*request->mutable_opts() = options.profiler_options();
request->set_duration_ms(options.profiler_options().duration_ms());
}
} // namespace
/*static*/ std::unique_ptr<RemoteProfilerSessionManager> /*static*/ std::unique_ptr<RemoteProfilerSessionManager>
RemoteProfilerSessionManager::Create( RemoteProfilerSessionManager::Create(
const RemoteProfilerSessionManagerOptions& options, const RemoteProfilerSessionManagerOptions& options,
tensorflow::Status& out_status, AddressResolver resolver) { const ProfileRequest& request, tensorflow::Status& out_status,
AddressResolver resolver) {
VLOG(1) << "Creating a RemoteProfilerSessionManager."; VLOG(1) << "Creating a RemoteProfilerSessionManager.";
auto session_manager = auto session_manager = absl::WrapUnique(
absl::WrapUnique(new RemoteProfilerSessionManager(options, resolver)); new RemoteProfilerSessionManager(options, request, resolver));
out_status = session_manager->Init(); out_status = session_manager->Init();
if (!out_status.ok()) { if (!out_status.ok()) {
return nullptr; return nullptr;
@ -75,8 +48,9 @@ RemoteProfilerSessionManager::Create(
} }
RemoteProfilerSessionManager::RemoteProfilerSessionManager( RemoteProfilerSessionManager::RemoteProfilerSessionManager(
RemoteProfilerSessionManagerOptions options, AddressResolver resolver) RemoteProfilerSessionManagerOptions options, ProfileRequest request,
: options_(std::move(options)) { AddressResolver resolver)
: options_(std::move(options)), request_(std::move(request)) {
if (resolver) { if (resolver) {
resolver_ = std::move(resolver); resolver_ = std::move(resolver);
} else { } else {
@ -91,14 +65,7 @@ RemoteProfilerSessionManager::~RemoteProfilerSessionManager() {
Status RemoteProfilerSessionManager::Init() { Status RemoteProfilerSessionManager::Init() {
mutex_lock lock(mutex_); mutex_lock lock(mutex_);
VLOG(1) << "SessionManager initializing."; VLOG(1) << "SessionManager initializing.";
// TODO(b/169482824) Move validation to call site.
Status status = ValidateOptionsLocked();
if (!status.ok()) {
LOG(ERROR) << status;
return status;
}
std::string session_id = GetCurrentTimeStampAsString();
const absl::Time session_created_ts = const absl::Time session_created_ts =
absl::FromUnixNanos(options_.session_creation_timestamp_ns()); absl::FromUnixNanos(options_.session_creation_timestamp_ns());
const absl::Time deadline = const absl::Time deadline =
@ -115,16 +82,14 @@ Status RemoteProfilerSessionManager::Init() {
// Prepare a list of clients. // Prepare a list of clients.
clients_.reserve(options_.service_addresses_size()); clients_.reserve(options_.service_addresses_size());
for (auto& service_addr : options_.service_addresses()) { for (auto& service_address : options_.service_addresses()) {
std::string resolved_service_addr = resolver_(service_addr); std::string resolved_service_address = resolver_(service_address);
ProfileRequest profile_request; ProfileRequest request = request_;
PopulateProfileRequest(options_, session_id, resolved_service_addr, request.set_host_name(resolved_service_address);
&profile_request);
// Creation also issues Profile RPC asynchronously. // Creation also issues Profile RPC asynchronously.
auto client = RemoteProfilerSession::Create( auto client = RemoteProfilerSession::Create(
std::move(resolved_service_addr), deadline, std::move(profile_request)); std::move(resolved_service_address), deadline, std::move(request));
clients_.push_back(std::move(client)); clients_.push_back(std::move(client));
} }
@ -132,41 +97,18 @@ Status RemoteProfilerSessionManager::Init() {
return Status::OK(); return Status::OK();
} }
Status RemoteProfilerSessionManager::ValidateOptionsLocked() {
if (!options_.service_addresses_size()) {
return errors::InvalidArgument("No service addresses specified.");
}
if (options_.profiler_options().duration_ms() == 0) {
if (options_.max_session_duration_ms() != 0) {
return errors::InvalidArgument(
"If local profiler duration is unbounded, profiling session duration "
"must be unbounded.");
}
}
if (options_.max_session_duration_ms() <
options_.profiler_options().duration_ms()) {
return errors::InvalidArgument(
"The maximum profiling session duration must be greater than or equal "
"to the local profiler duration.");
}
return Status::OK();
}
std::vector<RemoteProfilerSessionManager::Response> std::vector<RemoteProfilerSessionManager::Response>
RemoteProfilerSessionManager::WaitForCompletion() { RemoteProfilerSessionManager::WaitForCompletion() {
mutex_lock lock(mutex_); mutex_lock lock(mutex_);
std::vector<RemoteProfilerSessionManager::Response> remote_responses; std::vector<RemoteProfilerSessionManager::Response> remote_responses(
remote_responses.reserve(clients_.size()); clients_.size());
for (auto& client : clients_) { for (int32 idx = 0; idx < clients_.size(); ++idx) {
remote_responses.emplace_back(); auto& remote_response = remote_responses[idx];
auto* profile_response = &remote_responses.back().profile_response; auto* client = clients_[idx].get();
Status& status = remote_responses.back().status; remote_response.profile_response =
std::string* service_addr = &remote_responses.back().service_addr; client->WaitForCompletion(remote_response.status);
*profile_response = client->WaitForCompletion(status); remote_response.service_address = std::string(client->GetServiceAddress());
*service_addr = std::string(client->GetServiceAddress());
} }
return remote_responses; return remote_responses;
} }

View File

@ -26,9 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/rpc/client/profiler_client.h" #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
namespace tensorflow { namespace tensorflow {
@ -40,21 +37,16 @@ using AddressResolver = std::function<std::string(absl::string_view)>;
class RemoteProfilerSessionManager { class RemoteProfilerSessionManager {
public: public:
struct Response { struct Response {
std::string service_addr; std::string service_address;
std::unique_ptr<ProfileResponse> profile_response; std::unique_ptr<ProfileResponse> profile_response;
Status status; Status status;
}; };
// Instantiates a collection of RemoteProfilerSessions starts profiling on // Instantiates a collection of RemoteProfilerSessions starts profiling on
// each of them immediately. // each of them immediately. Assumes that options have already been validated.
static std::unique_ptr<RemoteProfilerSessionManager> Create( static std::unique_ptr<RemoteProfilerSessionManager> Create(
const RemoteProfilerSessionManagerOptions& options, const RemoteProfilerSessionManagerOptions& options,
tensorflow::Status& out_status, AddressResolver resolver = nullptr); const ProfileRequest& request, tensorflow::Status& out_status,
AddressResolver resolver = nullptr);
static RemoteProfilerSessionManagerOptions DefaultOptions() {
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() = ProfilerSession::DefaultOptions();
return options;
}
// Awaits for responses from remote profiler sessions and returns them as a // Awaits for responses from remote profiler sessions and returns them as a
// list. Subsequent calls beyond the first will yield a list of errors. // list. Subsequent calls beyond the first will yield a list of errors.
@ -69,16 +61,16 @@ class RemoteProfilerSessionManager {
private: private:
explicit RemoteProfilerSessionManager( explicit RemoteProfilerSessionManager(
RemoteProfilerSessionManagerOptions options, AddressResolver resolver); RemoteProfilerSessionManagerOptions options, ProfileRequest request,
AddressResolver resolver);
// Initialization of all client contexts. // Initialization of all client contexts.
Status Init(); Status Init();
Status ValidateOptionsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
mutex mutex_; mutex mutex_;
// Remote profiler session options. // Remote profiler session options.
RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_); RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_);
ProfileRequest request_ TF_GUARDED_BY(mutex_);
// List of clients, each connects to a profiling service. // List of clients, each connects to a profiling service.
std::vector<std::unique_ptr<RemoteProfilerSession>> clients_ std::vector<std::unique_ptr<RemoteProfilerSession>> clients_
TF_GUARDED_BY(mutex_); TF_GUARDED_BY(mutex_);

View File

@ -35,46 +35,73 @@ namespace {
using ::tensorflow::profiler::test::DurationApproxLess; using ::tensorflow::profiler::test::DurationApproxLess;
using ::tensorflow::profiler::test::DurationNear; using ::tensorflow::profiler::test::DurationNear;
using ::tensorflow::profiler::test::StartServer; using ::tensorflow::profiler::test::StartServer;
using ::tensorflow::testing::TmpDir;
using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response; using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response;
// Copied from capture_profile to not introduce a dependency.
ProfileRequest PopulateProfileRequest(
absl::string_view repository_root, absl::string_view session_id,
absl::string_view host_name,
const RemoteProfilerSessionManagerOptions& options) {
constexpr uint64 kMaxEvents = 1000000;
const absl::string_view kXPlanePb = "xplane.pb";
ProfileRequest request;
// TODO(b/169976117) Remove duration from request.
request.set_duration_ms(options.profiler_options().duration_ms());
request.set_max_events(kMaxEvents);
request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id.data(), session_id.size());
request.set_host_name(host_name.data(), host_name.size());
// XPlane tool is only used by OSS profiler and safely ignored by TPU
// profiler.
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
*request.mutable_opts() = options.profiler_options();
return request;
}
TEST(RemoteProfilerSessionManagerTest, Simple) { TEST(RemoteProfilerSessionManagerTest, Simple) {
absl::Duration duration = absl::Milliseconds(30); absl::Duration duration = absl::Milliseconds(30);
RemoteProfilerSessionManagerOptions options = RemoteProfilerSessionManagerOptions options;
RemoteProfilerSessionManager::DefaultOptions(); *options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms( options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration)); absl::ToInt64Milliseconds(duration));
std::string service_addresses; std::string service_address;
auto server = StartServer(duration, &service_addresses); auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_addresses); options.add_service_addresses(service_address);
absl::Time approx_start = absl::Now(); absl::Time approx_start = absl::Now();
absl::Duration grace = absl::Seconds(1); absl::Duration grace = absl::Seconds(1);
absl::Duration max_duration = duration + grace; absl::Duration max_duration = duration + grace;
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start)); options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status; Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status); auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion(); std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
ASSERT_EQ(responses.size(), 1); ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(responses.back().status.ok()); EXPECT_TRUE(responses.back().status.ok());
EXPECT_FALSE(responses.back().profile_response->empty_trace()); EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0); EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
} }
TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) { TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
absl::Duration duration = absl::Milliseconds(30); absl::Duration duration = absl::Milliseconds(30);
RemoteProfilerSessionManagerOptions options = RemoteProfilerSessionManagerOptions options;
RemoteProfilerSessionManager::DefaultOptions(); *options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms( options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration)); absl::ToInt64Milliseconds(duration));
std::string service_addresses; std::string service_address;
auto server = StartServer(duration, &service_addresses); auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_addresses); options.add_service_addresses(service_address);
absl::Duration grace = absl::Seconds(1); absl::Duration grace = absl::Seconds(1);
absl::Duration max_duration = duration + grace; absl::Duration max_duration = duration + grace;
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
@ -82,28 +109,32 @@ TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
options.set_session_creation_timestamp_ns(0); options.set_session_creation_timestamp_ns(0);
absl::Time approx_start = absl::Now(); absl::Time approx_start = absl::Now();
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status; Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status); auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion(); std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0))); EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0)));
ASSERT_EQ(responses.size(), 1); ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status)); EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status));
EXPECT_FALSE(responses.back().profile_response->empty_trace()); EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0); EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
} }
TEST(RemoteProfilerSessionManagerTest, LongSession) { TEST(RemoteProfilerSessionManagerTest, LongSession) {
absl::Duration duration = absl::Seconds(3); absl::Duration duration = absl::Seconds(3);
RemoteProfilerSessionManagerOptions options = RemoteProfilerSessionManagerOptions options;
RemoteProfilerSessionManager::DefaultOptions(); *options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms( options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration)); absl::ToInt64Milliseconds(duration));
std::string service_addresses; std::string service_address;
auto server = StartServer(duration, &service_addresses); auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_addresses); options.add_service_addresses(service_address);
absl::Time approx_start = absl::Now(); absl::Time approx_start = absl::Now();
// Empirically determined value. // Empirically determined value.
absl::Duration grace = absl::Seconds(20); absl::Duration grace = absl::Seconds(20);
@ -111,15 +142,18 @@ TEST(RemoteProfilerSessionManagerTest, LongSession) {
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start)); options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status; Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status); auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion(); std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start; absl::Duration elapsed = absl::Now() - approx_start;
ASSERT_EQ(responses.size(), 1); ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(responses.back().status.ok()); EXPECT_TRUE(responses.back().status.ok());
EXPECT_FALSE(responses.back().profile_response->empty_trace()); EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0); EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
} }

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/support/status.h" #include "grpcpp/support/status.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
@ -31,6 +32,8 @@ limitations under the License.
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h" #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
#include "tensorflow/core/profiler/profiler_service.pb.h" #include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/file_system_utils.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
namespace tensorflow { namespace tensorflow {
namespace profiler { namespace profiler {
@ -38,15 +41,31 @@ namespace {
const absl::string_view kXPlanePb = "xplane.pb"; const absl::string_view kXPlanePb = "xplane.pb";
Status CollectDataToResponse(const ProfileRequest& req, // Collects data in XSpace format. The data is saved to a repository
ProfilerSession* profiler, // unconditionally.
ProfileResponse* response) { Status CollectDataToRepository(const ProfileRequest& request,
profiler::XSpace xspace; ProfilerSession* profiler,
ProfileResponse* response) {
response->set_empty_trace(true);
// Read the profile data into xspace.
XSpace xspace;
TF_RETURN_IF_ERROR(profiler->CollectData(&xspace)); TF_RETURN_IF_ERROR(profiler->CollectData(&xspace));
auto* tool_data = response->add_tool_data(); VLOG(3) << "Collected XSpace to repository.";
tool_data->set_name(kXPlanePb.data(), kXPlanePb.size()); response->set_empty_trace(IsEmpty(xspace));
xspace.SerializeToString(tool_data->mutable_data());
return Status::OK(); std::string log_dir_path =
ProfilerJoinPath(request.repository_root(), request.session_id());
VLOG(1) << "Creating " << log_dir_path;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir_path));
std::string file_name = absl::StrCat(request.host_name(), ".", kXPlanePb);
// Windows file names do not support colons.
absl::StrReplaceAll({{":", "_"}}, &file_name);
// Dumps profile data to <repository_root>/<run>/<host>_<port>.<kXPlanePb>
std::string out_path = ProfilerJoinPath(log_dir_path, file_name);
LOG(INFO) << "Collecting XSpace to repository: " << out_path;
return WriteBinaryProto(Env::Default(), out_path, xspace);
} }
class ProfilerServiceImpl : public grpc::ProfilerService::Service { class ProfilerServiceImpl : public grpc::ProfilerService::Service {
@ -68,7 +87,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
} }
Env* env = Env::Default(); Env* env = Env::Default();
for (uint64 i = 0; i < req->duration_ms(); ++i) { for (uint64 i = 0; i < req->opts().duration_ms(); ++i) {
env->SleepForMicroseconds(EnvTime::kMillisToMicros); env->SleepForMicroseconds(EnvTime::kMillisToMicros);
if (ctx->IsCancelled()) { if (ctx->IsCancelled()) {
return ::grpc::Status::CANCELLED; return ::grpc::Status::CANCELLED;
@ -80,7 +99,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
} }
} }
status = CollectDataToResponse(*req, profiler.get(), response); status = CollectDataToRepository(*req, profiler.get(), response);
if (!status.ok()) { if (!status.ok()) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL, return ::grpc::Status(::grpc::StatusCode::INTERNAL,
status.error_message()); status.error_message());
@ -116,5 +135,4 @@ std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService() {
} }
} // namespace profiler } // namespace profiler
} // namespace tensorflow } // namespace tensorflow

View File

@ -172,7 +172,7 @@ class ServerLibTest(test.TestCase):
# return UnavailableError with no trace events collected string. # return UnavailableError with no trace events collected string.
with self.assertRaises(errors.UnavailableError) as error: with self.assertRaises(errors.UnavailableError) as error:
profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
self.assertEqual("No trace event is collected", str(error.exception)) self.assertStartsWith(str(error.exception), "No trace event was collected")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ cuda_py_test(
srcs = ["profiler_api_test.py"], srcs = ["profiler_api_test.py"],
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"external", # So that test suite reruns unconditionally.
"no_pip", "no_pip",
"no_rocm", "no_rocm",
], ],

View File

@ -67,10 +67,15 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
'kernel_stats.pb', 'kernel_stats.pb',
] ]
for file in expected_files: for file in expected_files:
path = os.path.join(logdir, 'plugins/profile/*/*{}'.format(file)) path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file))
self.assertEqual(1, len(glob.glob(path)), self.assertEqual(1, len(glob.glob(path)),
'Expected one path match: ' + path) 'Expected one path match: ' + path)
def _check_xspace_pb_exist(self, logdir):
path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb')
self.assertEqual(1, len(glob.glob(path)),
'Expected one path match: ' + path)
def test_single_worker_no_profiling(self): def test_single_worker_no_profiling(self):
"""Test single worker without profiling.""" """Test single worker without profiling."""
@ -86,7 +91,6 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
profiler.start_server(port) profiler.start_server(port)
_, steps, train_ds, model = _model_setup() _, steps, train_ds, model = _model_setup()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
logging.info('worker finishing')
def on_profile(port, logdir): def on_profile(port, logdir):
# Request for 30 milliseconds of profile. # Request for 30 milliseconds of profile.
@ -109,7 +113,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
thread_profiler.start() thread_profiler.start()
thread_profiler.join() thread_profiler.join()
thread_worker.join(120) thread_worker.join(120)
self._check_tools_pb_exist(logdir) self._check_xspace_pb_exist(logdir)
def test_single_worker_programmatic_mode(self): def test_single_worker_programmatic_mode(self):
"""Test single worker programmatic mode.""" """Test single worker programmatic mode."""

View File

@ -130,6 +130,7 @@ tf_python_pybind_extension(
"//tensorflow/python:pybind11_status", "//tensorflow/python:pybind11_status",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@pybind11", "@pybind11",
], ],
) )

View File

@ -14,11 +14,17 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "pybind11/cast.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/pytypes.h" #include "pybind11/pytypes.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
@ -38,7 +44,12 @@ namespace py = ::pybind11;
namespace { namespace {
tensorflow::Status ValidateHostPortPair(const std::string& host_port) { using ::tensorflow::RemoteProfilerSessionManagerOptions;
// Profiler gives grace after profiling duration to terminate.
constexpr absl::Duration kSessionGraceTime = absl::Seconds(5);
tensorflow::Status ValidateHostPortPair(absl::string_view host_port) {
tensorflow::uint32 port; tensorflow::uint32 port;
std::vector<absl::string_view> parts = absl::StrSplit(host_port, ':'); std::vector<absl::string_view> parts = absl::StrSplit(host_port, ':');
// Must be host:port, port must be a number, host must not contain a '/', // Must be host:port, port must be a number, host must not contain a '/',
@ -51,34 +62,156 @@ tensorflow::Status ValidateHostPortPair(const std::string& host_port) {
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
// Takes profiler options in a py::dict and returns a ProfileOptions. tensorflow::Status ValidateOptions(
const RemoteProfilerSessionManagerOptions& options) {
if (options.service_addresses().empty()) {
return tensorflow::errors::InvalidArgument("No service address provided.");
}
if (options.profiler_options().duration_ms() == 0) {
return tensorflow::errors::InvalidArgument(
"duration_ms must be greater than zero.");
}
for (absl::string_view host_port : options.service_addresses()) {
TF_RETURN_IF_ERROR(ValidateHostPortPair(host_port));
}
if (options.max_session_duration_ms() <
options.profiler_options().duration_ms()) {
return tensorflow::errors::InvalidArgument(
"The maximum profiling session duration must be greater than or equal "
"to the local profiler duration.");
}
return tensorflow::Status::OK();
}
// Receives a comma delimited list of service_addresses and adds them to
// RemoteProfilerSessionManagerOptions::service_addresses.
void AddServiceAddresses(absl::string_view service_addresses,
RemoteProfilerSessionManagerOptions* options) {
for (absl::string_view server : absl::StrSplit(service_addresses, ',')) {
options->add_service_addresses(server.data(), server.size());
}
}
// Sets gRPC deadline to a grace period based on the profiling duration.
void UpdateMaxSessionDuration(RemoteProfilerSessionManagerOptions& options) {
auto local_profiler_duration = options.profiler_options().duration_ms();
auto session_creation_ts = options.session_creation_timestamp_ns();
auto requested_start_ts = options.profiler_options().start_timestamp_ns();
// User only needs to set maximal session duration if the profiling duration
// is bounded.
DCHECK_GT(local_profiler_duration, 0);
VLOG(3) << "duration_ms was given as " << local_profiler_duration;
// Max session duration includes the profiling session and grace time.
auto profile_duration =
absl::Milliseconds(local_profiler_duration) + kSessionGraceTime;
absl::Duration delay_duration;
// When requested start timestamp is 0, profiling starts immediately.
if (requested_start_ts > 0) {
delay_duration =
absl::Nanoseconds(requested_start_ts - session_creation_ts);
}
auto max_session_duration = profile_duration + delay_duration;
options.set_max_session_duration_ms(
absl::ToInt64Milliseconds(max_session_duration));
VLOG(1) << "max_session_duration set to " << max_session_duration;
}
// Takes profiler options in a py::dict and returns a
// RemoteProfilerSessionManagerOptions.
// This must be called under GIL because it reads Python objects. Reading Python // This must be called under GIL because it reads Python objects. Reading Python
// objects require GIL because the objects can be mutated by other Python // objects require GIL because the objects can be mutated by other Python
// threads. In addition, Python objects are reference counted; reading py::dict // threads. In addition, Python objects are reference counted; reading py::dict
// will increase its reference count. // will increase its reference count.
tensorflow::ProfileOptions GetOptionsLocked(const py::dict& opts) { RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir,
tensorflow::ProfileOptions options = const py::dict& opts) {
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions(); tensorflow::ProfilerSession::DefaultOptions();
// Store a timestamp of when this session was created. This will be the basis
// of gRPC deadline afterwards.
auto now = absl::Now();
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(now));
VLOG(2) << "set_session_creation_timestamp_ns set to "
<< options.session_creation_timestamp_ns() << " [" << now << "]";
// Set the path of where to store XSpaces.
options.mutable_profiler_options()->set_repository_path(logdir.data(),
logdir.size());
VLOG(2) << "repository_path set to "
<< options.profiler_options().repository_path();
for (const auto& kw : opts) { for (const auto& kw : opts) {
std::string key = py::cast<std::string>(kw.first); std::string key = py::cast<std::string>(kw.first);
if (key == "host_tracer_level") { if (key == "host_tracer_level") {
options.set_host_tracer_level(py::cast<int>(kw.second)); auto value = py::cast<int>(kw.second);
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level(); options.mutable_profiler_options()->set_host_tracer_level(value);
VLOG(1) << "host_tracer_level set to " << value;
} else if (key == "device_tracer_level") { } else if (key == "device_tracer_level") {
options.set_device_tracer_level(py::cast<int>(kw.second)); auto value = py::cast<int>(kw.second);
VLOG(1) << "device_tracer_level set to " << options.device_tracer_level(); options.mutable_profiler_options()->set_device_tracer_level(value);
VLOG(1) << "device_tracer_level set to " << value;
} else if (key == "python_tracer_level") { } else if (key == "python_tracer_level") {
options.set_python_tracer_level(py::cast<int>(kw.second)); auto value = py::cast<int>(kw.second);
VLOG(1) << "python_tracer_level set to " << options.python_tracer_level(); options.mutable_profiler_options()->set_python_tracer_level(value);
VLOG(1) << "python_tracer_level set to " << value;
} else {
LOG(WARNING) << "Unrecognised key: " << key;
} }
} }
return options;
}
RemoteProfilerSessionManagerOptions GetOptionsLocked(
absl::string_view service_addresses, absl::string_view logdir,
absl::string_view worker_list, bool include_dataset_ops,
tensorflow::int32 duration_ms, py::dict opts, bool* is_cloud_tpu_session) {
RemoteProfilerSessionManagerOptions options = GetOptionsLocked(logdir, opts);
// Remote profiling does not support any use cases where the following options
// are set by `py::dict opts`. e.g. `opts['service_addrs']` will not happen.
DCHECK(options.service_addresses().empty());
// In remote profiling, duration is always passed by value explicitly and not
// set in py::dict opts.
DCHECK_EQ(options.profiler_options().duration_ms(), 0);
// Because duration_ms is not set from py::dict opts, it follows that
// max_session_duration_ms must be unset as well.
DCHECK_EQ(options.max_session_duration_ms(), 0);
// Worker_list is only used for TensorBoard TPU capture cases. For a TPU
// cluster, service_address is the Master, which can already be found in the
// list of workers. These sessions will be used with the ProfileAnalysis
// service.
*is_cloud_tpu_session = !worker_list.empty();
AddServiceAddresses(*is_cloud_tpu_session ? worker_list : service_addresses,
&options);
// Set local profiler duration and profiler session durations.
options.mutable_profiler_options()->set_include_dataset_ops(
include_dataset_ops);
options.mutable_profiler_options()->set_duration_ms(duration_ms);
UpdateMaxSessionDuration(options);
for (int idx = 0; idx < options.service_addresses_size(); ++idx) {
VLOG(1) << "service_addr " << idx << " set to "
<< options.service_addresses(idx);
}
VLOG(1) << "include_dataset_ops set to " << include_dataset_ops;
VLOG(1) << "duration_ms set to " << duration_ms;
return options; return options;
} }
class ProfilerSessionWrapper { class ProfilerSessionWrapper {
public: public:
void Start(const char* logdir, const py::dict& options) { void Start(const char* logdir, const py::dict& options) {
session_ = tensorflow::ProfilerSession::Create(GetOptionsLocked(options)); auto opts = GetOptionsLocked(logdir, options);
session_ = tensorflow::ProfilerSession::Create(opts.profiler_options());
logdir_ = logdir; logdir_ = logdir;
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status()); tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
} }
@ -130,26 +263,28 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
profiler_server.release(); profiler_server.release();
}); });
m.def("trace", m.def("trace", [](const char* service_addr, const char* logdir,
[](const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops,
const char* worker_list, bool include_dataset_ops, int duration_ms, int duration_ms, int num_tracing_attempts,
int num_tracing_attempts, py::dict options) { py::dict options) {
// Normalize py::dict into a well defined proto. // TPU capture is true if the user sets worker_list.
tensorflow::ProfileOptions opts = GetOptionsLocked(options); bool is_cloud_tpu_session = false;
// Normalize py::dict into a well defined and validated proto.
tensorflow::RemoteProfilerSessionManagerOptions opts =
GetOptionsLocked(service_addr, logdir, worker_list, include_dataset_ops,
duration_ms, options, &is_cloud_tpu_session);
tensorflow::Status status = ValidateOptions(opts);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
tensorflow::Status status = ValidateHostPortPair(service_addr); {
tensorflow::MaybeRaiseRegisteredFromStatus(status); // Release the lock to keep the lock scope to a minimum, and allow
opts.set_include_dataset_ops(include_dataset_ops); // other threads to proceed.
{ py::gil_scoped_release release;
// Release the lock to keep the lock scope to a minimum, and allow status = tensorflow::profiler::Trace(logdir, num_tracing_attempts, opts,
// other threads to proceed. is_cloud_tpu_session);
py::gil_scoped_release release; }
status = tensorflow::profiler::Trace(service_addr, logdir, tensorflow::MaybeRaiseRegisteredFromStatus(status);
worker_list, duration_ms, });
num_tracing_attempts, opts);
}
tensorflow::MaybeRaiseRegisteredFromStatus(status);
});
m.def("monitor", [](const char* service_addr, int duration_ms, m.def("monitor", [](const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp) { int monitoring_level, bool display_timestamp) {

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import errors
from tensorflow.python.profiler.internal import _pywrap_profiler from tensorflow.python.profiler.internal import _pywrap_profiler
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -32,38 +33,57 @@ def trace(service_addr,
worker_list='', worker_list='',
num_tracing_attempts=3, num_tracing_attempts=3,
options=None): options=None):
"""Sends grpc requests to profiler server to perform on-demand profiling. """Sends gRPC requests to one or more profiler servers to perform on-demand profiling.
This method will block caller thread until it receives tracing result. This This method will block the calling thread until it receives responses from all
method supports CPU, GPU, and Cloud TPU. This method supports profiling a servers or until deadline expiration. Both single host and multiple host
single host for CPU, GPU, TPU, as well as multiple TPU workers. profiling are supported on CPU, GPU, and TPU.
The profiled results will be saved to your specified TensorBoard log The profiled results will be saved by each server to the specified TensorBoard
directory (e.g. the directory you save your model checkpoints). Use the log directory (i.e. the directory you save your model checkpoints). Use the
TensorBoard profile plugin to view the visualization and analysis results. TensorBoard profile plugin to view the visualization and analysis results.
Args: Args:
service_addr: gRPC address of profiler service e.g. grpc://localhost:6009. service_addr: A comma delimited string of gRPC addresses of the workers to
logdir: Path of TensorBoard log directory e.g. /tmp/tb_log. profile.
duration_ms: Duration of tracing or monitoring in ms. e.g. service_addr='grpc://localhost:6009'
worker_list: Optional. The list of workers that we are about to profile in service_addr='grpc://10.0.0.2:8466,grpc://10.0.0.3:8466'
the current session (TPU only). service_addr='grpc://localhost:12345,grpc://localhost:23456'
logdir: Path to save profile data to, typically a TensorBoard log directory.
This path must be accessible to both the client and server.
e.g. logdir='gs://your_tb_dir'
duration_ms: Duration of tracing or monitoring in mliiseconds. Must be
greater than zero.
worker_list: An optional TPU only configuration. The list of workers to
profile in the current session.
num_tracing_attempts: Optional. Automatically retry N times when no trace num_tracing_attempts: Optional. Automatically retry N times when no trace
event is collected (default 3). event is collected (default 3).
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
profiler options. profiler options.
Raises: Raises:
UnavailableError: If no trace event is collected. InvalidArgumentError: For when arguments fail validation checks.
UnavailableError: If no trace event was collected.
Example usage (CPU/GPU): Example usage (CPU/GPU):
# Start a profiler server before your model runs. # Start a profiler server before your model runs.
```python ```python
tf.profiler.experimental.server.start(6009) tf.profiler.experimental.server.start(6009)
# your model code. # (Model code goes here).
# Send gRPC request to the profiler server to collect a trace of your model. # Send gRPC request to the profiler server to collect a trace of your model.
```python ```python
tf.profiler.experimental.client.trace('grpc://localhost:6009', tf.profiler.experimental.client.trace('grpc://localhost:6009',
'/tmp/tb_log', 2000) '/nfs/tb_log', 2000)
Example usage (Multiple GPUs):
# E.g. your worker IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you
# would like to schedule start of profiling 1 second from now, for a duration
# of 2 seconds.
options['delay_ms'] = 1000
tf.profiler.experimental.client.trace(
'grpc://10.0.0.2:8466,grpc://10.0.0.3:8466,grpc://10.0.0.4:8466',
'gs://your_tb_dir',
2000,
options=options)
Example usage (TPU): Example usage (TPU):
# Send gRPC request to a TPU worker to collect a trace of your model. A # Send gRPC request to a TPU worker to collect a trace of your model. A
@ -82,16 +102,19 @@ def trace(service_addr,
# profile for 2 seconds. # profile for 2 seconds.
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466', tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
'gs://your_tb_dir', 'gs://your_tb_dir',
2000, '10.0.0.3,10.0.0.4') 2000, '10.0.0.2,10.0.0.3,10.0.0.4')
Launch TensorBoard and point it to the same logdir you provided to this API. Launch TensorBoard and point it to the same logdir you provided to this API.
$ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples) $ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples)
Open your browser and go to localhost:6006/#profile to view profiling results. Open your browser and go to localhost:6006/#profile to view profiling results.
""" """
if duration_ms <= 0:
raise errors.InvalidArgumentError(None, None,
'duration_ms must be greater than zero.')
opts = dict(options._asdict()) if options is not None else {} opts = dict(options._asdict()) if options is not None else {}
_pywrap_profiler.trace( _pywrap_profiler.trace(
_strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True, _strip_addresses(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
duration_ms, num_tracing_attempts, opts) duration_ms, num_tracing_attempts, opts)
@ -127,3 +150,7 @@ def monitor(service_addr, duration_ms, level=1):
def _strip_prefix(s, prefix): def _strip_prefix(s, prefix):
return s[len(prefix):] if s.startswith(prefix) else s return s[len(prefix):] if s.startswith(prefix) else s
def _strip_addresses(addresses, prefix):
return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')])

View File

@ -38,7 +38,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.UnavailableError) as error: with self.assertRaises(errors.UnavailableError) as error:
profiler_client.trace( profiler_client.trace(
'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10) 'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10)
self.assertEqual('No trace event is collected', str(error.exception)) self.assertStartsWith(str(error.exception), 'No trace event was collected')
def testTrace_ProfileIdleServerWithOptions(self): def testTrace_ProfileIdleServerWithOptions(self):
test_port = portpicker.pick_unused_port() test_port = portpicker.pick_unused_port()
@ -54,7 +54,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
self.get_temp_dir(), self.get_temp_dir(),
duration_ms=10, duration_ms=10,
options=options) options=options)
self.assertEqual('No trace event is collected', str(error.exception)) self.assertStartsWith(str(error.exception), 'No trace event was collected')
def testMonitor_ProcessInvalidAddress(self): def testMonitor_ProcessInvalidAddress(self):
# Monitor is only supported in cloud TPU. Test invalid address instead. # Monitor is only supported in cloud TPU. Test invalid address instead.

View File

@ -346,6 +346,10 @@ tensorflow::profiler::ProfilerServer::~ProfilerServer
tensorflow::profiler::ProfileGrpc tensorflow::profiler::ProfileGrpc
tensorflow::profiler::NewSessionGrpc tensorflow::profiler::NewSessionGrpc
tensorflow::profiler::MonitorGrpc tensorflow::profiler::MonitorGrpc
tensorflow::profiler::RemoteProfilerSession::Create
tensorflow::profiler::RemoteProfilerSession::GetServiceAddress
tensorflow::profiler::RemoteProfilerSession::WaitForCompletion
tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession
[status_macros] # tfcompile [status_macros] # tfcompile
xla::status_macros::MakeErrorStream::Impl::Impl xla::status_macros::MakeErrorStream::Impl::Impl