Extract out the profiling listener creation as a virtual method so that subclasses could overwrite w/ its own profiling listener.

PiperOrigin-RevId: 284142645
Change-Id: Ie66ca638c9e797f5aaf7493b8b5c5180a89eea1f
This commit is contained in:
Chao Mei 2019-12-06 01:44:34 -08:00 committed by TensorFlower Gardener
parent acb8b516ee
commit 9b3930304b
2 changed files with 16 additions and 8 deletions

View File

@ -573,12 +573,8 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
// Install profilers if necessary right after interpreter is created so that
// any memory allocations inside the TFLite runtime could be recorded if the
// installed profiler profile memory usage information.
if (params_.Get<bool>("enable_op_profiling")) {
profiling_listener_.reset(new ProfilingListener(
interpreter_.get(),
params_.Get<int32_t>("max_profiling_buffer_entries")));
AddListener(profiling_listener_.get());
}
profiling_listener_ = MayCreateProfilingListener();
if (profiling_listener_) AddListener(profiling_listener_.get());
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
@ -771,6 +767,14 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
return std::unique_ptr<tflite::OpResolver>(resolver);
}
std::unique_ptr<BenchmarkListener>
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
interpreter_.get(),
params_.Get<int32_t>("max_profiling_buffer_entries")));
}
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
} // namespace benchmark

View File

@ -71,6 +71,10 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
// Allow subclasses to create a customized Op resolver during init.
virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
// Create a BenchmarkListener that's specifically for TFLite profiling if
// necessary.
virtual std::unique_ptr<BenchmarkListener> MayCreateProfilingListener() const;
void CleanUp();
std::unique_ptr<tflite::FlatBufferModel> model_;
@ -103,8 +107,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<InputLayerInfo> inputs_;
std::vector<InputTensorData> inputs_data_;
std::unique_ptr<BenchmarkListener> profiling_listener_;
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_;
std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr;
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_ = nullptr;
TfLiteDelegatePtrMap delegates_;
std::mt19937 random_engine_;