Extract out the profiling listener creation as a virtual method so that subclasses could overwrite w/ its own profiling listener.
PiperOrigin-RevId: 284142645 Change-Id: Ie66ca638c9e797f5aaf7493b8b5c5180a89eea1f
This commit is contained in:
parent
acb8b516ee
commit
9b3930304b
@ -573,12 +573,8 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
|
|||||||
// Install profilers if necessary right after interpreter is created so that
|
// Install profilers if necessary right after interpreter is created so that
|
||||||
// any memory allocations inside the TFLite runtime could be recorded if the
|
// any memory allocations inside the TFLite runtime could be recorded if the
|
||||||
// installed profiler profile memory usage information.
|
// installed profiler profile memory usage information.
|
||||||
if (params_.Get<bool>("enable_op_profiling")) {
|
profiling_listener_ = MayCreateProfilingListener();
|
||||||
profiling_listener_.reset(new ProfilingListener(
|
if (profiling_listener_) AddListener(profiling_listener_.get());
|
||||||
interpreter_.get(),
|
|
||||||
params_.Get<int32_t>("max_profiling_buffer_entries")));
|
|
||||||
AddListener(profiling_listener_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
|
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
|
||||||
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
|
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
|
||||||
@ -771,6 +767,14 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
|
|||||||
return std::unique_ptr<tflite::OpResolver>(resolver);
|
return std::unique_ptr<tflite::OpResolver>(resolver);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<BenchmarkListener>
|
||||||
|
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
||||||
|
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||||
|
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||||
|
interpreter_.get(),
|
||||||
|
params_.Get<int32_t>("max_profiling_buffer_entries")));
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
||||||
|
|
||||||
} // namespace benchmark
|
} // namespace benchmark
|
||||||
|
@ -71,6 +71,10 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
|||||||
// Allow subclasses to create a customized Op resolver during init.
|
// Allow subclasses to create a customized Op resolver during init.
|
||||||
virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
|
virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
|
||||||
|
|
||||||
|
// Create a BenchmarkListener that's specifically for TFLite profiling if
|
||||||
|
// necessary.
|
||||||
|
virtual std::unique_ptr<BenchmarkListener> MayCreateProfilingListener() const;
|
||||||
|
|
||||||
void CleanUp();
|
void CleanUp();
|
||||||
|
|
||||||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
std::unique_ptr<tflite::FlatBufferModel> model_;
|
||||||
@ -103,8 +107,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
|||||||
|
|
||||||
std::vector<InputLayerInfo> inputs_;
|
std::vector<InputLayerInfo> inputs_;
|
||||||
std::vector<InputTensorData> inputs_data_;
|
std::vector<InputTensorData> inputs_data_;
|
||||||
std::unique_ptr<BenchmarkListener> profiling_listener_;
|
std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr;
|
||||||
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_;
|
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_ = nullptr;
|
||||||
TfLiteDelegatePtrMap delegates_;
|
TfLiteDelegatePtrMap delegates_;
|
||||||
|
|
||||||
std::mt19937 random_engine_;
|
std::mt19937 random_engine_;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user