diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc index a4c763e4b28..acf630c93cf 100644 --- a/tensorflow/lite/profiling/profile_summarizer.cc +++ b/tensorflow/lite/profiling/profile_summarizer.cc @@ -89,8 +89,8 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, } // namespace ProfileSummarizer::ProfileSummarizer( - std::unique_ptr summary_formatter) - : summary_formatter_(std::move(summary_formatter)) { + std::shared_ptr summary_formatter) + : summary_formatter_(summary_formatter) { // Create stats calculator for the primary graph. stats_calculator_map_[0] = std::unique_ptr( new tensorflow::StatsCalculator( diff --git a/tensorflow/lite/profiling/profile_summarizer.h b/tensorflow/lite/profiling/profile_summarizer.h index 1348761b792..960c6ba7c3d 100644 --- a/tensorflow/lite/profiling/profile_summarizer.h +++ b/tensorflow/lite/profiling/profile_summarizer.h @@ -32,8 +32,8 @@ namespace profiling { class ProfileSummarizer { public: explicit ProfileSummarizer( - std::unique_ptr summary_formatter = - std::make_unique()); + std::shared_ptr summary_formatter = + std::make_shared()); virtual ~ProfileSummarizer() {} // Process profile events to update statistics for operator invocations. @@ -70,7 +70,7 @@ class ProfileSummarizer { std::unique_ptr delegate_stats_calculator_; // Summary formatter for customized output formats. - std::unique_ptr summary_formatter_; + std::shared_ptr summary_formatter_; }; } // namespace profiling diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 72968fc8e24..5a413112e2f 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -118,6 +118,7 @@ cc_library( deps = [ ":benchmark_model_lib", "//tensorflow/lite/profiling:profile_summarizer", + "//tensorflow/lite/profiling:profile_summary_formatter", "//tensorflow/lite/profiling:profiler", ], ) diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 23b76a921c5..6b1e9819312 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -185,6 +185,13 @@ std::vector TfLiteIntArrayToVector(const TfLiteIntArray* int_array) { return values; } +std::shared_ptr +CreateProfileSummaryFormatter(bool format_as_csv) { + return format_as_csv + ? std::make_shared() + : std::make_shared(); +} + } // namespace BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { @@ -566,7 +573,9 @@ BenchmarkTfLiteModel::MayCreateProfilingListener() const { if (!params_.Get("enable_op_profiling")) return nullptr; return std::unique_ptr(new ProfilingListener( interpreter_.get(), params_.Get("max_profiling_buffer_entries"), - params_.Get("profiling_output_csv_file"))); + params_.Get("profiling_output_csv_file"), + CreateProfileSummaryFormatter( + !params_.Get("profiling_output_csv_file").empty()))); } TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 1d056bdf0cf..a0bcce843ab 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -24,7 +24,6 @@ limitations under the License. #include #include "tensorflow/lite/model.h" -#include "tensorflow/lite/profiling/profile_summary_formatter.h" #include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/tools/benchmark/benchmark_model.h" diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.cc b/tensorflow/lite/tools/benchmark/profiling_listener.cc index 8d7a0fe3537..50df69c4b7c 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.cc +++ b/tensorflow/lite/tools/benchmark/profiling_listener.cc @@ -20,14 +20,15 @@ limitations under the License. namespace tflite { namespace benchmark { -ProfilingListener::ProfilingListener(Interpreter* interpreter, - uint32_t max_num_entries, - const std::string& csv_file_path) - : interpreter_(interpreter), - profiler_(max_num_entries), - run_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())), - init_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())), - csv_file_path_(csv_file_path) { +ProfilingListener::ProfilingListener( + Interpreter* interpreter, uint32_t max_num_entries, + const std::string& csv_file_path, + std::shared_ptr summarizer_formatter) + : run_summarizer_(summarizer_formatter), + init_summarizer_(summarizer_formatter), + csv_file_path_(csv_file_path), + interpreter_(interpreter), + profiler_(max_num_entries) { TFLITE_BENCHMARK_CHECK(interpreter); interpreter_->SetProfiler(&profiler_); @@ -85,12 +86,5 @@ void ProfilingListener::WriteOutput(const std::string& header, (*stream) << data << std::endl; } -std::unique_ptr -ProfilingListener::CreateProfileSummaryFormatter(bool format_as_csv) const { - return format_as_csv - ? std::make_unique() - : std::make_unique(); -} - } // namespace benchmark } // namespace tflite diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.h b/tensorflow/lite/tools/benchmark/profiling_listener.h index 9c0f6745bbb..0b2772baea1 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.h +++ b/tensorflow/lite/tools/benchmark/profiling_listener.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_ #define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_ +#include + #include "tensorflow/lite/profiling/buffered_profiler.h" #include "tensorflow/lite/profiling/profile_summarizer.h" +#include "tensorflow/lite/profiling/profile_summary_formatter.h" #include "tensorflow/lite/tools/benchmark/benchmark_model.h" namespace tflite { @@ -26,8 +29,11 @@ namespace benchmark { // Dumps profiling events if profiling is enabled. class ProfilingListener : public BenchmarkListener { public: - explicit ProfilingListener(Interpreter* interpreter, uint32_t max_num_entries, - const std::string& csv_file_path = ""); + ProfilingListener( + Interpreter* interpreter, uint32_t max_num_entries, + const std::string& csv_file_path = "", + std::shared_ptr summarizer_formatter = + std::make_shared()); void OnBenchmarkStart(const BenchmarkParams& params) override; @@ -38,18 +44,15 @@ class ProfilingListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override; protected: - // Allow subclasses to create a customized summary writer during init. - virtual std::unique_ptr - CreateProfileSummaryFormatter(bool format_as_csv) const; + profiling::ProfileSummarizer run_summarizer_; + profiling::ProfileSummarizer init_summarizer_; + std::string csv_file_path_; private: void WriteOutput(const std::string& header, const string& data, std::ostream* stream); Interpreter* interpreter_; profiling::BufferedProfiler profiler_; - profiling::ProfileSummarizer run_summarizer_; - profiling::ProfileSummarizer init_summarizer_; - std::string csv_file_path_; }; } // namespace benchmark