Extract out the profiling listener creation as a virtual method so that subclasses could overwrite w/ its own profiling listener.
PiperOrigin-RevId: 284142645 Change-Id: Ie66ca638c9e797f5aaf7493b8b5c5180a89eea1f
This commit is contained in:
parent
acb8b516ee
commit
9b3930304b
@ -573,12 +573,8 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
|
||||
// Install profilers if necessary right after interpreter is created so that
|
||||
// any memory allocations inside the TFLite runtime could be recorded if the
|
||||
// installed profiler profile memory usage information.
|
||||
if (params_.Get<bool>("enable_op_profiling")) {
|
||||
profiling_listener_.reset(new ProfilingListener(
|
||||
interpreter_.get(),
|
||||
params_.Get<int32_t>("max_profiling_buffer_entries")));
|
||||
AddListener(profiling_listener_.get());
|
||||
}
|
||||
profiling_listener_ = MayCreateProfilingListener();
|
||||
if (profiling_listener_) AddListener(profiling_listener_.get());
|
||||
|
||||
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
|
||||
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
|
||||
@ -771,6 +767,14 @@ std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
|
||||
return std::unique_ptr<tflite::OpResolver>(resolver);
|
||||
}
|
||||
|
||||
std::unique_ptr<BenchmarkListener>
|
||||
BenchmarkTfLiteModel::MayCreateProfilingListener() const {
|
||||
if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
|
||||
return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
|
||||
interpreter_.get(),
|
||||
params_.Get<int32_t>("max_profiling_buffer_entries")));
|
||||
}
|
||||
|
||||
TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
|
||||
|
||||
} // namespace benchmark
|
||||
|
@ -71,6 +71,10 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
||||
// Allow subclasses to create a customized Op resolver during init.
|
||||
virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
|
||||
|
||||
// Create a BenchmarkListener that's specifically for TFLite profiling if
|
||||
// necessary.
|
||||
virtual std::unique_ptr<BenchmarkListener> MayCreateProfilingListener() const;
|
||||
|
||||
void CleanUp();
|
||||
|
||||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
||||
@ -103,8 +107,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
||||
|
||||
std::vector<InputLayerInfo> inputs_;
|
||||
std::vector<InputTensorData> inputs_data_;
|
||||
std::unique_ptr<BenchmarkListener> profiling_listener_;
|
||||
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_;
|
||||
std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr;
|
||||
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_ = nullptr;
|
||||
TfLiteDelegatePtrMap delegates_;
|
||||
|
||||
std::mt19937 random_engine_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user