1. Move the SubgraphAwareProfiler as a private member class inside Subgraph as it's tightly coupled w/ a Subgraph.
2. Adjust the Profiler APIs to avoid certain confusion, esp. where we ignore the passed subgraph index argument. PiperOrigin-RevId: 276044446 Change-Id: I9a3f6542af89ebb386a8055b30019d05c79f3c15
This commit is contained in:
parent
48993d1e2d
commit
5460af59bc
@ -32,55 +32,22 @@ class Profiler {
|
||||
|
||||
virtual ~Profiler() {}
|
||||
|
||||
// Signals the beginning of an event from a subgraph, returning a handle to
|
||||
// the profile event.
|
||||
// Signals the beginning of an event from a subgraph indexed at
|
||||
// 'event_subgraph_index', returning a handle to the profile event.
|
||||
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata,
|
||||
uint32_t event_subgraph_index) = 0;
|
||||
|
||||
// Signals the beginning of an event, returning a handle to the profile event.
|
||||
// The subgraph where an event comes from will be determined implicilty.
|
||||
// Similar w/ the above, but the event comes from the primary subgraph that's
|
||||
// indexed at 0.
|
||||
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) = 0;
|
||||
uint32_t event_metadata) {
|
||||
return BeginEvent(tag, event_type, event_metadata, /*primary subgraph*/ 0);
|
||||
}
|
||||
|
||||
// Signals an end to the specified profile event.
|
||||
virtual void EndEvent(uint32_t event_handle) = 0;
|
||||
};
|
||||
|
||||
// SubgraphAwareProfiler is a profiler that takes care of event tracing in a
|
||||
// certain subgraph.
|
||||
class SubgraphAwareProfiler : public Profiler {
|
||||
public:
|
||||
// Constructor should be called with the non-nullptr profiler argument.
|
||||
explicit SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
|
||||
: profiler_(profiler), subgraph_index_(subgraph_index) {}
|
||||
~SubgraphAwareProfiler() override {}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata,
|
||||
uint32_t subgraph_index /*ignore*/) override {
|
||||
// It assumes that this profiler only produces events from the subgraph that
|
||||
// is provided at the creation of the profiler.
|
||||
return BeginEvent(tag, event_type, event_metadata);
|
||||
}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) override {
|
||||
// Assume that the wrapped profiler is not nullptr.
|
||||
return profiler_->BeginEvent(tag, event_type, event_metadata,
|
||||
subgraph_index_);
|
||||
}
|
||||
|
||||
void EndEvent(uint32_t event_handle) override {
|
||||
// Assume that the wrapped profiler is not nullptr.
|
||||
profiler_->EndEvent(event_handle);
|
||||
}
|
||||
|
||||
private:
|
||||
Profiler* const profiler_;
|
||||
const uint32_t subgraph_index_;
|
||||
};
|
||||
|
||||
// Adds a profile event to `profiler` that begins with the construction
|
||||
// of the object and ends when the object goes out of scope.
|
||||
// The lifetime of tag should be at least the lifetime of `profiler`.
|
||||
|
@ -285,9 +285,15 @@ class Subgraph {
|
||||
// WARNING: This is an experimental API and subject to change.
|
||||
TfLiteStatus ResetVariableTensors();
|
||||
|
||||
void SetProfiler(std::unique_ptr<Profiler> profiler) {
|
||||
profiler_ = std::move(profiler);
|
||||
context_.profiler = profiler_.get();
|
||||
void SetProfiler(Profiler* profiler, int associated_subgraph_idx) {
|
||||
if (!profiler) {
|
||||
profiler_.reset(nullptr);
|
||||
context_.profiler = nullptr;
|
||||
} else {
|
||||
profiler_.reset(
|
||||
new SubgraphAwareProfiler(profiler, associated_subgraph_idx));
|
||||
context_.profiler = profiler_.get();
|
||||
}
|
||||
}
|
||||
|
||||
Profiler* GetProfiler() { return profiler_.get(); }
|
||||
@ -302,6 +308,40 @@ class Subgraph {
|
||||
bool HasDynamicTensors() { return has_dynamic_tensors_; }
|
||||
|
||||
private:
|
||||
// SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
|
||||
// BufferedProfiler instance, and takes care of event profiling/tracing in a
|
||||
// certain subgraph.
|
||||
class SubgraphAwareProfiler : public Profiler {
|
||||
public:
|
||||
// Constructor should be called with the non-nullptr profiler argument.
|
||||
SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
|
||||
: profiler_(profiler), subgraph_index_(subgraph_index) {}
|
||||
~SubgraphAwareProfiler() override {}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata,
|
||||
uint32_t subgraph_index) override {
|
||||
if (!profiler_) return 0;
|
||||
return profiler_->BeginEvent(tag, event_type, event_metadata,
|
||||
subgraph_index);
|
||||
}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) override {
|
||||
return BeginEvent(tag, event_type, event_metadata, subgraph_index_);
|
||||
}
|
||||
|
||||
void EndEvent(uint32_t event_handle) override {
|
||||
if (!profiler_) return;
|
||||
profiler_->EndEvent(event_handle);
|
||||
}
|
||||
|
||||
private:
|
||||
// Not own the memory.
|
||||
Profiler* const profiler_;
|
||||
const uint32_t subgraph_index_;
|
||||
};
|
||||
|
||||
// Prevent 'context_' from accessing functions that are only available to
|
||||
// delegated kernels.
|
||||
void SwitchToKernelContext();
|
||||
@ -570,7 +610,7 @@ class Subgraph {
|
||||
bool tensor_resized_since_op_invoke_ = false;
|
||||
|
||||
// Profiler for this interpreter instance.
|
||||
std::unique_ptr<Profiler> profiler_;
|
||||
std::unique_ptr<SubgraphAwareProfiler> profiler_;
|
||||
|
||||
// A pointer to vector of subgraphs. The vector is owned by the interpreter.
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs_ = nullptr;
|
||||
|
@ -316,13 +316,7 @@ TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
|
||||
void Interpreter::SetProfiler(Profiler* profiler) {
|
||||
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
|
||||
++subgraph_index) {
|
||||
if (profiler != nullptr) {
|
||||
subgraphs_[subgraph_index]->SetProfiler(std::unique_ptr<Profiler>(
|
||||
new SubgraphAwareProfiler(profiler, subgraph_index)));
|
||||
} else {
|
||||
subgraphs_[subgraph_index]->SetProfiler(
|
||||
std::unique_ptr<Profiler>(nullptr));
|
||||
}
|
||||
subgraphs_[subgraph_index]->SetProfiler(profiler, subgraph_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,12 +84,6 @@ class BufferedProfiler : public tflite::Profiler {
|
||||
event_subgraph_index);
|
||||
}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) override {
|
||||
return buffer_.BeginEvent(tag, event_type, event_metadata,
|
||||
/*primary graph*/ 0);
|
||||
}
|
||||
|
||||
void EndEvent(uint32_t event_handle) override {
|
||||
buffer_.EndEvent(event_handle);
|
||||
}
|
||||
|
@ -27,12 +27,11 @@ namespace profiling {
|
||||
class NoopProfiler : public tflite::Profiler {
|
||||
public:
|
||||
NoopProfiler() {}
|
||||
NoopProfiler(int max_profiling_buffer_entries) {}
|
||||
explicit NoopProfiler(int max_profiling_buffer_entries) {}
|
||||
|
||||
uint32_t BeginEvent(const char*, EventType, uint32_t, uint32_t) override {
|
||||
return 0;
|
||||
}
|
||||
uint32_t BeginEvent(const char*, EventType, uint32_t) override { return 0; }
|
||||
void EndEvent(uint32_t) override {}
|
||||
|
||||
void StartProfiling() {}
|
||||
|
Loading…
Reference in New Issue
Block a user