Add an option to use platform-wide tracing via TFLite platform profiler.
PiperOrigin-RevId: 298577942 Change-Id: I749dfdf89f4560128b128ac0493e30429f7aad36
This commit is contained in:
parent
27e215e2ed
commit
e3790a6422
@ -148,6 +148,7 @@ cc_library(
|
|||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/experimental/ruy/profiler",
|
"//tensorflow/lite/experimental/ruy/profiler",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"//tensorflow/lite/profiling:platform_profiler",
|
||||||
"//tensorflow/lite/profiling:profiler",
|
"//tensorflow/lite/profiling:profiler",
|
||||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
|
@ -82,6 +82,11 @@ and the following optional parameters:
|
|||||||
blank, passive mode is used by default.
|
blank, passive mode is used by default.
|
||||||
* `enable_op_profiling`: `bool` (default=false) \
|
* `enable_op_profiling`: `bool` (default=false) \
|
||||||
Whether to enable per-operator profiling measurement.
|
Whether to enable per-operator profiling measurement.
|
||||||
|
* `enable_platform_tracing`: `bool` (default=false) \
|
||||||
|
Whether to enable platform-wide tracing. Needs to be combined with
|
||||||
|
'enable_op_profiling'. Note, the platform-wide tracing might not work if
|
||||||
|
the tool runs as a commandline native binary. For example, on Android, the
|
||||||
|
ATrace-based tracing only works when the tool is launched as an APK.
|
||||||
* `hexagon_profiling`: `bool` (default=false) \
|
* `hexagon_profiling`: `bool` (default=false) \
|
||||||
Whether to profile ops running on hexagon. Needs to be combined with
|
Whether to profile ops running on hexagon. Needs to be combined with
|
||||||
`enable_op_profiling`. When this is set to true the profile of ops
|
`enable_op_profiling`. When this is set to true the profile of ops
|
||||||
|
@ -84,6 +84,8 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
|||||||
params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0));
|
params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0));
|
||||||
params.AddParam("profiling_output_csv_file",
|
params.AddParam("profiling_output_csv_file",
|
||||||
BenchmarkParam::Create<std::string>(""));
|
BenchmarkParam::Create<std::string>(""));
|
||||||
|
params.AddParam("enable_platform_tracing",
|
||||||
|
BenchmarkParam::Create<bool>(false));
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/op_resolver.h"
|
#include "tensorflow/lite/op_resolver.h"
|
||||||
|
#include "tensorflow/lite/profiling/platform_profiler.h"
|
||||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
|
#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
|
||||||
@ -59,6 +60,20 @@ constexpr int kOpProfilingEnabledDefault = true;
|
|||||||
constexpr int kOpProfilingEnabledDefault = false;
|
constexpr int kOpProfilingEnabledDefault = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Dumps platform-wide tracing files via a platform-based profiler that's built
|
||||||
|
// upon platform tracing tools, like ATrace on Android etc.
|
||||||
|
class PlatformProfilingListener : public BenchmarkListener {
|
||||||
|
public:
|
||||||
|
explicit PlatformProfilingListener(Interpreter* interpreter) {
|
||||||
|
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||||
|
platform_profiler_ = profiling::CreatePlatformProfiler();
|
||||||
|
interpreter->SetProfiler(platform_profiler_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<tflite::Profiler> platform_profiler_;
|
||||||
|
};
|
||||||
|
|
||||||
// Dumps ruy profiling events if the ruy profiler is enabled.
|
// Dumps ruy profiling events if the ruy profiler is enabled.
|
||||||
class RuyProfileListener : public BenchmarkListener {
|
class RuyProfileListener : public BenchmarkListener {
|
||||||
public:
|
public:
|
||||||
@ -261,6 +276,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
|||||||
BenchmarkParam::Create<std::string>(""));
|
BenchmarkParam::Create<std::string>(""));
|
||||||
default_params.AddParam("max_delegated_partitions",
|
default_params.AddParam("max_delegated_partitions",
|
||||||
BenchmarkParam::Create<int32_t>(0));
|
BenchmarkParam::Create<int32_t>(0));
|
||||||
|
default_params.AddParam("enable_platform_tracing",
|
||||||
|
BenchmarkParam::Create<bool>(false));
|
||||||
|
|
||||||
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
||||||
delegate_util->AddParams(&default_params);
|
delegate_util->AddParams(&default_params);
|
||||||
@ -313,7 +330,10 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
|||||||
"File path to export profile data as CSV, if not set "
|
"File path to export profile data as CSV, if not set "
|
||||||
"prints to stdout."),
|
"prints to stdout."),
|
||||||
CreateFlag<int>("max_delegated_partitions", ¶ms_,
|
CreateFlag<int>("max_delegated_partitions", ¶ms_,
|
||||||
"Max partitions to be delegated.")};
|
"Max partitions to be delegated."),
|
||||||
|
CreateFlag<bool>("enable_platform_tracing", ¶ms_,
|
||||||
|
"enable platform-wide tracing, only meaningful when "
|
||||||
|
"--enable_op_profiling is set to true.")};
|
||||||
|
|
||||||
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
|
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
|
||||||
|
|
||||||
@ -356,6 +376,8 @@ void BenchmarkTfLiteModel::LogParams() {
|
|||||||
<< "]";
|
<< "]";
|
||||||
TFLITE_LOG(INFO) << "Max number of delegated partitions : ["
|
TFLITE_LOG(INFO) << "Max number of delegated partitions : ["
|
||||||
<< params_.Get<int32_t>("max_delegated_partitions") << "]";
|
<< params_.Get<int32_t>("max_delegated_partitions") << "]";
|
||||||
|
TFLITE_LOG(INFO) << "Enable platform-wide tracing: ["
|
||||||
|
<< params_.Get<bool>("enable_platform_tracing") << "]";
|
||||||
|
|
||||||
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
||||||
delegate_util->LogParams(params_);
|
delegate_util->LogParams(params_);
|
||||||
@ -684,6 +706,12 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
|
|||||||
std::unique_ptr<BenchmarkListener>
|
std::unique_ptr<BenchmarkListener>
|
||||||
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
||||||
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||||
|
|
||||||
|
if (params_.Get<bool>("enable_platform_tracing")) {
|
||||||
|
return std::unique_ptr<BenchmarkListener>(
|
||||||
|
new PlatformProfilingListener(interpreter_.get()));
|
||||||
|
}
|
||||||
|
|
||||||
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||||
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
||||||
params_.Get<std::string>("profiling_output_csv_file"),
|
params_.Get<std::string>("profiling_output_csv_file"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user