Fix ProfilingListener for subclasses to override.
Fix BenchmarkTfLiteModel to pass ProfileSummaryFormatter to ProfilingListener. PiperOrigin-RevId: 296362673 Change-Id: I9e494202c03d8794effdf11eb1bdf1f69d62d35c
This commit is contained in:
parent
41b6bae3d1
commit
e38ef04eca
@ -89,8 +89,8 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
|
||||
} // namespace
|
||||
|
||||
ProfileSummarizer::ProfileSummarizer(
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter)
|
||||
: summary_formatter_(std::move(summary_formatter)) {
|
||||
std::shared_ptr<ProfileSummaryFormatter> summary_formatter)
|
||||
: summary_formatter_(summary_formatter) {
|
||||
// Create stats calculator for the primary graph.
|
||||
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
|
||||
new tensorflow::StatsCalculator(
|
||||
|
@ -32,8 +32,8 @@ namespace profiling {
|
||||
class ProfileSummarizer {
|
||||
public:
|
||||
explicit ProfileSummarizer(
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter =
|
||||
std::make_unique<ProfileSummaryDefaultFormatter>());
|
||||
std::shared_ptr<ProfileSummaryFormatter> summary_formatter =
|
||||
std::make_shared<ProfileSummaryDefaultFormatter>());
|
||||
virtual ~ProfileSummarizer() {}
|
||||
|
||||
// Process profile events to update statistics for operator invocations.
|
||||
@ -70,7 +70,7 @@ class ProfileSummarizer {
|
||||
std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
|
||||
|
||||
// Summary formatter for customized output formats.
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter_;
|
||||
std::shared_ptr<ProfileSummaryFormatter> summary_formatter_;
|
||||
};
|
||||
|
||||
} // namespace profiling
|
||||
|
@ -118,6 +118,7 @@ cc_library(
|
||||
deps = [
|
||||
":benchmark_model_lib",
|
||||
"//tensorflow/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||
"//tensorflow/lite/profiling:profiler",
|
||||
],
|
||||
)
|
||||
|
@ -185,6 +185,13 @@ std::vector<int> TfLiteIntArrayToVector(const TfLiteIntArray* int_array) {
|
||||
return values;
|
||||
}
|
||||
|
||||
std::shared_ptr<profiling::ProfileSummaryFormatter>
|
||||
CreateProfileSummaryFormatter(bool format_as_csv) {
|
||||
return format_as_csv
|
||||
? std::make_shared<profiling::ProfileSummaryCSVFormatter>()
|
||||
: std::make_shared<profiling::ProfileSummaryDefaultFormatter>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
||||
@ -566,7 +573,9 @@ BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
||||
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
||||
params_.Get<std::string>("profiling_output_csv_file")));
|
||||
params_.Get<std::string>("profiling_output_csv_file"),
|
||||
CreateProfileSummaryFormatter(
|
||||
!params_.Get<std::string>("profiling_output_csv_file").empty())));
|
||||
}
|
||||
|
||||
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
#include "tensorflow/lite/profiling/profiler.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
||||
|
||||
|
@ -20,14 +20,15 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
|
||||
ProfilingListener::ProfilingListener(Interpreter* interpreter,
|
||||
uint32_t max_num_entries,
|
||||
const std::string& csv_file_path)
|
||||
: interpreter_(interpreter),
|
||||
profiler_(max_num_entries),
|
||||
run_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
||||
init_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
||||
csv_file_path_(csv_file_path) {
|
||||
ProfilingListener::ProfilingListener(
|
||||
Interpreter* interpreter, uint32_t max_num_entries,
|
||||
const std::string& csv_file_path,
|
||||
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter)
|
||||
: run_summarizer_(summarizer_formatter),
|
||||
init_summarizer_(summarizer_formatter),
|
||||
csv_file_path_(csv_file_path),
|
||||
interpreter_(interpreter),
|
||||
profiler_(max_num_entries) {
|
||||
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||
interpreter_->SetProfiler(&profiler_);
|
||||
|
||||
@ -85,12 +86,5 @@ void ProfilingListener::WriteOutput(const std::string& header,
|
||||
(*stream) << data << std::endl;
|
||||
}
|
||||
|
||||
std::unique_ptr<profiling::ProfileSummaryFormatter>
|
||||
ProfilingListener::CreateProfileSummaryFormatter(bool format_as_csv) const {
|
||||
return format_as_csv
|
||||
? std::make_unique<profiling::ProfileSummaryDefaultFormatter>()
|
||||
: std::make_unique<profiling::ProfileSummaryCSVFormatter>();
|
||||
}
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
|
@ -16,8 +16,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/profiling/buffered_profiler.h"
|
||||
#include "tensorflow/lite/profiling/profile_summarizer.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -26,8 +29,11 @@ namespace benchmark {
|
||||
// Dumps profiling events if profiling is enabled.
|
||||
class ProfilingListener : public BenchmarkListener {
|
||||
public:
|
||||
explicit ProfilingListener(Interpreter* interpreter, uint32_t max_num_entries,
|
||||
const std::string& csv_file_path = "");
|
||||
ProfilingListener(
|
||||
Interpreter* interpreter, uint32_t max_num_entries,
|
||||
const std::string& csv_file_path = "",
|
||||
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter =
|
||||
std::make_shared<profiling::ProfileSummaryDefaultFormatter>());
|
||||
|
||||
void OnBenchmarkStart(const BenchmarkParams& params) override;
|
||||
|
||||
@ -38,18 +44,15 @@ class ProfilingListener : public BenchmarkListener {
|
||||
void OnBenchmarkEnd(const BenchmarkResults& results) override;
|
||||
|
||||
protected:
|
||||
// Allow subclasses to create a customized summary writer during init.
|
||||
virtual std::unique_ptr<profiling::ProfileSummaryFormatter>
|
||||
CreateProfileSummaryFormatter(bool format_as_csv) const;
|
||||
profiling::ProfileSummarizer run_summarizer_;
|
||||
profiling::ProfileSummarizer init_summarizer_;
|
||||
std::string csv_file_path_;
|
||||
|
||||
private:
|
||||
void WriteOutput(const std::string& header, const string& data,
|
||||
std::ostream* stream);
|
||||
Interpreter* interpreter_;
|
||||
profiling::BufferedProfiler profiler_;
|
||||
profiling::ProfileSummarizer run_summarizer_;
|
||||
profiling::ProfileSummarizer init_summarizer_;
|
||||
std::string csv_file_path_;
|
||||
};
|
||||
|
||||
} // namespace benchmark
|
||||
|
Loading…
Reference in New Issue
Block a user