1. Move the SubgraphAwareProfiler as a private member class inside Subgraph as it's tightly coupled w/ a Subgraph.

2. Adjust the Profiler APIs to avoid certain confusion, esp. where we ignore the passed subgraph index argument.

PiperOrigin-RevId: 276044446
Change-Id: I9a3f6542af89ebb386a8055b30019d05c79f3c15
This commit is contained in:
Chao Mei 2019-10-22 05:20:05 -07:00 committed by TensorFlower Gardener
parent 48993d1e2d
commit 5460af59bc
5 changed files with 53 additions and 59 deletions

View File

@ -32,55 +32,22 @@ class Profiler {
virtual ~Profiler() {}
// Signals the beginning of an event from a subgraph, returning a handle to
// the profile event.
// Signals the beginning of an event from a subgraph indexed at
// 'event_subgraph_index', returning a handle to the profile event.
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata,
uint32_t event_subgraph_index) = 0;
// Signals the beginning of an event, returning a handle to the profile event.
// The subgraph where an event comes from will be determined implicilty.
// Similar w/ the above, but the event comes from the primary subgraph that's
// indexed at 0.
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) = 0;
uint32_t event_metadata) {
return BeginEvent(tag, event_type, event_metadata, /*primary subgraph*/ 0);
}
// Signals an end to the specified profile event.
virtual void EndEvent(uint32_t event_handle) = 0;
};
// SubgraphAwareProfiler is a profiler that takes care of event tracing in a
// certain subgraph.
class SubgraphAwareProfiler : public Profiler {
public:
// Constructor should be called with the non-nullptr profiler argument.
explicit SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
: profiler_(profiler), subgraph_index_(subgraph_index) {}
~SubgraphAwareProfiler() override {}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata,
uint32_t subgraph_index /*ignore*/) override {
// It assumes that this profiler only produces events from the subgraph that
// is provided at the creation of the profiler.
return BeginEvent(tag, event_type, event_metadata);
}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) override {
// Assume that the wrapped profiler is not nullptr.
return profiler_->BeginEvent(tag, event_type, event_metadata,
subgraph_index_);
}
void EndEvent(uint32_t event_handle) override {
// Assume that the wrapped profiler is not nullptr.
profiler_->EndEvent(event_handle);
}
private:
Profiler* const profiler_;
const uint32_t subgraph_index_;
};
// Adds a profile event to `profiler` that begins with the construction
// of the object and ends when the object goes out of scope.
// The lifetime of tag should be at least the lifetime of `profiler`.

View File

@ -285,9 +285,15 @@ class Subgraph {
// WARNING: This is an experimental API and subject to change.
TfLiteStatus ResetVariableTensors();
void SetProfiler(std::unique_ptr<Profiler> profiler) {
profiler_ = std::move(profiler);
context_.profiler = profiler_.get();
void SetProfiler(Profiler* profiler, int associated_subgraph_idx) {
if (!profiler) {
profiler_.reset(nullptr);
context_.profiler = nullptr;
} else {
profiler_.reset(
new SubgraphAwareProfiler(profiler, associated_subgraph_idx));
context_.profiler = profiler_.get();
}
}
Profiler* GetProfiler() { return profiler_.get(); }
@ -302,6 +308,40 @@ class Subgraph {
bool HasDynamicTensors() { return has_dynamic_tensors_; }
private:
// SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
// BufferedProfiler instance, and takes care of event profiling/tracing in a
// certain subgraph.
class SubgraphAwareProfiler : public Profiler {
public:
// Constructor should be called with the non-nullptr profiler argument.
SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
: profiler_(profiler), subgraph_index_(subgraph_index) {}
~SubgraphAwareProfiler() override {}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata,
uint32_t subgraph_index) override {
if (!profiler_) return 0;
return profiler_->BeginEvent(tag, event_type, event_metadata,
subgraph_index);
}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) override {
return BeginEvent(tag, event_type, event_metadata, subgraph_index_);
}
void EndEvent(uint32_t event_handle) override {
if (!profiler_) return;
profiler_->EndEvent(event_handle);
}
private:
// Not own the memory.
Profiler* const profiler_;
const uint32_t subgraph_index_;
};
// Prevent 'context_' from accessing functions that are only available to
// delegated kernels.
void SwitchToKernelContext();
@ -570,7 +610,7 @@ class Subgraph {
bool tensor_resized_since_op_invoke_ = false;
// Profiler for this interpreter instance.
std::unique_ptr<Profiler> profiler_;
std::unique_ptr<SubgraphAwareProfiler> profiler_;
// A pointer to vector of subgraphs. The vector is owned by the interpreter.
std::vector<std::unique_ptr<Subgraph>>* subgraphs_ = nullptr;

View File

@ -316,13 +316,7 @@ TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
void Interpreter::SetProfiler(Profiler* profiler) {
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
++subgraph_index) {
if (profiler != nullptr) {
subgraphs_[subgraph_index]->SetProfiler(std::unique_ptr<Profiler>(
new SubgraphAwareProfiler(profiler, subgraph_index)));
} else {
subgraphs_[subgraph_index]->SetProfiler(
std::unique_ptr<Profiler>(nullptr));
}
subgraphs_[subgraph_index]->SetProfiler(profiler, subgraph_index);
}
}

View File

@ -84,12 +84,6 @@ class BufferedProfiler : public tflite::Profiler {
event_subgraph_index);
}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) override {
return buffer_.BeginEvent(tag, event_type, event_metadata,
/*primary graph*/ 0);
}
void EndEvent(uint32_t event_handle) override {
buffer_.EndEvent(event_handle);
}

View File

@ -27,12 +27,11 @@ namespace profiling {
class NoopProfiler : public tflite::Profiler {
public:
NoopProfiler() {}
NoopProfiler(int max_profiling_buffer_entries) {}
explicit NoopProfiler(int max_profiling_buffer_entries) {}
uint32_t BeginEvent(const char*, EventType, uint32_t, uint32_t) override {
return 0;
}
uint32_t BeginEvent(const char*, EventType, uint32_t) override { return 0; }
void EndEvent(uint32_t) override {}
void StartProfiling() {}