consolidate tf::profiler::ProfilerOptions with tf::ProfileOptions.

Remove ProfilerInterface::GetDeviceType();

PiperOrigin-RevId: 305506928
Change-Id: I18150a5af56d9a0950fa0f0fb5190150a495e806
This commit is contained in:
A. Unique TensorFlower 2020-04-08 10:36:49 -07:00 committed by TensorFlower Gardener
parent 2dd36891b5
commit 3acb43588d
22 changed files with 160 additions and 157 deletions

View File

@ -238,6 +238,7 @@ FRAMEWORK_PROTO_SRCS = [
PROFILER_PROTO_SRCS = [
"//tensorflow/core/profiler/protobuf:xplane.proto",
"//tensorflow/core/profiler:profiler_options.proto",
]
ERROR_CODES_PROTO_SRCS = [
@ -2162,6 +2163,7 @@ tf_proto_library(
"//tensorflow/core/framework:protos_all",
"//tensorflow/core/lib/core:error_codes_proto",
"//tensorflow/core/profiler/protobuf:xplane_proto",
"//tensorflow/core/profiler:profiler_options_proto",
"//tensorflow/core/util:protos_all",
"//tensorflow/core/util:test_log_proto_impl",
],

View File

@ -589,7 +589,10 @@ def tf_protos_all():
)
def tf_protos_profiler_impl():
return [clean_dep("//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl")]
return [
clean_dep("//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl"),
clean_dep("//tensorflow/core/profiler:profiler_options_proto_cc_impl"),
]
def tf_protos_grappler_impl():
return [clean_dep("//tensorflow/core/grappler/costs:op_performance_data_cc_impl")]

View File

@ -28,6 +28,20 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
tf_proto_library(
name = "profiler_options_proto",
srcs = ["profiler_options.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
visibility = ["//visibility:public"],
)
# This is needed because of how tf_android_core_proto_sources parses proto paths.
exports_files(
srcs = ["profiler_options.proto"],
visibility = ["//tensorflow/core:__pkg__"],
)
tf_proto_library(
name = "profiler_service_proto",
srcs = ["profiler_service.proto"],
@ -35,6 +49,7 @@ tf_proto_library(
cc_api_version = 2,
cc_grpc_version = 1,
protodeps = [
":profiler_options_proto",
":profiler_service_monitor_result_proto",
],
use_grpc_namespace = True,

View File

@ -434,6 +434,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
],
)

View File

@ -50,6 +50,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_schema",

View File

@ -55,8 +55,6 @@ class HostTracer : public ProfilerInterface {
Status CollectData(XSpace* space) override;
DeviceType GetDeviceType() override { return DeviceType::kCpu; }
private:
// Level of host tracing.
const int host_trace_level_;
@ -154,9 +152,9 @@ Status HostTracer::CollectData(XSpace* space) {
// Not in anonymous namespace for testing purposes.
std::unique_ptr<ProfilerInterface> CreateHostTracer(
const profiler::ProfilerOptions& options) {
if (options.host_tracer_level == 0) return nullptr;
return absl::make_unique<HostTracer>(options.host_tracer_level);
const ProfileOptions& options) {
if (options.host_tracer_level() == 0) return nullptr;
return absl::make_unique<HostTracer>(options.host_tracer_level());
}
auto register_host_tracer_factory = [] {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
@ -31,7 +32,7 @@ namespace tensorflow {
namespace profiler {
std::unique_ptr<ProfilerInterface> CreateHostTracer(
const ProfilerOptions& options);
const ProfileOptions& options);
namespace {
@ -77,7 +78,7 @@ inline ::testing::PolymorphicMatcher<NodeStatsMatcher> EqualsNodeStats(
TEST(HostTracerTest, CollectsTraceMeEventsAsRunMetadata) {
uint32 thread_id = Env::Default()->GetCurrentThreadId();
auto tracer = CreateHostTracer(ProfilerOptions());
auto tracer = CreateHostTracer(ProfilerSession::DefaultOptions());
TF_ASSERT_OK(tracer->Start());
{ TraceMe traceme("hello"); }
@ -122,7 +123,7 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) {
ASSERT_TRUE(Env::Default()->GetCurrentThreadName(&thread_name));
thread_id = Env::Default()->GetCurrentThreadId();
auto tracer = CreateHostTracer(ProfilerOptions());
auto tracer = CreateHostTracer(ProfilerSession::DefaultOptions());
TF_ASSERT_OK(tracer->Start());
{ TraceMe traceme("hello"); }

View File

@ -74,8 +74,6 @@ class MetadataCollector : public ProfilerInterface {
return Status::OK();
}
DeviceType GetDeviceType() override { return DeviceType::kCpu; }
private:
std::vector<xla::gpu::GpuModuleDebugInfo> debug_info_;
bool trace_active_ = false;
@ -84,9 +82,9 @@ class MetadataCollector : public ProfilerInterface {
};
std::unique_ptr<ProfilerInterface> CreatMetadataCollector(
const profiler::ProfilerOptions& options) {
return options.enable_hlo_proto ? absl::make_unique<MetadataCollector>()
: nullptr;
const ProfileOptions& options) {
return options.enable_hlo_proto() ? absl::make_unique<MetadataCollector>()
: nullptr;
}
} // namespace

View File

@ -50,8 +50,6 @@ class PythonTracer : public ProfilerInterface {
Status CollectData(XSpace* space) override;
DeviceType GetDeviceType() override { return DeviceType::kCpu; }
private:
bool recording_ = false;
@ -107,10 +105,10 @@ Status PythonTracer::CollectData(XSpace* space) {
// Not in anonymous namespace for testing purposes.
std::unique_ptr<ProfilerInterface> CreatePythonTracer(
const profiler::ProfilerOptions& options) {
if (!options.enable_python_tracer) return nullptr;
const ProfileOptions& options) {
if (options.python_tracer_level() == 0) return nullptr;
// This ProfilerInterface rely on TraceMeRecorder to be active.
if (options.host_tracer_level == 0) return nullptr;
if (options.host_tracer_level() == 0) return nullptr;
return absl::make_unique<PythonTracer>();
}

View File

@ -76,6 +76,7 @@ tf_cc_test_gpu(
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",

View File

@ -513,9 +513,6 @@ class GpuTracer : public profiler::ProfilerInterface {
Status Stop() override;
Status CollectData(RunMetadata* run_metadata) override;
Status CollectData(XSpace* space) override;
profiler::DeviceType GetDeviceType() override {
return profiler::DeviceType::kGpu;
}
private:
Status DoStart();
@ -679,9 +676,9 @@ Status GpuTracer::CollectData(XSpace* space) {
// Not in anonymous namespace for testing purposes.
std::unique_ptr<profiler::ProfilerInterface> CreateGpuTracer(
const profiler::ProfilerOptions& options) {
if (options.device_type != profiler::DeviceType::kGpu &&
options.device_type != profiler::DeviceType::kUnspecified)
const ProfileOptions& options) {
if (options.device_type() != ProfileOptions::GPU &&
options.device_type() != ProfileOptions::UNSPECIFIED)
return nullptr;
profiler::CuptiTracer* cupti_tracer =
profiler::CuptiTracer::GetCuptiTracerSingleton();

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
@ -45,14 +46,15 @@ namespace tensorflow {
namespace profiler {
#if GOOGLE_CUDA
std::unique_ptr<ProfilerInterface> CreateGpuTracer(
const ProfilerOptions& options);
extern std::unique_ptr<ProfilerInterface> CreateGpuTracer(
const ProfileOptions& options);
std::unique_ptr<ProfilerInterface> CreateGpuTracer() {
ProfileOptions options = ProfilerSession::DefaultOptions();
return CreateGpuTracer(options);
}
#else
// We don't have device tracer for non-cuda case.
std::unique_ptr<ProfilerInterface> CreateGpuTracer(
const ProfilerOptions& options) {
return nullptr;
}
std::unique_ptr<ProfilerInterface> CreateGpuTracer() { return nullptr; }
#endif
namespace {
@ -111,24 +113,21 @@ class DeviceTracerTest : public ::testing::Test {
};
TEST_F(DeviceTracerTest, StartStop) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
TF_EXPECT_OK(tracer->Start());
TF_EXPECT_OK(tracer->Stop());
}
TEST_F(DeviceTracerTest, StopBeforeStart) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
TF_EXPECT_OK(tracer->Stop());
TF_EXPECT_OK(tracer->Stop());
}
TEST_F(DeviceTracerTest, CollectBeforeStart) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
RunMetadata run_metadata;
TF_EXPECT_OK(tracer->CollectData(&run_metadata));
@ -136,8 +135,7 @@ TEST_F(DeviceTracerTest, CollectBeforeStart) {
}
TEST_F(DeviceTracerTest, CollectBeforeStop) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
TF_EXPECT_OK(tracer->Start());
RunMetadata run_metadata;
@ -147,9 +145,8 @@ TEST_F(DeviceTracerTest, CollectBeforeStop) {
}
TEST_F(DeviceTracerTest, StartTwoTracers) {
profiler::ProfilerOptions options;
auto tracer1 = CreateGpuTracer(options);
auto tracer2 = CreateGpuTracer(options);
auto tracer1 = CreateGpuTracer();
auto tracer2 = CreateGpuTracer();
if (!tracer1 || !tracer2) return;
TF_EXPECT_OK(tracer1->Start());
@ -162,8 +159,7 @@ TEST_F(DeviceTracerTest, StartTwoTracers) {
TEST_F(DeviceTracerTest, RunWithTracer) {
// On non-GPU platforms, we may not support DeviceTracer.
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
Initialize({3, 2, -1, 0});
@ -190,8 +186,7 @@ TEST_F(DeviceTracerTest, RunWithTracer) {
}
TEST_F(DeviceTracerTest, TraceToStepStatsCollector) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
Initialize({3, 2, -1, 0});
@ -244,8 +239,7 @@ TEST_F(DeviceTracerTest, RunWithTraceOption) {
}
TEST_F(DeviceTracerTest, TraceToXSpace) {
profiler::ProfilerOptions options;
auto tracer = CreateGpuTracer(options);
auto tracer = CreateGpuTracer();
if (!tracer) return;
Initialize({3, 2, -1, 0});

View File

@ -36,7 +36,7 @@ void RegisterProfilerFactory(ProfilerFactory factory) {
}
void CreateProfilers(
const profiler::ProfilerOptions& options,
const ProfileOptions& options,
std::vector<std::unique_ptr<profiler::ProfilerInterface>>* result) {
mutex_lock lock(mu);
for (auto factory : *GetFactories()) {

View File

@ -24,11 +24,11 @@ namespace tensorflow {
namespace profiler {
using ProfilerFactory =
std::unique_ptr<ProfilerInterface> (*)(const ProfilerOptions&);
std::unique_ptr<ProfilerInterface> (*)(const ProfileOptions&);
void RegisterProfilerFactory(ProfilerFactory factory);
void CreateProfilers(const ProfilerOptions& options,
void CreateProfilers(const ProfileOptions& options,
std::vector<std::unique_ptr<ProfilerInterface>>* result);
} // namespace profiler

View File

@ -16,50 +16,13 @@ limitations under the License.
#define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace profiler {
enum class DeviceType {
kUnspecified,
kCpu,
kGpu,
kTpu,
};
struct ProfilerOptions {
// DeviceType::kUnspecified: All registered device profiler will be enabled.
// DeviceType::kCpu: only CPU will be profiled.
// DeviceType::kGpu: only CPU/GPU will be profiled.
// DeviceType::kTpu: only CPU/TPU will be profiled.
DeviceType device_type = DeviceType::kUnspecified;
// Levels of host tracing:
// - Level 0 is used to disable host traces.
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
// level program execution details (expensive TF ops, XLA ops, etc).
// This is the default.
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
// (low-level) program execution details (cheap TF ops, etc).
uint32 host_tracer_level = 2;
// Levels of device tracing:
// - Level 0 is used to disable device traces.
// - Level 1 is used to enable device traces.
// - More levels might be defined for specific device for controlling the
// verbosity of the trace.
uint32 device_tracer_level = 1;
// Whether to enable python function calls tracer.
bool enable_python_tracer = false;
// Whether to capture HLO protos from XLA runtime.
bool enable_hlo_proto = true;
};
// Interface for tensorflow profiler plugins.
//
// ProfileSession calls each of these methods at most once per instance, and
@ -87,9 +50,6 @@ class ProfilerInterface {
// After this or the overload above are called once, subsequent calls might
// return empty data.
virtual Status CollectData(XSpace* space) = 0;
// Which device this ProfilerInterface is used for.
virtual DeviceType GetDeviceType() = 0;
};
} // namespace profiler

View File

@ -26,6 +26,7 @@ cc_library(
deps = [
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
] + if_static([
@ -50,6 +51,7 @@ cc_library(
"//tensorflow/core/platform",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/util:ptr_util",
] + if_not_android([
":profiler_utils",

View File

@ -33,8 +33,17 @@ limitations under the License.
namespace tensorflow {
namespace {
ProfileOptions GetOptions(const ProfileOptions& opts) {
if (opts.version()) return opts;
ProfileOptions options = ProfilerSession::DefaultOptions();
options.set_include_dataset_ops(opts.include_dataset_ops());
return options;
}
}; // namespace
/*static*/ std::unique_ptr<ProfilerSession> ProfilerSession::Create(
const profiler::ProfilerOptions& options) {
const ProfileOptions& options) {
return WrapUnique(new ProfilerSession(options));
}
@ -45,12 +54,24 @@ namespace tensorflow {
if (!s.ok()) {
LOG(WARNING) << "ProfilerSession: " << s.error_message();
}
profiler::ProfilerOptions options;
options.host_tracer_level = host_tracer_level;
ProfileOptions options = DefaultOptions();
options.set_host_tracer_level(host_tracer_level);
return Create(options);
}
Status ProfilerSession::Status() {
/*static*/ ProfileOptions ProfilerSession::DefaultOptions() {
ProfileOptions options;
options.set_version(1);
options.set_device_tracer_level(1);
options.set_host_tracer_level(2);
options.set_device_type(ProfileOptions::UNSPECIFIED);
options.set_python_tracer_level(0);
options.set_enable_hlo_proto(false);
options.set_include_dataset_ops(true);
return options;
}
tensorflow::Status ProfilerSession::Status() {
mutex_lock l(mutex_);
return status_;
}
@ -122,14 +143,14 @@ Status ProfilerSession::CollectData(RunMetadata* run_metadata) {
return Status::OK();
}
ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options)
ProfilerSession::ProfilerSession(const ProfileOptions& options)
#if !defined(IS_MOBILE_PLATFORM)
: active_(profiler::AcquireProfilerLock()),
#else
: active_(false),
#endif
start_time_ns_(EnvTime::NowNanos()) {
start_time_ns_(EnvTime::NowNanos()),
options_(GetOptions(options)) {
if (!active_) {
#if !defined(IS_MOBILE_PLATFORM)
status_ = tensorflow::Status(error::UNAVAILABLE,
@ -145,7 +166,7 @@ ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options)
LOG(INFO) << "Profiler session started.";
#if !defined(IS_MOBILE_PLATFORM)
CreateProfilers(options, &profilers_);
CreateProfilers(options_, &profilers_);
#endif
status_ = Status::OK();

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
@ -36,9 +37,9 @@ namespace tensorflow {
class ProfilerSession {
public:
// Creates and ProfilerSession and starts profiling.
static std::unique_ptr<ProfilerSession> Create(
const profiler::ProfilerOptions& options);
static std::unique_ptr<ProfilerSession> Create(const ProfileOptions& options);
static std::unique_ptr<ProfilerSession> Create();
static ProfileOptions DefaultOptions();
// Deletes an existing Profiler and enables starting a new one.
~ProfilerSession();
@ -53,7 +54,7 @@ class ProfilerSession {
private:
// Constructs an instance of the class and starts profiling
explicit ProfilerSession(const profiler::ProfilerOptions& options);
explicit ProfilerSession(const ProfileOptions& options);
// ProfilerSession is neither copyable or movable.
ProfilerSession(const ProfilerSession&) = delete;
@ -68,6 +69,7 @@ class ProfilerSession {
tensorflow::Status status_ TF_GUARDED_BY(mutex_);
const uint64 start_time_ns_;
mutex mutex_;
ProfileOptions options_;
};
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
syntax = "proto3";
package tensorflow;
message ProfileOptions {
// Some default value of option are not proto3 default value. Use this version
// to determine if we should use default option value instead of proto3
// default value.
uint32 version = 5;
enum DeviceType {
UNSPECIFIED = 0;
CPU = 1;
GPU = 2;
TPU = 3;
}
// Device type to profile/trace: (version >= 1)
// DeviceType::UNSPECIFIED: All registered device profiler will be enabled.
// DeviceType::CPU: only CPU will be profiled.
// DeviceType::GPU: only CPU/GPU will be profiled.
// DeviceType::TPU: only CPU/TPU will be profiled.
DeviceType device_type = 6;
// We don't collect the dataset ops by default for better trace-viewer
// scalability. The caller can mannually set this field to include the ops.
bool include_dataset_ops = 1;
// Levels of host tracing: (version >= 1)
// - Level 0 is used to disable host traces.
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
// level program execution details (expensive TF ops, XLA ops, etc).
// This is the default.
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
// (low-level) program execution details (cheap TF ops, etc).
uint32 host_tracer_level = 2;
// Levels of device tracing: (version >= 1)
// - Level 0 is used to disable device traces.
// - Level 1 is used to enable device traces.
// - More levels might be defined for specific device for controlling the
// verbosity of the trace.
uint32 device_tracer_level = 3;
// Whether enable python function calls tracing. Runtime overhead ensues if
// enabled. Default off. (version >= 1)
uint32 python_tracer_level = 4;
// Whether serialize hlo_proto when XLA is used. (version >= 1)
bool enable_hlo_proto = 7;
// next-field: 8
}

View File

@ -2,6 +2,7 @@ syntax = "proto3";
package tensorflow;
import "tensorflow/core/profiler/profiler_options.proto";
import "tensorflow/core/profiler/profiler_service_monitor_result.proto";
// The ProfilerService service retrieves performance information about
@ -13,40 +14,6 @@ service ProfilerService {
rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
}
message ProfileOptions {
// Some default value of option are not proto3 default value. Use this version
// to determine if we should use default option value instead of proto3
// default value.
uint32 version = 5;
// We don't collect the dataset ops by default for better trace-viewer
// scalability. The caller can mannually set this field to include the ops.
bool include_dataset_ops = 1;
// Levels of host tracing: (version >= 1)
// - Level 0 is used to disable host traces.
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
// level program execution details (expensive TF ops, XLA ops, etc).
// This is the default.
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
// (low-level) program execution details (cheap TF ops, etc).
uint32 host_tracer_level = 2;
// Levels of device tracing: (version >= 1)
// - Level 0 is used to disable device traces.
// - Level 1 is used to enable device traces.
// - More levels might be defined for specific device for controlling the
// verbosity of the trace.
uint32 device_tracer_level = 3;
// Whether enable python function calls tracing. Runtime overhead ensues if
// enabled. Default off. (version >= 1)
uint32 python_tracer_level = 4;
// next-field: 6
}
message ToolRequestOptions {
// Required formats for the tool, it should be one of "json", "proto", "raw"
// etc. If not specified (backward compatible), use default format, i.e. most

View File

@ -53,7 +53,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
ProfileResponse* response) override {
VLOG(1) << "Received a profile request: " << req->DebugString();
std::unique_ptr<ProfilerSession> profiler =
ProfilerSession::Create(GetOptions(req->opts()));
ProfilerSession::Create(req->opts());
Status status = profiler->Status();
if (!status.ok()) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
@ -76,19 +76,6 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
return ::grpc::Status::OK;
}
private:
profiler::ProfilerOptions GetOptions(const tensorflow::ProfileOptions& opts) {
profiler::ProfilerOptions options;
if (opts.version()) {
options.host_tracer_level = opts.host_tracer_level();
options.device_tracer_level = opts.device_tracer_level();
options.enable_python_tracer = opts.python_tracer_level() > 0;
} else {
// use default options value;
}
return options;
}
};
} // namespace

View File

@ -93,17 +93,18 @@ class ProfilerSessionWrapper {
}
private:
tensorflow::profiler::ProfilerOptions GetOptions(const py::dict& opts) {
tensorflow::profiler::ProfilerOptions options;
tensorflow::ProfileOptions GetOptions(const py::dict& opts) {
tensorflow::ProfileOptions options =
tensorflow::ProfilerSession::DefaultOptions();
for (const auto& kw : opts) {
std::string key = py::cast<std::string>(kw.first);
if (key == "host_tracer_level") {
options.host_tracer_level = py::cast<int>(kw.second);
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level;
options.set_host_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
} else if (key == "python_tracer_level") {
options.enable_python_tracer = py::cast<int>(kw.second) > 0;
options.set_python_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "enable_python_tracer set to "
<< options.enable_python_tracer;
<< options.python_tracer_level();
}
}
return options;