Fix ProfilingListener for subclasses to override.
Fix BenchmarkTfLiteModel to pass ProfileSummaryFormatter to ProfilingListener. PiperOrigin-RevId: 296362673 Change-Id: I9e494202c03d8794effdf11eb1bdf1f69d62d35c
This commit is contained in:
parent
41b6bae3d1
commit
e38ef04eca
@ -89,8 +89,8 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ProfileSummarizer::ProfileSummarizer(
|
ProfileSummarizer::ProfileSummarizer(
|
||||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter)
|
std::shared_ptr<ProfileSummaryFormatter> summary_formatter)
|
||||||
: summary_formatter_(std::move(summary_formatter)) {
|
: summary_formatter_(summary_formatter) {
|
||||||
// Create stats calculator for the primary graph.
|
// Create stats calculator for the primary graph.
|
||||||
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
|
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
|
||||||
new tensorflow::StatsCalculator(
|
new tensorflow::StatsCalculator(
|
||||||
|
@ -32,8 +32,8 @@ namespace profiling {
|
|||||||
class ProfileSummarizer {
|
class ProfileSummarizer {
|
||||||
public:
|
public:
|
||||||
explicit ProfileSummarizer(
|
explicit ProfileSummarizer(
|
||||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter =
|
std::shared_ptr<ProfileSummaryFormatter> summary_formatter =
|
||||||
std::make_unique<ProfileSummaryDefaultFormatter>());
|
std::make_shared<ProfileSummaryDefaultFormatter>());
|
||||||
virtual ~ProfileSummarizer() {}
|
virtual ~ProfileSummarizer() {}
|
||||||
|
|
||||||
// Process profile events to update statistics for operator invocations.
|
// Process profile events to update statistics for operator invocations.
|
||||||
@ -70,7 +70,7 @@ class ProfileSummarizer {
|
|||||||
std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
|
std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
|
||||||
|
|
||||||
// Summary formatter for customized output formats.
|
// Summary formatter for customized output formats.
|
||||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter_;
|
std::shared_ptr<ProfileSummaryFormatter> summary_formatter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace profiling
|
} // namespace profiling
|
||||||
|
@ -118,6 +118,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":benchmark_model_lib",
|
":benchmark_model_lib",
|
||||||
"//tensorflow/lite/profiling:profile_summarizer",
|
"//tensorflow/lite/profiling:profile_summarizer",
|
||||||
|
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||||
"//tensorflow/lite/profiling:profiler",
|
"//tensorflow/lite/profiling:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -185,6 +185,13 @@ std::vector<int> TfLiteIntArrayToVector(const TfLiteIntArray* int_array) {
|
|||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<profiling::ProfileSummaryFormatter>
|
||||||
|
CreateProfileSummaryFormatter(bool format_as_csv) {
|
||||||
|
return format_as_csv
|
||||||
|
? std::make_shared<profiling::ProfileSummaryCSVFormatter>()
|
||||||
|
: std::make_shared<profiling::ProfileSummaryDefaultFormatter>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
|
||||||
@ -566,7 +573,9 @@ BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
|||||||
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||||
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||||
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
|
||||||
params_.Get<std::string>("profiling_output_csv_file")));
|
params_.Get<std::string>("profiling_output_csv_file"),
|
||||||
|
CreateProfileSummaryFormatter(
|
||||||
|
!params_.Get<std::string>("profiling_output_csv_file").empty())));
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
|
||||||
#include "tensorflow/lite/profiling/profiler.h"
|
#include "tensorflow/lite/profiling/profiler.h"
|
||||||
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
||||||
|
|
||||||
|
@ -20,14 +20,15 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace benchmark {
|
namespace benchmark {
|
||||||
|
|
||||||
ProfilingListener::ProfilingListener(Interpreter* interpreter,
|
ProfilingListener::ProfilingListener(
|
||||||
uint32_t max_num_entries,
|
Interpreter* interpreter, uint32_t max_num_entries,
|
||||||
const std::string& csv_file_path)
|
const std::string& csv_file_path,
|
||||||
: interpreter_(interpreter),
|
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter)
|
||||||
profiler_(max_num_entries),
|
: run_summarizer_(summarizer_formatter),
|
||||||
run_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
init_summarizer_(summarizer_formatter),
|
||||||
init_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
csv_file_path_(csv_file_path),
|
||||||
csv_file_path_(csv_file_path) {
|
interpreter_(interpreter),
|
||||||
|
profiler_(max_num_entries) {
|
||||||
TFLITE_BENCHMARK_CHECK(interpreter);
|
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||||
interpreter_->SetProfiler(&profiler_);
|
interpreter_->SetProfiler(&profiler_);
|
||||||
|
|
||||||
@ -85,12 +86,5 @@ void ProfilingListener::WriteOutput(const std::string& header,
|
|||||||
(*stream) << data << std::endl;
|
(*stream) << data << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<profiling::ProfileSummaryFormatter>
|
|
||||||
ProfilingListener::CreateProfileSummaryFormatter(bool format_as_csv) const {
|
|
||||||
return format_as_csv
|
|
||||||
? std::make_unique<profiling::ProfileSummaryDefaultFormatter>()
|
|
||||||
: std::make_unique<profiling::ProfileSummaryCSVFormatter>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace benchmark
|
} // namespace benchmark
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -16,8 +16,11 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
||||||
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/lite/profiling/buffered_profiler.h"
|
#include "tensorflow/lite/profiling/buffered_profiler.h"
|
||||||
#include "tensorflow/lite/profiling/profile_summarizer.h"
|
#include "tensorflow/lite/profiling/profile_summarizer.h"
|
||||||
|
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||||
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -26,8 +29,11 @@ namespace benchmark {
|
|||||||
// Dumps profiling events if profiling is enabled.
|
// Dumps profiling events if profiling is enabled.
|
||||||
class ProfilingListener : public BenchmarkListener {
|
class ProfilingListener : public BenchmarkListener {
|
||||||
public:
|
public:
|
||||||
explicit ProfilingListener(Interpreter* interpreter, uint32_t max_num_entries,
|
ProfilingListener(
|
||||||
const std::string& csv_file_path = "");
|
Interpreter* interpreter, uint32_t max_num_entries,
|
||||||
|
const std::string& csv_file_path = "",
|
||||||
|
std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter =
|
||||||
|
std::make_shared<profiling::ProfileSummaryDefaultFormatter>());
|
||||||
|
|
||||||
void OnBenchmarkStart(const BenchmarkParams& params) override;
|
void OnBenchmarkStart(const BenchmarkParams& params) override;
|
||||||
|
|
||||||
@ -38,18 +44,15 @@ class ProfilingListener : public BenchmarkListener {
|
|||||||
void OnBenchmarkEnd(const BenchmarkResults& results) override;
|
void OnBenchmarkEnd(const BenchmarkResults& results) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Allow subclasses to create a customized summary writer during init.
|
profiling::ProfileSummarizer run_summarizer_;
|
||||||
virtual std::unique_ptr<profiling::ProfileSummaryFormatter>
|
profiling::ProfileSummarizer init_summarizer_;
|
||||||
CreateProfileSummaryFormatter(bool format_as_csv) const;
|
std::string csv_file_path_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void WriteOutput(const std::string& header, const string& data,
|
void WriteOutput(const std::string& header, const string& data,
|
||||||
std::ostream* stream);
|
std::ostream* stream);
|
||||||
Interpreter* interpreter_;
|
Interpreter* interpreter_;
|
||||||
profiling::BufferedProfiler profiler_;
|
profiling::BufferedProfiler profiler_;
|
||||||
profiling::ProfileSummarizer run_summarizer_;
|
|
||||||
profiling::ProfileSummarizer init_summarizer_;
|
|
||||||
std::string csv_file_path_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace benchmark
|
} // namespace benchmark
|
||||||
|
Loading…
Reference in New Issue
Block a user