prepare to pass options through python_client.py.
PiperOrigin-RevId: 305148445 Change-Id: I8d643117d095d277393a7fd3cbbce13b4c69733f
This commit is contained in:
parent
b91970316f
commit
f67fde60cf
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
|
||||
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
|
||||
#include "tensorflow/core/util/events_writer.h"
|
||||
|
||||
@ -182,8 +181,8 @@ Status ValidateHostPortPair(const string& host_port) {
|
||||
// given logdir. If no trace was collected, retries tracing for
|
||||
// num_tracing_attempts.
|
||||
Status Trace(const string& service_addr, const string& logdir,
|
||||
const string& workers_list, bool include_dataset_ops,
|
||||
int duration_ms, int num_tracing_attempts) {
|
||||
const string& workers_list, int duration_ms,
|
||||
int num_tracing_attempts, const ProfileOptions& opts) {
|
||||
// Use the current timestamp as the run name.
|
||||
tensorflow::string session_id = GetCurrentTimeStampAsString();
|
||||
std::vector<string> hostnames;
|
||||
@ -193,8 +192,6 @@ Status Trace(const string& service_addr, const string& logdir,
|
||||
|
||||
Status status = Status::OK();
|
||||
int remaining_attempts = num_tracing_attempts;
|
||||
ProfileOptions opts;
|
||||
opts.set_include_dataset_ops(include_dataset_ops);
|
||||
while (true) {
|
||||
std::cout << "Starting to trace for " << duration_ms << " ms. "
|
||||
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
@ -34,8 +35,8 @@ Status Monitor(const string& service_addr, int duration_ms,
|
||||
// given logdir. If no trace was collected, retries tracing for
|
||||
// num_tracing_attempts.
|
||||
Status Trace(const string& service_addr, const string& logdir,
|
||||
const string& workers_list, bool include_dataset_ops,
|
||||
int duration_ms, int num_tracing_attempts);
|
||||
const string& workers_list, int duration_ms,
|
||||
int num_tracing_attempts, const ProfileOptions& opts);
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -118,6 +118,7 @@ tf_python_pybind_extension(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler:profiler_service_proto_cc",
|
||||
"//tensorflow/core/profiler/convert:xplane_to_profile_response",
|
||||
"//tensorflow/core/profiler/convert:xplane_to_trace_events",
|
||||
"//tensorflow/core/profiler/lib:profiler_session_headers",
|
||||
|
@ -137,9 +137,11 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
|
||||
tensorflow::Status status =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||
status = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
tensorflow::ProfileOptions opts;
|
||||
opts.set_include_dataset_ops(include_dataset_ops);
|
||||
status =
|
||||
tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
duration_ms, num_tracing_attempts, opts);
|
||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user