Add an option to use platform-wide tracing via TFLite platform profiler.
PiperOrigin-RevId: 298577942 Change-Id: I749dfdf89f4560128b128ac0493e30429f7aad36
This commit is contained in:
parent
27e215e2ed
commit
e3790a6422
@ -148,6 +148,7 @@ cc_library(
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/experimental/ruy/profiler",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/profiling:platform_profiler",
|
||||
"//tensorflow/lite/profiling:profiler",
|
||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
|
@ -82,6 +82,11 @@ and the following optional parameters:
|
||||
blank, passive mode is used by default.
|
||||
* `enable_op_profiling`: `bool` (default=false) \
|
||||
Whether to enable per-operator profiling measurement.
|
||||
* `enable_platform_tracing`: `bool` (default=false) \
|
||||
Whether to enable platform-wide tracing. Needs to be combined with
|
||||
'enable_op_profiling'. Note, the platform-wide tracing might not work if
|
||||
the tool runs as a commandline native binary. For example, on Android, the
|
||||
ATrace-based tracing only works when the tool is launched as an APK.
|
||||
* `hexagon_profiling`: `bool` (default=false) \
|
||||
Whether to profile ops running on hexagon. Needs to be combined with
|
||||
`enable_op_profiling`. When this is set to true the profile of ops
|
||||
|
@ -84,6 +84,8 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
||||
params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0));
|
||||
params.AddParam("profiling_output_csv_file",
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("enable_platform_tracing",
|
||||
BenchmarkParam::Create<bool>(false));
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/op_resolver.h"
|
||||
#include "tensorflow/lite/profiling/platform_profiler.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
|
||||
@ -59,6 +60,20 @@ constexpr int kOpProfilingEnabledDefault = true;
|
||||
constexpr int kOpProfilingEnabledDefault = false;
|
||||
#endif
|
||||
|
||||
// Dumps platform-wide tracing files via a platform-based profiler that's built
|
||||
// upon platform tracing tools, like ATrace on Android etc.
|
||||
class PlatformProfilingListener : public BenchmarkListener {
|
||||
public:
|
||||
explicit PlatformProfilingListener(Interpreter* interpreter) {
|
||||
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||
platform_profiler_ = profiling::CreatePlatformProfiler();
|
||||
interpreter->SetProfiler(platform_profiler_.get());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<tflite::Profiler> platform_profiler_;
|
||||
};
|
||||
|
||||
// Dumps ruy profiling events if the ruy profiler is enabled.
|
||||
class RuyProfileListener : public BenchmarkListener {
|
||||
public:
|
||||
@ -261,6 +276,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
default_params.AddParam("max_delegated_partitions",
|
||||
BenchmarkParam::Create<int32_t>(0));
|
||||
default_params.AddParam("enable_platform_tracing",
|
||||
BenchmarkParam::Create<bool>(false));
|
||||
|
||||
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
||||
delegate_util->AddParams(&default_params);
|
||||
@ -313,7 +330,10 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
||||
"File path to export profile data as CSV, if not set "
|
||||
"prints to stdout."),
|
||||
CreateFlag<int>("max_delegated_partitions", ¶ms_,
|
||||
"Max partitions to be delegated.")};
|
||||
"Max partitions to be delegated."),
|
||||
CreateFlag<bool>("enable_platform_tracing", ¶ms_,
|
||||
"enable platform-wide tracing, only meaningful when "
|
||||
"--enable_op_profiling is set to true.")};
|
||||
|
||||
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
|
||||
|
||||
@ -356,6 +376,8 @@ void BenchmarkTfLiteModel::LogParams() {
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Max number of delegated partitions : ["
|
||||
<< params_.Get<int32_t>("max_delegated_partitions") << "]";
|
||||
TFLITE_LOG(INFO) << "Enable platform-wide tracing: ["
|
||||
<< params_.Get<bool>("enable_platform_tracing") << "]";
|
||||
|
||||
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
|
||||
delegate_util->LogParams(params_);
|
||||
@ -684,6 +706,12 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
|
||||
std::unique_ptr<BenchmarkListener>
|
||||
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
||||
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||
|
||||
if (params_.Get<bool>("enable_platform_tracing")) {
|
||||
return std::unique_ptr<BenchmarkListener>(
|
||||
new PlatformProfilingListener(interpreter_.get()));
|
||||
}
|
||||
|
||||
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
||||
params_.Get<std::string>("profiling_output_csv_file"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user