diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 197907ec9e8..2edbbd06ec4 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -573,12 +573,8 @@ TfLiteStatus BenchmarkTfLiteModel::Init() { // Install profilers if necessary right after interpreter is created so that // any memory allocations inside the TFLite runtime could be recorded if the // installed profiler profile memory usage information. - if (params_.Get("enable_op_profiling")) { - profiling_listener_.reset(new ProfilingListener( - interpreter_.get(), - params_.Get("max_profiling_buffer_entries"))); - AddListener(profiling_listener_.get()); - } + profiling_listener_ = MayCreateProfilingListener(); + if (profiling_listener_) AddListener(profiling_listener_.get()); interpreter_->UseNNAPI(params_.Get("use_legacy_nnapi")); interpreter_->SetAllowFp16PrecisionForFp32(params_.Get("allow_fp16")); @@ -771,6 +767,14 @@ std::unique_ptr BenchmarkTfLiteModel::GetOpResolver() return std::unique_ptr(resolver); } +std::unique_ptr +BenchmarkTfLiteModel::MayCreateProfilingListener() const { + if (!params_.Get("enable_op_profiling")) return nullptr; + return std::unique_ptr(new ProfilingListener( + interpreter_.get(), + params_.Get("max_profiling_buffer_entries"))); +} + TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); } } // namespace benchmark diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index a6fc38a6180..3778cc968bd 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -71,6 +71,10 @@ class BenchmarkTfLiteModel : public BenchmarkModel { // Allow subclasses to create a customized Op resolver during init. virtual std::unique_ptr GetOpResolver() const; + // Create a BenchmarkListener that's specifically for TFLite profiling if + // necessary. + virtual std::unique_ptr MayCreateProfilingListener() const; + void CleanUp(); std::unique_ptr model_; @@ -103,8 +107,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::vector inputs_; std::vector inputs_data_; - std::unique_ptr profiling_listener_; - std::unique_ptr gemmlowp_profiling_listener_; + std::unique_ptr profiling_listener_ = nullptr; + std::unique_ptr gemmlowp_profiling_listener_ = nullptr; TfLiteDelegatePtrMap delegates_; std::mt19937 random_engine_;