Add an option to use platform-wide tracing via TFLite platform profiler.

PiperOrigin-RevId: 298577942
Change-Id: I749dfdf89f4560128b128ac0493e30429f7aad36
This commit is contained in:
Chao Mei 2020-03-03 04:48:42 -08:00 committed by TensorFlower Gardener
parent 27e215e2ed
commit e3790a6422
4 changed files with 37 additions and 1 deletions

View File

@ -148,6 +148,7 @@ cc_library(
"//tensorflow/lite:string_util",
"//tensorflow/lite/experimental/ruy/profiler",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/profiling:platform_profiler",
"//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/profiling:profile_summary_formatter",
"//tensorflow/lite/tools/evaluation:utils",

View File

@ -82,6 +82,11 @@ and the following optional parameters:
blank, passive mode is used by default.
* `enable_op_profiling`: `bool` (default=false) \
Whether to enable per-operator profiling measurement.
* `enable_platform_tracing`: `bool` (default=false) \
Whether to enable platform-wide tracing. Needs to be combined with
'enable_op_profiling'. Note, the platform-wide tracing might not work if
the tool runs as a commandline native binary. For example, on Android, the
ATrace-based tracing only works when the tool is launched as an APK.
* `hexagon_profiling`: `bool` (default=false) \
Whether to profile ops running on hexagon. Needs to be combined with
`enable_op_profiling`. When this is set to true the profile of ops

View File

@ -84,6 +84,8 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0));
params.AddParam("profiling_output_csv_file",
BenchmarkParam::Create<std::string>(""));
params.AddParam("enable_platform_tracing",
BenchmarkParam::Create<bool>(false));
return params;
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/profiling/platform_profiler.h"
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
@ -59,6 +60,20 @@ constexpr int kOpProfilingEnabledDefault = true;
constexpr int kOpProfilingEnabledDefault = false;
#endif
// Dumps platform-wide tracing files via a platform-based profiler that's built
// upon platform tracing tools, like ATrace on Android etc.
class PlatformProfilingListener : public BenchmarkListener {
public:
explicit PlatformProfilingListener(Interpreter* interpreter) {
TFLITE_BENCHMARK_CHECK(interpreter);
platform_profiler_ = profiling::CreatePlatformProfiler();
interpreter->SetProfiler(platform_profiler_.get());
}
private:
std::unique_ptr<tflite::Profiler> platform_profiler_;
};
// Dumps ruy profiling events if the ruy profiler is enabled.
class RuyProfileListener : public BenchmarkListener {
public:
@ -261,6 +276,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
BenchmarkParam::Create<std::string>(""));
default_params.AddParam("max_delegated_partitions",
BenchmarkParam::Create<int32_t>(0));
default_params.AddParam("enable_platform_tracing",
BenchmarkParam::Create<bool>(false));
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
delegate_util->AddParams(&default_params);
@ -313,7 +330,10 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
"File path to export profile data as CSV, if not set "
"prints to stdout."),
CreateFlag<int>("max_delegated_partitions", &params_,
"Max partitions to be delegated.")};
"Max partitions to be delegated."),
CreateFlag<bool>("enable_platform_tracing", &params_,
"enable platform-wide tracing, only meaningful when "
"--enable_op_profiling is set to true.")};
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
@ -356,6 +376,8 @@ void BenchmarkTfLiteModel::LogParams() {
<< "]";
TFLITE_LOG(INFO) << "Max number of delegated partitions : ["
<< params_.Get<int32_t>("max_delegated_partitions") << "]";
TFLITE_LOG(INFO) << "Enable platform-wide tracing: ["
<< params_.Get<bool>("enable_platform_tracing") << "]";
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
delegate_util->LogParams(params_);
@ -684,6 +706,12 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
std::unique_ptr<BenchmarkListener>
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
if (params_.Get<bool>("enable_platform_tracing")) {
return std::unique_ptr<BenchmarkListener>(
new PlatformProfilingListener(interpreter_.get()));
}
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
params_.Get<std::string>("profiling_output_csv_file"),