Fix ProfilingListener for subclasses to override.

Fix BenchmarkTfLiteModel to pass ProfileSummaryFormatter to ProfilingListener.

PiperOrigin-RevId: 296362673
Change-Id: I9e494202c03d8794effdf11eb1bdf1f69d62d35c
This commit is contained in:
Dayeong Lee 2020-02-20 21:23:59 -08:00 committed by TensorFlower Gardener
parent 41b6bae3d1
commit e38ef04eca
7 changed files with 36 additions and 30 deletions

View File

@ -89,8 +89,8 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
} // namespace
ProfileSummarizer::ProfileSummarizer(
std::unique_ptr<ProfileSummaryFormatter> summary_formatter)
: summary_formatter_(std::move(summary_formatter)) {
std::shared_ptr<ProfileSummaryFormatter> summary_formatter)
: summary_formatter_(summary_formatter) {
// Create stats calculator for the primary graph.
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
new tensorflow::StatsCalculator(

View File

@ -32,8 +32,8 @@ namespace profiling {
class ProfileSummarizer {
public:
explicit ProfileSummarizer(
std::unique_ptr<ProfileSummaryFormatter> summary_formatter =
std::make_unique<ProfileSummaryDefaultFormatter>());
std::shared_ptr<ProfileSummaryFormatter> summary_formatter =
std::make_shared<ProfileSummaryDefaultFormatter>());
virtual ~ProfileSummarizer() {}
// Process profile events to update statistics for operator invocations.
@ -70,7 +70,7 @@ class ProfileSummarizer {
std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
// Summary formatter for customized output formats.
std::unique_ptr<ProfileSummaryFormatter> summary_formatter_;
std::shared_ptr<ProfileSummaryFormatter> summary_formatter_;
};
} // namespace profiling

View File

@ -118,6 +118,7 @@ cc_library(
deps = [
":benchmark_model_lib",
"//tensorflow/lite/profiling:profile_summarizer",
"//tensorflow/lite/profiling:profile_summary_formatter",
"//tensorflow/lite/profiling:profiler",
],
)

View File

@ -185,6 +185,13 @@ std::vector<int> TfLiteIntArrayToVector(const TfLiteIntArray* int_array) {
return values;
}
std::shared_ptr<profiling::ProfileSummaryFormatter>
CreateProfileSummaryFormatter(bool format_as_csv) {
return format_as_csv
? std::make_shared<profiling::ProfileSummaryCSVFormatter>()
: std::make_shared<profiling::ProfileSummaryDefaultFormatter>();
}
} // namespace
BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
@ -566,7 +573,9 @@ BenchmarkTfLiteModel::MayCreateProfilingListener() const {
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
params_.Get<std::string>("profiling_output_csv_file")));
params_.Get<std::string>("profiling_output_csv_file"),
CreateProfileSummaryFormatter(
!params_.Get<std::string>("profiling_output_csv_file").empty())));
}
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"

View File

@ -20,14 +20,15 @@ limitations under the License.
namespace tflite {
namespace benchmark {
ProfilingListener::ProfilingListener(Interpreter* interpreter,
uint32_t max_num_entries,
const std::string& csv_file_path)
: interpreter_(interpreter),
profiler_(max_num_entries),
run_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
init_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
csv_file_path_(csv_file_path) {
ProfilingListener::ProfilingListener(
Interpreter* interpreter, uint32_t max_num_entries,
const std::string& csv_file_path,
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter)
: run_summarizer_(summarizer_formatter),
init_summarizer_(summarizer_formatter),
csv_file_path_(csv_file_path),
interpreter_(interpreter),
profiler_(max_num_entries) {
TFLITE_BENCHMARK_CHECK(interpreter);
interpreter_->SetProfiler(&profiler_);
@ -85,12 +86,5 @@ void ProfilingListener::WriteOutput(const std::string& header,
(*stream) << data << std::endl;
}
std::unique_ptr<profiling::ProfileSummaryFormatter>
ProfilingListener::CreateProfileSummaryFormatter(bool format_as_csv) const {
return format_as_csv
? std::make_unique<profiling::ProfileSummaryDefaultFormatter>()
: std::make_unique<profiling::ProfileSummaryCSVFormatter>();
}
} // namespace benchmark
} // namespace tflite

View File

@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
#include <memory>
#include "tensorflow/lite/profiling/buffered_profiler.h"
#include "tensorflow/lite/profiling/profile_summarizer.h"
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
namespace tflite {
@ -26,8 +29,11 @@ namespace benchmark {
// Dumps profiling events if profiling is enabled.
class ProfilingListener : public BenchmarkListener {
public:
explicit ProfilingListener(Interpreter* interpreter, uint32_t max_num_entries,
const std::string& csv_file_path = "");
ProfilingListener(
Interpreter* interpreter, uint32_t max_num_entries,
const std::string& csv_file_path = "",
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter =
std::make_shared<profiling::ProfileSummaryDefaultFormatter>());
void OnBenchmarkStart(const BenchmarkParams& params) override;
@ -38,18 +44,15 @@ class ProfilingListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override;
protected:
// Allow subclasses to create a customized summary writer during init.
virtual std::unique_ptr<profiling::ProfileSummaryFormatter>
CreateProfileSummaryFormatter(bool format_as_csv) const;
profiling::ProfileSummarizer run_summarizer_;
profiling::ProfileSummarizer init_summarizer_;
std::string csv_file_path_;
private:
void WriteOutput(const std::string& header, const string& data,
std::ostream* stream);
Interpreter* interpreter_;
profiling::BufferedProfiler profiler_;
profiling::ProfileSummarizer run_summarizer_;
profiling::ProfileSummarizer init_summarizer_;
std::string csv_file_path_;
};
} // namespace benchmark