Added multi-worker sampling mode.

* Profile() now calls RemoteProfilerSessionManager (RPSM).
* Profiler server saves XSpace to repository instead of returning XSpace by ProfileResponse.
  * ProfileResponse will have a non-empty trace iff there are any XPlanes.

PiperOrigin-RevId: 336807134
Change-Id: I664400c9715d73605b9daa4cfdcf6475dee4e959
This commit is contained in:
Yi Situ 2020-10-12 21:54:35 -07:00 committed by TensorFlower Gardener
parent 78ba66c122
commit 8813a286b8
20 changed files with 462 additions and 291 deletions

View File

@ -47,8 +47,11 @@ cc_library(
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:file_system_utils",
"//tensorflow/core/profiler/utils:xplane_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
],
)

View File

@ -27,6 +27,7 @@ cc_library(
],
deps = [
":profiler_client_for_pybind",
":remote_profiler_session_manager",
":save_profile",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -133,14 +134,11 @@ cc_library(
srcs = ["remote_profiler_session_manager.cc"],
hdrs = ["remote_profiler_session_manager.h"],
copts = tf_profiler_copts(),
visibility = ["//tensorflow/core/profiler:internal"],
deps = [
":profiler_client_for_pybind",
":save_profile",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/utils:time_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -30,12 +30,16 @@ limitations under the License.
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
namespace tensorflow {
namespace profiler {
namespace {
using ::tensorflow::profiler::RemoteProfilerSessionManager;
using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response;
constexpr uint64 kMaxEvents = 1000000;
const absl::string_view kXPlanePb = "xplane.pb";
@ -48,17 +52,18 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
return request;
}
ProfileRequest PopulateProfileRequest(int duration_ms,
const std::string& repository_root,
const std::string& session_id,
const std::string& host_name,
const ProfileOptions& opts) {
ProfileRequest PopulateProfileRequest(
absl::string_view repository_root, absl::string_view session_id,
absl::string_view host_name,
const RemoteProfilerSessionManagerOptions& options) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
// TODO(b/169976117) Remove duration from request.
request.set_duration_ms(options.profiler_options().duration_ms());
request.set_max_events(kMaxEvents);
request.set_repository_root(repository_root);
request.set_session_id(session_id);
request.set_host_name(host_name);
request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id.data(), session_id.size());
request.set_host_name(host_name.data(), host_name.size());
// These tools are only used by TPU profiler.
request.add_tools("trace_viewer");
request.add_tools("op_profile");
request.add_tools("input_pipeline");
@ -68,21 +73,26 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
request.add_tools("overview_page");
request.add_tools("pod_viewer");
request.add_tools("tensorflow_stats");
*request.mutable_opts() = opts;
// XPlane tool is only used by OSS profiler and safely ignored by TPU
// profiler.
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
*request.mutable_opts() = options.profiler_options();
return request;
}
NewProfileSessionRequest PopulateNewProfileSessionRequest(
const std::string& service_addr, const std::string& repository_root,
const std::vector<string>& hostnames, int duration_ms,
const std::string& session_id, const ProfileOptions& opts) {
absl::string_view repository_root, absl::string_view session_id,
const RemoteProfilerSessionManagerOptions& opts) {
NewProfileSessionRequest request;
std::vector<std::string> parts = absl::StrSplit(service_addr, ':');
*request.mutable_request() = PopulateProfileRequest(
duration_ms, repository_root, session_id, parts[0], opts);
request.set_repository_root(repository_root);
request.set_session_id(session_id);
for (const auto& hostname : hostnames) {
std::vector<absl::string_view> parts =
absl::StrSplit(opts.service_addresses(0), ':');
DCHECK(!parts.empty());
*request.mutable_request() =
PopulateProfileRequest(repository_root, session_id, parts[0], opts);
request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id.data(), session_id.size());
for (const auto& hostname : opts.service_addresses()) {
request.add_hosts(hostname);
}
return request;
@ -99,44 +109,40 @@ inline bool ShouldRetryTracing(Status status) {
status.error_message() == "Stream removed");
}
// If the ProfileResponse has single 'xplane.pb' tool, convert the xplane to
// other tools and add in ProfileResponse. Otherwise, the ProfileResponse is
// already converted, simply return.
Status ConvertXSpaceToToolsInProfileResponse(const ProfileRequest& request,
ProfileResponse* response) {
if (response->tool_data_size() != 1) return Status::OK();
if (response->tool_data(0).name() != kXPlanePb) return Status::OK();
XSpace xspace;
xspace.ParseFromString(response->tool_data(0).data());
TF_RETURN_IF_ERROR(ConvertXSpaceToProfileResponse(xspace, request, response));
return Status::OK();
}
Status Profile(const std::string& repository_root,
const std::string& session_id,
const RemoteProfilerSessionManagerOptions& opts) {
Status status;
// Host name will be overwritten by RemoteProfilerSessionManager later.
ProfileRequest request = PopulateProfileRequest(repository_root, session_id,
/*host_name=*/"", opts);
auto session = RemoteProfilerSessionManager::Create(opts, request, status);
TF_RETURN_IF_ERROR(status);
// Expect one or more service addresses.
DCHECK_GT(opts.service_addresses_size(), 0);
std::vector<Response> responses = session->WaitForCompletion();
// Expect responses to have the same size as clients.
DCHECK_EQ(responses.size(), opts.service_addresses_size());
Status Profile(const std::string& service_addr,
const std::string& repository_root, int duration_ms,
const std::string& session_id, const ProfileOptions& opts) {
std::vector<std::string> parts = absl::StrSplit(service_addr, ':');
ProfileRequest request = PopulateProfileRequest(duration_ms, repository_root,
session_id, parts[0], opts);
ProfileResponse response;
TF_RETURN_IF_ERROR(ProfileGrpc(service_addr, request, &response));
if (!response.empty_trace()) {
TF_RETURN_IF_ERROR(
ConvertXSpaceToToolsInProfileResponse(request, &response));
TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id,
request.host_name(), response, &std::cout));
// Print this at the end so that it's not buried in irrelevant LOG messages.
std::cout
<< "NOTE: using the trace duration " << duration_ms << "ms.\n"
<< "Set an appropriate duration (with --duration_ms) if you "
"don't see a full step in your trace or the captured trace is too "
"large."
<< std::endl;
bool has_trace_data = false;
for (const auto& client_response : responses) {
ProfileResponse& response = *client_response.profile_response;
if (response.empty_trace()) {
LOG(WARNING) << "No trace event is collected from "
<< client_response.service_address;
} else {
has_trace_data = true;
}
if (!client_response.status.ok()) {
LOG(WARNING) << client_response.service_address << " returned "
<< client_response.status;
}
}
if (response.empty_trace()) {
return Status(error::Code::UNAVAILABLE, "No trace event is collected");
if (!has_trace_data) {
return Status(error::Code::UNAVAILABLE,
"No trace event was collected because there were no responses"
" from clients or the responses did not have trace data.");
}
return Status::OK();
}
@ -144,52 +150,47 @@ Status Profile(const std::string& service_addr,
// Start a new profiling session that include all the hosts included in
// hostnames, for the time interval of duration_ms. Possibly save the profiling
// result in the directory specified by repository_root and session_id.
Status NewSession(const std::string& service_addr,
const std::string& repository_root,
const std::vector<string>& hostnames, int duration_ms,
const std::string& session_id, const ProfileOptions& opts) {
NewProfileSessionRequest request = PopulateNewProfileSessionRequest(
service_addr, repository_root, hostnames, duration_ms, session_id, opts);
Status NewSession(absl::string_view repository_root,
absl::string_view session_id,
const RemoteProfilerSessionManagerOptions& opts) {
NewProfileSessionRequest request =
PopulateNewProfileSessionRequest(repository_root, session_id, opts);
NewProfileSessionResponse response;
TF_RETURN_IF_ERROR(NewSessionGrpc(service_addr, request, &response));
TF_RETURN_IF_ERROR(
NewSessionGrpc(opts.service_addresses(0), request, &response));
std::cout << "Profile session succeed for host(s):"
<< absl::StrJoin(hostnames, ",") << std::endl;
<< absl::StrJoin(opts.service_addresses(), ",") << std::endl;
if (response.empty_trace()) {
return Status(error::Code::UNAVAILABLE, "No trace event is collected");
return errors::Unavailable("No trace event is collected");
}
return Status::OK();
}
} // namespace
// Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts.
Status Trace(const std::string& service_addr, const std::string& logdir,
const std::string& workers_list, int duration_ms,
int num_tracing_attempts, const ProfileOptions& opts) {
Status Trace(const std::string& logdir, int num_tracing_attempts,
const RemoteProfilerSessionManagerOptions& opts,
bool is_cloud_tpu_session) {
DCHECK_GT(opts.profiler_options().duration_ms(), 0);
DCHECK(!opts.service_addresses().empty());
// Use the current timestamp as the run name.
std::string session_id = GetCurrentTimeStampAsString();
std::vector<std::string> hostnames;
if (!workers_list.empty()) {
hostnames = absl::StrSplit(workers_list, ',');
}
std::string repository_root = GetTensorBoardProfilePluginDir(logdir);
auto duration_ms = opts.profiler_options().duration_ms();
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
std::string repository_root =
profiler::GetTensorBoardProfilePluginDir(logdir);
Status status = Status::OK();
Status status;
int remaining_attempts = num_tracing_attempts;
while (true) {
std::cout << "Starting to trace for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;
if (hostnames.empty()) {
status =
Profile(service_addr, repository_root, duration_ms, session_id, opts);
if (is_cloud_tpu_session) {
status = NewSession(repository_root, session_id, opts);
} else {
status = NewSession(service_addr, repository_root, hostnames, duration_ms,
session_id, opts);
status = Profile(repository_root, session_id, opts);
}
if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
break;
@ -223,11 +224,10 @@ Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) {
ProfileResponse response;
ProfileRequest request = PopulateProfileRequest(
/*duration_ms=*/0, GetTensorBoardProfilePluginDir(logdir),
GetCurrentTimeStampAsString(), port::Hostname(), /*opts=*/{});
GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(),
port::Hostname(), /*options=*/{});
TF_RETURN_IF_ERROR(
ConvertXSpaceToProfileResponse(xspace, request, &response));
std::stringstream ss; // Record LOG messages.
TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
request.session_id(), request.host_name(),

View File

@ -36,12 +36,12 @@ Status Monitor(const std::string& service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
std::string* result);
// Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts.
Status Trace(const std::string& service_addr, const std::string& logdir,
const std::string& workers_list, int duration_ms,
int num_tracing_attempts, const ProfileOptions& opts);
// Starts tracing on a single or multiple hosts. Each host will save the result
// in the given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts. Assumes that options have been validated.
Status Trace(const std::string& logdir, int num_tracing_attempts,
const RemoteProfilerSessionManagerOptions& opts,
bool is_cloud_tpu_session);
} // namespace profiler
} // namespace tensorflow

View File

@ -99,10 +99,11 @@ RemoteProfilerSession::RemoteProfilerSession(std::string service_address,
service_address_(std::move(service_address)),
stub_(CreateStub<grpc::ProfilerService>(service_address_)),
deadline_(deadline),
profile_request_(std::move(profile_request)) {}
profile_request_(std::move(profile_request)) {
response_->set_empty_trace(true);
}
RemoteProfilerSession::~RemoteProfilerSession() {
LOG(INFO) << "Waiting for completion.";
Status dummy;
WaitForCompletion(dummy);
grpc_context_.TryCancel();
@ -113,6 +114,8 @@ void RemoteProfilerSession::ProfileAsync() {
grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
VLOG(1) << "Deadline set to " << deadline_;
rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
// Connection failure will create lame channel whereby grpc_status_ will be an
// error.
rpc_->Finish(response_.get(), &grpc_status_,
static_cast<void*>(&status_on_completion_));
VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
@ -125,6 +128,7 @@ std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
"WaitForCompletion must only be called once.");
return nullptr;
}
LOG(INFO) << "Waiting for completion.";
void* got_tag = nullptr;
bool ok = false;

View File

@ -82,7 +82,7 @@ class RemoteProfilerSession {
absl::Time deadline_;
::grpc::ClientContext grpc_context_;
std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_;
::grpc::Status grpc_status_;
::grpc::Status grpc_status_ = ::grpc::Status::OK;
// Asynchronous completion queue states.
::grpc::CompletionQueue cq_;

View File

@ -52,8 +52,10 @@ TEST(RemoteProfilerSession, Simple) {
absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still.
EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace());
EXPECT_GT(response->tool_data_size(), 0);
// True because there was no workload traced and subsequently no XEvents.
EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
}
@ -86,8 +88,9 @@ TEST(RemoteProfilerSession, Timeout) {
auto response = remote_session->WaitForCompletion(status);
// At end of session we will have a timeout error.
EXPECT_TRUE(errors::IsDeadlineExceeded(status));
EXPECT_FALSE(response->empty_trace()); // This defaults to false.
// True because there was no workload traced and subsequently no XEvents.
EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
}
@ -109,8 +112,10 @@ TEST(RemoteProfilerSession, LongDeadline) {
absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still.
EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace());
EXPECT_GT(response->tool_data_size(), 0);
// True because there was no workload traced and subsequently no XEvents.
EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
// Elapsed time is near profiling duration despite long grace period.
EXPECT_THAT(elapsed, DurationNear(duration));
}
@ -134,8 +139,10 @@ TEST(RemoteProfilerSession, LongDuration) {
absl::Duration elapsed = absl::Now() - approx_start;
// At end of session this evaluates to true still.
EXPECT_TRUE(status.ok());
EXPECT_FALSE(response->empty_trace());
EXPECT_GT(response->tool_data_size(), 0);
// True because there was no workload traced and subsequently no XEvents.
EXPECT_TRUE(response->empty_trace());
// XSpaces are serialized and not returned as tools in ProfileResponse.
EXPECT_EQ(response->tool_data_size(), 0);
// Elapsed time takes longer to complete for larger traces.
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
}

View File

@ -37,14 +37,14 @@ namespace profiler {
namespace test {
inline std::unique_ptr<ProfilerServer> StartServer(
absl::Duration duration, std::string* service_addresses,
absl::Duration duration, std::string* service_address,
ProfileRequest* request = nullptr) {
auto profiler_server = absl::make_unique<ProfilerServer>();
int port = testing::PickUnusedPortOrDie();
profiler_server->StartProfilerServer(port);
DCHECK(service_addresses);
*service_addresses = absl::StrCat("localhost:", port);
DCHECK(service_address);
*service_address = absl::StrCat("localhost:", port);
if (request) {
request->set_duration_ms(absl::ToInt64Milliseconds(duration));
@ -53,10 +53,11 @@ inline std::unique_ptr<ProfilerServer> StartServer(
request->mutable_opts()->set_duration_ms(
absl::ToInt64Milliseconds(duration));
request->set_session_id("test_session");
request->set_host_name(*service_addresses);
request->set_host_name(*service_address);
request->set_repository_root(testing::TmpDir());
}
LOG(INFO) << "Started " << *service_addresses << " at " << absl::Now();
LOG(INFO) << "Started " << *service_address << " at " << absl::Now();
LOG(INFO) << "Duration: " << duration;
return profiler_server;

View File

@ -26,47 +26,20 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace profiler {
namespace {
constexpr uint64 kMaxEvents = 1000000;
// TODO(yisitu) merge with the implementation in capture_profile.
void PopulateProfileRequest(const RemoteProfilerSessionManagerOptions& options,
absl::string_view session_id,
absl::string_view host_name,
ProfileRequest* request) {
request->set_max_events(kMaxEvents);
request->set_repository_root(options.profiler_options().repository_path());
request->set_session_id(session_id.data(), session_id.size());
request->add_tools("trace_viewer");
request->add_tools("op_profile");
request->add_tools("input_pipeline");
request->add_tools("kernel_stats");
request->add_tools("memory_viewer");
request->add_tools("memory_profile");
request->add_tools("overview_page");
request->add_tools("pod_viewer");
request->add_tools("tensorflow_stats");
request->set_host_name(host_name.data(), host_name.size());
*request->mutable_opts() = options.profiler_options();
request->set_duration_ms(options.profiler_options().duration_ms());
}
} // namespace
/*static*/ std::unique_ptr<RemoteProfilerSessionManager>
RemoteProfilerSessionManager::Create(
const RemoteProfilerSessionManagerOptions& options,
tensorflow::Status& out_status, AddressResolver resolver) {
const ProfileRequest& request, tensorflow::Status& out_status,
AddressResolver resolver) {
VLOG(1) << "Creating a RemoteProfilerSessionManager.";
auto session_manager =
absl::WrapUnique(new RemoteProfilerSessionManager(options, resolver));
auto session_manager = absl::WrapUnique(
new RemoteProfilerSessionManager(options, request, resolver));
out_status = session_manager->Init();
if (!out_status.ok()) {
return nullptr;
@ -75,8 +48,9 @@ RemoteProfilerSessionManager::Create(
}
RemoteProfilerSessionManager::RemoteProfilerSessionManager(
RemoteProfilerSessionManagerOptions options, AddressResolver resolver)
: options_(std::move(options)) {
RemoteProfilerSessionManagerOptions options, ProfileRequest request,
AddressResolver resolver)
: options_(std::move(options)), request_(std::move(request)) {
if (resolver) {
resolver_ = std::move(resolver);
} else {
@ -91,14 +65,7 @@ RemoteProfilerSessionManager::~RemoteProfilerSessionManager() {
Status RemoteProfilerSessionManager::Init() {
mutex_lock lock(mutex_);
VLOG(1) << "SessionManager initializing.";
// TODO(b/169482824) Move validation to call site.
Status status = ValidateOptionsLocked();
if (!status.ok()) {
LOG(ERROR) << status;
return status;
}
std::string session_id = GetCurrentTimeStampAsString();
const absl::Time session_created_ts =
absl::FromUnixNanos(options_.session_creation_timestamp_ns());
const absl::Time deadline =
@ -115,16 +82,14 @@ Status RemoteProfilerSessionManager::Init() {
// Prepare a list of clients.
clients_.reserve(options_.service_addresses_size());
for (auto& service_addr : options_.service_addresses()) {
std::string resolved_service_addr = resolver_(service_addr);
ProfileRequest profile_request;
PopulateProfileRequest(options_, session_id, resolved_service_addr,
&profile_request);
for (auto& service_address : options_.service_addresses()) {
std::string resolved_service_address = resolver_(service_address);
ProfileRequest request = request_;
request.set_host_name(resolved_service_address);
// Creation also issues Profile RPC asynchronously.
auto client = RemoteProfilerSession::Create(
std::move(resolved_service_addr), deadline, std::move(profile_request));
std::move(resolved_service_address), deadline, std::move(request));
clients_.push_back(std::move(client));
}
@ -132,41 +97,18 @@ Status RemoteProfilerSessionManager::Init() {
return Status::OK();
}
Status RemoteProfilerSessionManager::ValidateOptionsLocked() {
if (!options_.service_addresses_size()) {
return errors::InvalidArgument("No service addresses specified.");
}
if (options_.profiler_options().duration_ms() == 0) {
if (options_.max_session_duration_ms() != 0) {
return errors::InvalidArgument(
"If local profiler duration is unbounded, profiling session duration "
"must be unbounded.");
}
}
if (options_.max_session_duration_ms() <
options_.profiler_options().duration_ms()) {
return errors::InvalidArgument(
"The maximum profiling session duration must be greater than or equal "
"to the local profiler duration.");
}
return Status::OK();
}
std::vector<RemoteProfilerSessionManager::Response>
RemoteProfilerSessionManager::WaitForCompletion() {
mutex_lock lock(mutex_);
std::vector<RemoteProfilerSessionManager::Response> remote_responses;
remote_responses.reserve(clients_.size());
std::vector<RemoteProfilerSessionManager::Response> remote_responses(
clients_.size());
for (auto& client : clients_) {
remote_responses.emplace_back();
auto* profile_response = &remote_responses.back().profile_response;
Status& status = remote_responses.back().status;
std::string* service_addr = &remote_responses.back().service_addr;
*profile_response = client->WaitForCompletion(status);
*service_addr = std::string(client->GetServiceAddress());
for (int32 idx = 0; idx < clients_.size(); ++idx) {
auto& remote_response = remote_responses[idx];
auto* client = clients_[idx].get();
remote_response.profile_response =
client->WaitForCompletion(remote_response.status);
remote_response.service_address = std::string(client->GetServiceAddress());
}
return remote_responses;
}

View File

@ -26,9 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
namespace tensorflow {
@ -40,21 +37,16 @@ using AddressResolver = std::function<std::string(absl::string_view)>;
class RemoteProfilerSessionManager {
public:
struct Response {
std::string service_addr;
std::string service_address;
std::unique_ptr<ProfileResponse> profile_response;
Status status;
};
// Instantiates a collection of RemoteProfilerSessions starts profiling on
// each of them immediately.
// each of them immediately. Assumes that options have already been validated.
static std::unique_ptr<RemoteProfilerSessionManager> Create(
const RemoteProfilerSessionManagerOptions& options,
tensorflow::Status& out_status, AddressResolver resolver = nullptr);
static RemoteProfilerSessionManagerOptions DefaultOptions() {
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() = ProfilerSession::DefaultOptions();
return options;
}
const ProfileRequest& request, tensorflow::Status& out_status,
AddressResolver resolver = nullptr);
// Awaits for responses from remote profiler sessions and returns them as a
// list. Subsequent calls beyond the first will yield a list of errors.
@ -69,16 +61,16 @@ class RemoteProfilerSessionManager {
private:
explicit RemoteProfilerSessionManager(
RemoteProfilerSessionManagerOptions options, AddressResolver resolver);
RemoteProfilerSessionManagerOptions options, ProfileRequest request,
AddressResolver resolver);
// Initialization of all client contexts.
Status Init();
Status ValidateOptionsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
mutex mutex_;
// Remote profiler session options.
RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_);
ProfileRequest request_ TF_GUARDED_BY(mutex_);
// List of clients, each connects to a profiling service.
std::vector<std::unique_ptr<RemoteProfilerSession>> clients_
TF_GUARDED_BY(mutex_);

View File

@ -35,46 +35,73 @@ namespace {
using ::tensorflow::profiler::test::DurationApproxLess;
using ::tensorflow::profiler::test::DurationNear;
using ::tensorflow::profiler::test::StartServer;
using ::tensorflow::testing::TmpDir;
using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response;
// Copied from capture_profile to not introduce a dependency.
ProfileRequest PopulateProfileRequest(
absl::string_view repository_root, absl::string_view session_id,
absl::string_view host_name,
const RemoteProfilerSessionManagerOptions& options) {
constexpr uint64 kMaxEvents = 1000000;
const absl::string_view kXPlanePb = "xplane.pb";
ProfileRequest request;
// TODO(b/169976117) Remove duration from request.
request.set_duration_ms(options.profiler_options().duration_ms());
request.set_max_events(kMaxEvents);
request.set_repository_root(repository_root.data(), repository_root.size());
request.set_session_id(session_id.data(), session_id.size());
request.set_host_name(host_name.data(), host_name.size());
// XPlane tool is only used by OSS profiler and safely ignored by TPU
// profiler.
request.add_tools(kXPlanePb.data(), kXPlanePb.size());
*request.mutable_opts() = options.profiler_options();
return request;
}
TEST(RemoteProfilerSessionManagerTest, Simple) {
absl::Duration duration = absl::Milliseconds(30);
RemoteProfilerSessionManagerOptions options =
RemoteProfilerSessionManager::DefaultOptions();
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration));
std::string service_addresses;
auto server = StartServer(duration, &service_addresses);
options.add_service_addresses(service_addresses);
std::string service_address;
auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_address);
absl::Time approx_start = absl::Now();
absl::Duration grace = absl::Seconds(1);
absl::Duration max_duration = duration + grace;
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status);
auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start;
ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(responses.back().status.ok());
EXPECT_FALSE(responses.back().profile_response->empty_trace());
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
}
TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
absl::Duration duration = absl::Milliseconds(30);
RemoteProfilerSessionManagerOptions options =
RemoteProfilerSessionManager::DefaultOptions();
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration));
std::string service_addresses;
auto server = StartServer(duration, &service_addresses);
options.add_service_addresses(service_addresses);
std::string service_address;
auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_address);
absl::Duration grace = absl::Seconds(1);
absl::Duration max_duration = duration + grace;
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
@ -82,28 +109,32 @@ TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
options.set_session_creation_timestamp_ns(0);
absl::Time approx_start = absl::Now();
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status);
auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start;
EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0)));
ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status));
EXPECT_FALSE(responses.back().profile_response->empty_trace());
EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
}
TEST(RemoteProfilerSessionManagerTest, LongSession) {
absl::Duration duration = absl::Seconds(3);
RemoteProfilerSessionManagerOptions options =
RemoteProfilerSessionManager::DefaultOptions();
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
options.mutable_profiler_options()->set_duration_ms(
absl::ToInt64Milliseconds(duration));
std::string service_addresses;
auto server = StartServer(duration, &service_addresses);
options.add_service_addresses(service_addresses);
std::string service_address;
auto server = StartServer(duration, &service_address);
options.add_service_addresses(service_address);
absl::Time approx_start = absl::Now();
// Empirically determined value.
absl::Duration grace = absl::Seconds(20);
@ -111,15 +142,18 @@ TEST(RemoteProfilerSessionManagerTest, LongSession) {
options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
ProfileRequest request =
PopulateProfileRequest(TmpDir(), "session_id", service_address, options);
Status status;
auto sessions = RemoteProfilerSessionManager::Create(options, status);
auto sessions =
RemoteProfilerSessionManager::Create(options, request, status);
EXPECT_TRUE(status.ok());
std::vector<Response> responses = sessions->WaitForCompletion();
absl::Duration elapsed = absl::Now() - approx_start;
ASSERT_EQ(responses.size(), 1);
EXPECT_TRUE(responses.back().status.ok());
EXPECT_FALSE(responses.back().profile_response->empty_trace());
EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
EXPECT_TRUE(responses.back().profile_response->empty_trace());
EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/support/status.h"
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
@ -31,6 +32,8 @@ limitations under the License.
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/file_system_utils.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
namespace tensorflow {
namespace profiler {
@ -38,15 +41,31 @@ namespace {
const absl::string_view kXPlanePb = "xplane.pb";
Status CollectDataToResponse(const ProfileRequest& req,
ProfilerSession* profiler,
ProfileResponse* response) {
profiler::XSpace xspace;
// Collects data in XSpace format. The data is saved to a repository
// unconditionally.
Status CollectDataToRepository(const ProfileRequest& request,
ProfilerSession* profiler,
ProfileResponse* response) {
response->set_empty_trace(true);
// Read the profile data into xspace.
XSpace xspace;
TF_RETURN_IF_ERROR(profiler->CollectData(&xspace));
auto* tool_data = response->add_tool_data();
tool_data->set_name(kXPlanePb.data(), kXPlanePb.size());
xspace.SerializeToString(tool_data->mutable_data());
return Status::OK();
VLOG(3) << "Collected XSpace to repository.";
response->set_empty_trace(IsEmpty(xspace));
std::string log_dir_path =
ProfilerJoinPath(request.repository_root(), request.session_id());
VLOG(1) << "Creating " << log_dir_path;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir_path));
std::string file_name = absl::StrCat(request.host_name(), ".", kXPlanePb);
// Windows file names do not support colons.
absl::StrReplaceAll({{":", "_"}}, &file_name);
// Dumps profile data to <repository_root>/<run>/<host>_<port>.<kXPlanePb>
std::string out_path = ProfilerJoinPath(log_dir_path, file_name);
LOG(INFO) << "Collecting XSpace to repository: " << out_path;
return WriteBinaryProto(Env::Default(), out_path, xspace);
}
class ProfilerServiceImpl : public grpc::ProfilerService::Service {
@ -68,7 +87,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
}
Env* env = Env::Default();
for (uint64 i = 0; i < req->duration_ms(); ++i) {
for (uint64 i = 0; i < req->opts().duration_ms(); ++i) {
env->SleepForMicroseconds(EnvTime::kMillisToMicros);
if (ctx->IsCancelled()) {
return ::grpc::Status::CANCELLED;
@ -80,7 +99,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
}
}
status = CollectDataToResponse(*req, profiler.get(), response);
status = CollectDataToRepository(*req, profiler.get(), response);
if (!status.ok()) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
status.error_message());
@ -116,5 +135,4 @@ std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService() {
}
} // namespace profiler
} // namespace tensorflow

View File

@ -172,7 +172,7 @@ class ServerLibTest(test.TestCase):
# return UnavailableError with no trace events collected string.
with self.assertRaises(errors.UnavailableError) as error:
profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
self.assertEqual("No trace event is collected", str(error.exception))
self.assertStartsWith(str(error.exception), "No trace event was collected")
if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ cuda_py_test(
srcs = ["profiler_api_test.py"],
python_version = "PY3",
tags = [
"external", # So that test suite reruns unconditionally.
"no_pip",
"no_rocm",
],

View File

@ -67,10 +67,15 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
'kernel_stats.pb',
]
for file in expected_files:
path = os.path.join(logdir, 'plugins/profile/*/*{}'.format(file))
path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file))
self.assertEqual(1, len(glob.glob(path)),
'Expected one path match: ' + path)
def _check_xspace_pb_exist(self, logdir):
path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb')
self.assertEqual(1, len(glob.glob(path)),
'Expected one path match: ' + path)
def test_single_worker_no_profiling(self):
"""Test single worker without profiling."""
@ -86,7 +91,6 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
profiler.start_server(port)
_, steps, train_ds, model = _model_setup()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
logging.info('worker finishing')
def on_profile(port, logdir):
# Request for 30 milliseconds of profile.
@ -109,7 +113,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
thread_profiler.start()
thread_profiler.join()
thread_worker.join(120)
self._check_tools_pb_exist(logdir)
self._check_xspace_pb_exist(logdir)
def test_single_worker_programmatic_mode(self):
"""Test single worker programmatic mode."""

View File

@ -130,6 +130,7 @@ tf_python_pybind_extension(
"//tensorflow/python:pybind11_status",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@pybind11",
],
)

View File

@ -14,11 +14,17 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "pybind11/cast.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/core/platform/env.h"
@ -38,7 +44,12 @@ namespace py = ::pybind11;
namespace {
tensorflow::Status ValidateHostPortPair(const std::string& host_port) {
using ::tensorflow::RemoteProfilerSessionManagerOptions;
// Profiler gives grace after profiling duration to terminate.
constexpr absl::Duration kSessionGraceTime = absl::Seconds(5);
tensorflow::Status ValidateHostPortPair(absl::string_view host_port) {
tensorflow::uint32 port;
std::vector<absl::string_view> parts = absl::StrSplit(host_port, ':');
// Must be host:port, port must be a number, host must not contain a '/',
@ -51,34 +62,156 @@ tensorflow::Status ValidateHostPortPair(const std::string& host_port) {
return tensorflow::Status::OK();
}
// Takes profiler options in a py::dict and returns a ProfileOptions.
tensorflow::Status ValidateOptions(
const RemoteProfilerSessionManagerOptions& options) {
if (options.service_addresses().empty()) {
return tensorflow::errors::InvalidArgument("No service address provided.");
}
if (options.profiler_options().duration_ms() == 0) {
return tensorflow::errors::InvalidArgument(
"duration_ms must be greater than zero.");
}
for (absl::string_view host_port : options.service_addresses()) {
TF_RETURN_IF_ERROR(ValidateHostPortPair(host_port));
}
if (options.max_session_duration_ms() <
options.profiler_options().duration_ms()) {
return tensorflow::errors::InvalidArgument(
"The maximum profiling session duration must be greater than or equal "
"to the local profiler duration.");
}
return tensorflow::Status::OK();
}
// Receives a comma delimited list of service_addresses and adds them to
// RemoteProfilerSessionManagerOptions::service_addresses.
void AddServiceAddresses(absl::string_view service_addresses,
RemoteProfilerSessionManagerOptions* options) {
for (absl::string_view server : absl::StrSplit(service_addresses, ',')) {
options->add_service_addresses(server.data(), server.size());
}
}
// Sets gRPC deadline to a grace period based on the profiling duration.
void UpdateMaxSessionDuration(RemoteProfilerSessionManagerOptions& options) {
auto local_profiler_duration = options.profiler_options().duration_ms();
auto session_creation_ts = options.session_creation_timestamp_ns();
auto requested_start_ts = options.profiler_options().start_timestamp_ns();
// User only needs to set maximal session duration if the profiling duration
// is bounded.
DCHECK_GT(local_profiler_duration, 0);
VLOG(3) << "duration_ms was given as " << local_profiler_duration;
// Max session duration includes the profiling session and grace time.
auto profile_duration =
absl::Milliseconds(local_profiler_duration) + kSessionGraceTime;
absl::Duration delay_duration;
// When requested start timestamp is 0, profiling starts immediately.
if (requested_start_ts > 0) {
delay_duration =
absl::Nanoseconds(requested_start_ts - session_creation_ts);
}
auto max_session_duration = profile_duration + delay_duration;
options.set_max_session_duration_ms(
absl::ToInt64Milliseconds(max_session_duration));
VLOG(1) << "max_session_duration set to " << max_session_duration;
}
// Takes profiler options in a py::dict and returns a
// RemoteProfilerSessionManagerOptions.
// This must be called under GIL because it reads Python objects. Reading Python
// objects require GIL because the objects can be mutated by other Python
// threads. In addition, Python objects are reference counted; reading py::dict
// will increase its reference count.
tensorflow::ProfileOptions GetOptionsLocked(const py::dict& opts) {
tensorflow::ProfileOptions options =
RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir,
const py::dict& opts) {
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
// Store a timestamp of when this session was created. This will be the basis
// of gRPC deadline afterwards.
auto now = absl::Now();
options.set_session_creation_timestamp_ns(absl::ToUnixNanos(now));
VLOG(2) << "set_session_creation_timestamp_ns set to "
<< options.session_creation_timestamp_ns() << " [" << now << "]";
// Set the path of where to store XSpaces.
options.mutable_profiler_options()->set_repository_path(logdir.data(),
logdir.size());
VLOG(2) << "repository_path set to "
<< options.profiler_options().repository_path();
for (const auto& kw : opts) {
std::string key = py::cast<std::string>(kw.first);
if (key == "host_tracer_level") {
options.set_host_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
auto value = py::cast<int>(kw.second);
options.mutable_profiler_options()->set_host_tracer_level(value);
VLOG(1) << "host_tracer_level set to " << value;
} else if (key == "device_tracer_level") {
options.set_device_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "device_tracer_level set to " << options.device_tracer_level();
auto value = py::cast<int>(kw.second);
options.mutable_profiler_options()->set_device_tracer_level(value);
VLOG(1) << "device_tracer_level set to " << value;
} else if (key == "python_tracer_level") {
options.set_python_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "python_tracer_level set to " << options.python_tracer_level();
auto value = py::cast<int>(kw.second);
options.mutable_profiler_options()->set_python_tracer_level(value);
VLOG(1) << "python_tracer_level set to " << value;
} else {
LOG(WARNING) << "Unrecognised key: " << key;
}
}
return options;
}
RemoteProfilerSessionManagerOptions GetOptionsLocked(
absl::string_view service_addresses, absl::string_view logdir,
absl::string_view worker_list, bool include_dataset_ops,
tensorflow::int32 duration_ms, py::dict opts, bool* is_cloud_tpu_session) {
RemoteProfilerSessionManagerOptions options = GetOptionsLocked(logdir, opts);
// Remote profiling does not support any use cases where the following options
// are set by `py::dict opts`. e.g. `opts['service_addrs']` will not happen.
DCHECK(options.service_addresses().empty());
// In remote profiling, duration is always passed by value explicitly and not
// set in py::dict opts.
DCHECK_EQ(options.profiler_options().duration_ms(), 0);
// Because duration_ms is not set from py::dict opts, it follows that
// max_session_duration_ms must be unset as well.
DCHECK_EQ(options.max_session_duration_ms(), 0);
// Worker_list is only used for TensorBoard TPU capture cases. For a TPU
// cluster, service_address is the Master, which can already be found in the
// list of workers. These sessions will be used with the ProfileAnalysis
// service.
*is_cloud_tpu_session = !worker_list.empty();
AddServiceAddresses(*is_cloud_tpu_session ? worker_list : service_addresses,
&options);
// Set local profiler duration and profiler session durations.
options.mutable_profiler_options()->set_include_dataset_ops(
include_dataset_ops);
options.mutable_profiler_options()->set_duration_ms(duration_ms);
UpdateMaxSessionDuration(options);
for (int idx = 0; idx < options.service_addresses_size(); ++idx) {
VLOG(1) << "service_addr " << idx << " set to "
<< options.service_addresses(idx);
}
VLOG(1) << "include_dataset_ops set to " << include_dataset_ops;
VLOG(1) << "duration_ms set to " << duration_ms;
return options;
}
class ProfilerSessionWrapper {
public:
void Start(const char* logdir, const py::dict& options) {
session_ = tensorflow::ProfilerSession::Create(GetOptionsLocked(options));
auto opts = GetOptionsLocked(logdir, options);
session_ = tensorflow::ProfilerSession::Create(opts.profiler_options());
logdir_ = logdir;
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
}
@ -130,26 +263,28 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
profiler_server.release();
});
m.def("trace",
[](const char* service_addr, const char* logdir,
const char* worker_list, bool include_dataset_ops, int duration_ms,
int num_tracing_attempts, py::dict options) {
// Normalize py::dict into a well defined proto.
tensorflow::ProfileOptions opts = GetOptionsLocked(options);
m.def("trace", [](const char* service_addr, const char* logdir,
const char* worker_list, bool include_dataset_ops,
int duration_ms, int num_tracing_attempts,
py::dict options) {
// TPU capture is true if the user sets worker_list.
bool is_cloud_tpu_session = false;
// Normalize py::dict into a well defined and validated proto.
tensorflow::RemoteProfilerSessionManagerOptions opts =
GetOptionsLocked(service_addr, logdir, worker_list, include_dataset_ops,
duration_ms, options, &is_cloud_tpu_session);
tensorflow::Status status = ValidateOptions(opts);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
tensorflow::Status status = ValidateHostPortPair(service_addr);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
opts.set_include_dataset_ops(include_dataset_ops);
{
// Release the lock to keep the lock scope to a minimum, and allow
// other threads to proceed.
py::gil_scoped_release release;
status = tensorflow::profiler::Trace(service_addr, logdir,
worker_list, duration_ms,
num_tracing_attempts, opts);
}
tensorflow::MaybeRaiseRegisteredFromStatus(status);
});
{
// Release the lock to keep the lock scope to a minimum, and allow
// other threads to proceed.
py::gil_scoped_release release;
status = tensorflow::profiler::Trace(logdir, num_tracing_attempts, opts,
is_cloud_tpu_session);
}
tensorflow::MaybeRaiseRegisteredFromStatus(status);
});
m.def("monitor", [](const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp) {

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import errors
from tensorflow.python.profiler.internal import _pywrap_profiler
from tensorflow.python.util.tf_export import tf_export
@ -32,38 +33,57 @@ def trace(service_addr,
worker_list='',
num_tracing_attempts=3,
options=None):
"""Sends grpc requests to profiler server to perform on-demand profiling.
"""Sends gRPC requests to one or more profiler servers to perform on-demand profiling.
This method will block caller thread until it receives tracing result. This
method supports CPU, GPU, and Cloud TPU. This method supports profiling a
single host for CPU, GPU, TPU, as well as multiple TPU workers.
The profiled results will be saved to your specified TensorBoard log
directory (e.g. the directory you save your model checkpoints). Use the
This method will block the calling thread until it receives responses from all
servers or until deadline expiration. Both single host and multiple host
profiling are supported on CPU, GPU, and TPU.
The profiled results will be saved by each server to the specified TensorBoard
log directory (i.e. the directory you save your model checkpoints). Use the
TensorBoard profile plugin to view the visualization and analysis results.
Args:
service_addr: gRPC address of profiler service e.g. grpc://localhost:6009.
logdir: Path of TensorBoard log directory e.g. /tmp/tb_log.
duration_ms: Duration of tracing or monitoring in ms.
worker_list: Optional. The list of workers that we are about to profile in
the current session (TPU only).
service_addr: A comma delimited string of gRPC addresses of the workers to
profile.
e.g. service_addr='grpc://localhost:6009'
service_addr='grpc://10.0.0.2:8466,grpc://10.0.0.3:8466'
service_addr='grpc://localhost:12345,grpc://localhost:23456'
logdir: Path to save profile data to, typically a TensorBoard log directory.
This path must be accessible to both the client and server.
e.g. logdir='gs://your_tb_dir'
duration_ms: Duration of tracing or monitoring in mliiseconds. Must be
greater than zero.
worker_list: An optional TPU only configuration. The list of workers to
profile in the current session.
num_tracing_attempts: Optional. Automatically retry N times when no trace
event is collected (default 3).
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
profiler options.
Raises:
UnavailableError: If no trace event is collected.
InvalidArgumentError: For when arguments fail validation checks.
UnavailableError: If no trace event was collected.
Example usage (CPU/GPU):
# Start a profiler server before your model runs.
```python
tf.profiler.experimental.server.start(6009)
# your model code.
# (Model code goes here).
# Send gRPC request to the profiler server to collect a trace of your model.
```python
tf.profiler.experimental.client.trace('grpc://localhost:6009',
'/tmp/tb_log', 2000)
'/nfs/tb_log', 2000)
Example usage (Multiple GPUs):
# E.g. your worker IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you
# would like to schedule start of profiling 1 second from now, for a duration
# of 2 seconds.
options['delay_ms'] = 1000
tf.profiler.experimental.client.trace(
'grpc://10.0.0.2:8466,grpc://10.0.0.3:8466,grpc://10.0.0.4:8466',
'gs://your_tb_dir',
2000,
options=options)
Example usage (TPU):
# Send gRPC request to a TPU worker to collect a trace of your model. A
@ -82,16 +102,19 @@ def trace(service_addr,
# profile for 2 seconds.
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
'gs://your_tb_dir',
2000, '10.0.0.3,10.0.0.4')
2000, '10.0.0.2,10.0.0.3,10.0.0.4')
Launch TensorBoard and point it to the same logdir you provided to this API.
$ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples)
Open your browser and go to localhost:6006/#profile to view profiling results.
"""
if duration_ms <= 0:
raise errors.InvalidArgumentError(None, None,
'duration_ms must be greater than zero.')
opts = dict(options._asdict()) if options is not None else {}
_pywrap_profiler.trace(
_strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
_strip_addresses(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
duration_ms, num_tracing_attempts, opts)
@ -127,3 +150,7 @@ def monitor(service_addr, duration_ms, level=1):
def _strip_prefix(s, prefix):
return s[len(prefix):] if s.startswith(prefix) else s
def _strip_addresses(addresses, prefix):
return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')])

View File

@ -38,7 +38,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.UnavailableError) as error:
profiler_client.trace(
'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10)
self.assertEqual('No trace event is collected', str(error.exception))
self.assertStartsWith(str(error.exception), 'No trace event was collected')
def testTrace_ProfileIdleServerWithOptions(self):
test_port = portpicker.pick_unused_port()
@ -54,7 +54,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
self.get_temp_dir(),
duration_ms=10,
options=options)
self.assertEqual('No trace event is collected', str(error.exception))
self.assertStartsWith(str(error.exception), 'No trace event was collected')
def testMonitor_ProcessInvalidAddress(self):
# Monitor is only supported in cloud TPU. Test invalid address instead.

View File

@ -346,6 +346,10 @@ tensorflow::profiler::ProfilerServer::~ProfilerServer
tensorflow::profiler::ProfileGrpc
tensorflow::profiler::NewSessionGrpc
tensorflow::profiler::MonitorGrpc
tensorflow::profiler::RemoteProfilerSession::Create
tensorflow::profiler::RemoteProfilerSession::GetServiceAddress
tensorflow::profiler::RemoteProfilerSession::WaitForCompletion
tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession
[status_macros] # tfcompile
xla::status_macros::MakeErrorStream::Impl::Impl