prepare to pass options through python_client.py.

PiperOrigin-RevId: 305148445
Change-Id: I8d643117d095d277393a7fd3cbbce13b4c69733f
This commit is contained in:
A. Unique TensorFlower 2020-04-06 16:57:29 -07:00 committed by TensorFlower Gardener
parent b91970316f
commit f67fde60cf
4 changed files with 11 additions and 10 deletions

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
#include "tensorflow/core/util/events_writer.h"
@ -182,8 +181,8 @@ Status ValidateHostPortPair(const string& host_port) {
// given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts.
Status Trace(const string& service_addr, const string& logdir,
const string& workers_list, bool include_dataset_ops,
int duration_ms, int num_tracing_attempts) {
const string& workers_list, int duration_ms,
int num_tracing_attempts, const ProfileOptions& opts) {
// Use the current timestamp as the run name.
tensorflow::string session_id = GetCurrentTimeStampAsString();
std::vector<string> hostnames;
@ -193,8 +192,6 @@ Status Trace(const string& service_addr, const string& logdir,
Status status = Status::OK();
int remaining_attempts = num_tracing_attempts;
ProfileOptions opts;
opts.set_include_dataset_ops(include_dataset_ops);
while (true) {
std::cout << "Starting to trace for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
namespace tensorflow {
namespace profiler {
@ -34,8 +35,8 @@ Status Monitor(const string& service_addr, int duration_ms,
// given logdir. If no trace was collected, retries tracing for
// num_tracing_attempts.
Status Trace(const string& service_addr, const string& logdir,
const string& workers_list, bool include_dataset_ops,
int duration_ms, int num_tracing_attempts);
const string& workers_list, int duration_ms,
int num_tracing_attempts, const ProfileOptions& opts);
} // namespace profiler
} // namespace tensorflow

View File

@ -118,6 +118,7 @@ tf_python_pybind_extension(
],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/convert:xplane_to_profile_response",
"//tensorflow/core/profiler/convert:xplane_to_trace_events",
"//tensorflow/core/profiler/lib:profiler_session_headers",

View File

@ -137,9 +137,11 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
tensorflow::Status status =
tensorflow::profiler::ValidateHostPortPair(service_addr);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
status = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::ProfileOptions opts;
opts.set_include_dataset_ops(include_dataset_ops);
status =
tensorflow::profiler::Trace(service_addr, logdir, worker_list,
duration_ms, num_tracing_attempts, opts);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
});