diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h index 8e5c4b517bf..fe8dbc6e179 100644 --- a/tensorflow/lite/core/api/profiler.h +++ b/tensorflow/lite/core/api/profiler.h @@ -32,55 +32,22 @@ class Profiler { virtual ~Profiler() {} - // Signals the beginning of an event from a subgraph, returning a handle to - // the profile event. + // Signals the beginning of an event from a subgraph indexed at + // 'event_subgraph_index', returning a handle to the profile event. virtual uint32_t BeginEvent(const char* tag, EventType event_type, uint32_t event_metadata, uint32_t event_subgraph_index) = 0; - - // Signals the beginning of an event, returning a handle to the profile event. - // The subgraph where an event comes from will be determined implicilty. + // Similar w/ the above, but the event comes from the primary subgraph that's + // indexed at 0. virtual uint32_t BeginEvent(const char* tag, EventType event_type, - uint32_t event_metadata) = 0; + uint32_t event_metadata) { + return BeginEvent(tag, event_type, event_metadata, /*primary subgraph*/ 0); + } // Signals an end to the specified profile event. virtual void EndEvent(uint32_t event_handle) = 0; }; -// SubgraphAwareProfiler is a profiler that takes care of event tracing in a -// certain subgraph. -class SubgraphAwareProfiler : public Profiler { - public: - // Constructor should be called with the non-nullptr profiler argument. - explicit SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index) - : profiler_(profiler), subgraph_index_(subgraph_index) {} - ~SubgraphAwareProfiler() override {} - - uint32_t BeginEvent(const char* tag, EventType event_type, - uint32_t event_metadata, - uint32_t subgraph_index /*ignore*/) override { - // It assumes that this profiler only produces events from the subgraph that - // is provided at the creation of the profiler. - return BeginEvent(tag, event_type, event_metadata); - } - - uint32_t BeginEvent(const char* tag, EventType event_type, - uint32_t event_metadata) override { - // Assume that the wrapped profiler is not nullptr. - return profiler_->BeginEvent(tag, event_type, event_metadata, - subgraph_index_); - } - - void EndEvent(uint32_t event_handle) override { - // Assume that the wrapped profiler is not nullptr. - profiler_->EndEvent(event_handle); - } - - private: - Profiler* const profiler_; - const uint32_t subgraph_index_; -}; - // Adds a profile event to `profiler` that begins with the construction // of the object and ends when the object goes out of scope. // The lifetime of tag should be at least the lifetime of `profiler`. diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 72de733c665..17310447c16 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -285,9 +285,15 @@ class Subgraph { // WARNING: This is an experimental API and subject to change. TfLiteStatus ResetVariableTensors(); - void SetProfiler(std::unique_ptr profiler) { - profiler_ = std::move(profiler); - context_.profiler = profiler_.get(); + void SetProfiler(Profiler* profiler, int associated_subgraph_idx) { + if (!profiler) { + profiler_.reset(nullptr); + context_.profiler = nullptr; + } else { + profiler_.reset( + new SubgraphAwareProfiler(profiler, associated_subgraph_idx)); + context_.profiler = profiler_.get(); + } } Profiler* GetProfiler() { return profiler_.get(); } @@ -302,6 +308,40 @@ class Subgraph { bool HasDynamicTensors() { return has_dynamic_tensors_; } private: + // SubgraphAwareProfiler wraps an actual TFLite profiler, such as a + // BufferedProfiler instance, and takes care of event profiling/tracing in a + // certain subgraph. + class SubgraphAwareProfiler : public Profiler { + public: + // Constructor should be called with the non-nullptr profiler argument. + SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index) + : profiler_(profiler), subgraph_index_(subgraph_index) {} + ~SubgraphAwareProfiler() override {} + + uint32_t BeginEvent(const char* tag, EventType event_type, + uint32_t event_metadata, + uint32_t subgraph_index) override { + if (!profiler_) return 0; + return profiler_->BeginEvent(tag, event_type, event_metadata, + subgraph_index); + } + + uint32_t BeginEvent(const char* tag, EventType event_type, + uint32_t event_metadata) override { + return BeginEvent(tag, event_type, event_metadata, subgraph_index_); + } + + void EndEvent(uint32_t event_handle) override { + if (!profiler_) return; + profiler_->EndEvent(event_handle); + } + + private: + // Not own the memory. + Profiler* const profiler_; + const uint32_t subgraph_index_; + }; + // Prevent 'context_' from accessing functions that are only available to // delegated kernels. void SwitchToKernelContext(); @@ -570,7 +610,7 @@ class Subgraph { bool tensor_resized_since_op_invoke_ = false; // Profiler for this interpreter instance. - std::unique_ptr profiler_; + std::unique_ptr profiler_; // A pointer to vector of subgraphs. The vector is owned by the interpreter. std::vector>* subgraphs_ = nullptr; diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index 8fef39f87f7..efa7a335fd9 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -316,13 +316,7 @@ TfLiteStatus Interpreter::GetBufferHandle(int tensor_index, void Interpreter::SetProfiler(Profiler* profiler) { for (int subgraph_index = 0; subgraph_index < subgraphs_.size(); ++subgraph_index) { - if (profiler != nullptr) { - subgraphs_[subgraph_index]->SetProfiler(std::unique_ptr( - new SubgraphAwareProfiler(profiler, subgraph_index))); - } else { - subgraphs_[subgraph_index]->SetProfiler( - std::unique_ptr(nullptr)); - } + subgraphs_[subgraph_index]->SetProfiler(profiler, subgraph_index); } } diff --git a/tensorflow/lite/profiling/buffered_profiler.h b/tensorflow/lite/profiling/buffered_profiler.h index fd7bc6fbc3a..844f1c58daa 100644 --- a/tensorflow/lite/profiling/buffered_profiler.h +++ b/tensorflow/lite/profiling/buffered_profiler.h @@ -84,12 +84,6 @@ class BufferedProfiler : public tflite::Profiler { event_subgraph_index); } - uint32_t BeginEvent(const char* tag, EventType event_type, - uint32_t event_metadata) override { - return buffer_.BeginEvent(tag, event_type, event_metadata, - /*primary graph*/ 0); - } - void EndEvent(uint32_t event_handle) override { buffer_.EndEvent(event_handle); } diff --git a/tensorflow/lite/profiling/noop_profiler.h b/tensorflow/lite/profiling/noop_profiler.h index f7db3f4f0f6..27363fc6788 100644 --- a/tensorflow/lite/profiling/noop_profiler.h +++ b/tensorflow/lite/profiling/noop_profiler.h @@ -27,12 +27,11 @@ namespace profiling { class NoopProfiler : public tflite::Profiler { public: NoopProfiler() {} - NoopProfiler(int max_profiling_buffer_entries) {} + explicit NoopProfiler(int max_profiling_buffer_entries) {} uint32_t BeginEvent(const char*, EventType, uint32_t, uint32_t) override { return 0; } - uint32_t BeginEvent(const char*, EventType, uint32_t) override { return 0; } void EndEvent(uint32_t) override {} void StartProfiling() {}