Reduce memory overhead of TraceMeRecorder::Consume
PiperOrigin-RevId: 342668511 Change-Id: I6de0cf78379cc33bcdebce10a13c6292d598c5d1
This commit is contained in:
parent
8a13606bbe
commit
6d1218d210
@ -22,7 +22,6 @@ cc_library(
|
||||
"//tensorflow/core/profiler/utils:tf_op_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_builder",
|
||||
"//tensorflow/core/profiler/utils:xplane_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -107,7 +107,6 @@ Status HostTracer::CollectData(RunMetadata* run_metadata) {
|
||||
if (recording_) {
|
||||
return errors::Internal("TraceMeRecorder not stopped");
|
||||
}
|
||||
MakeCompleteEvents(&events_);
|
||||
|
||||
StepStats* step_stats = run_metadata->mutable_step_stats();
|
||||
DeviceStepStats* dev_stats = step_stats->add_dev_stats();
|
||||
@ -117,7 +116,9 @@ Status HostTracer::CollectData(RunMetadata* run_metadata) {
|
||||
constexpr char kUserMetadataMarker = '#';
|
||||
for (TraceMeRecorder::ThreadEvents& thread : events_) {
|
||||
thread_names->insert({thread.thread.tid, thread.thread.name});
|
||||
for (TraceMeRecorder::Event& event : thread.events) {
|
||||
while (!thread.events.empty()) {
|
||||
auto event = std::move(thread.events.front());
|
||||
thread.events.pop_front();
|
||||
if (event.start_time && event.end_time) {
|
||||
NodeExecStats* ns = dev_stats->add_node_stats();
|
||||
if (event.name.back() != kUserMetadataMarker) {
|
||||
@ -149,10 +150,9 @@ Status HostTracer::CollectData(XSpace* space) {
|
||||
if (recording_) {
|
||||
return errors::Internal("TraceMeRecorder not stopped");
|
||||
}
|
||||
MakeCompleteEvents(&events_);
|
||||
XPlane* plane = FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName);
|
||||
ConvertCompleteEventsToXPlane(start_timestamp_ns_, events_, plane);
|
||||
events_.clear();
|
||||
ConvertCompleteEventsToXPlane(start_timestamp_ns_, std::exchange(events_, {}),
|
||||
plane);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -16,9 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h"
|
||||
@ -30,64 +28,45 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
namespace {
|
||||
|
||||
void MakeCompleteEvents(TraceMeRecorder::Events* events) {
|
||||
// Track events created by ActivityStart and copy their data to events created
|
||||
// by ActivityEnd. TraceMe records events in its destructor, so this results
|
||||
// in complete events sorted by their end_time in the thread they ended.
|
||||
// Within the same thread, the record created by ActivityStart must appear
|
||||
// before the record created by ActivityEnd. Cross-thread events must be
|
||||
// processed in a separate pass. A single map can be used because the
|
||||
// activity_id is globally unique.
|
||||
absl::flat_hash_map<int64, TraceMeRecorder::Event*> start_events;
|
||||
std::vector<TraceMeRecorder::Event*> end_events;
|
||||
for (auto& thread : *events) {
|
||||
for (auto& event : thread.events) {
|
||||
if (event.IsStart()) {
|
||||
start_events.emplace(event.ActivityId(), &event);
|
||||
} else if (event.IsEnd()) {
|
||||
auto iter = start_events.find(event.ActivityId());
|
||||
if (iter != start_events.end()) { // same thread
|
||||
auto* start_event = iter->second;
|
||||
event.name = std::move(start_event->name);
|
||||
event.start_time = start_event->start_time;
|
||||
start_events.erase(iter);
|
||||
} else { // cross-thread
|
||||
end_events.push_back(&event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto* event : end_events) { // cross-thread
|
||||
auto iter = start_events.find(event->ActivityId());
|
||||
if (iter != start_events.end()) {
|
||||
auto* start_event = iter->second;
|
||||
event->name = std::move(start_event->name);
|
||||
event->start_time = start_event->start_time;
|
||||
start_events.erase(iter);
|
||||
}
|
||||
void MayAddDisplayName(XEventMetadata* xevent_metadata) {
|
||||
if (!xevent_metadata->display_name().empty()) return;
|
||||
std::string tf_op_event_name = TfOpEventName(xevent_metadata->name());
|
||||
if (tf_op_event_name != xevent_metadata->name()) {
|
||||
xevent_metadata->set_display_name(std::move(tf_op_event_name));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ConvertCompleteEventsToXPlane(uint64 start_timestamp_ns,
|
||||
const TraceMeRecorder::Events& events,
|
||||
TraceMeRecorder::Events&& events,
|
||||
XPlane* raw_plane) {
|
||||
XPlaneBuilder xplane(raw_plane);
|
||||
for (const auto& thread : events) {
|
||||
for (auto& thread : events) {
|
||||
XLineBuilder xline = xplane.GetOrCreateLine(thread.thread.tid);
|
||||
xline.SetName(thread.thread.name);
|
||||
xline.SetTimestampNs(start_timestamp_ns);
|
||||
xline.ReserveEvents(thread.events.size());
|
||||
for (const auto& event : thread.events) {
|
||||
while (!thread.events.empty()) {
|
||||
auto event = std::move(thread.events.front());
|
||||
thread.events.pop_front();
|
||||
if (!event.IsComplete()) continue;
|
||||
if (event.start_time < start_timestamp_ns) continue;
|
||||
if (!HasMetadata(event.name)) {
|
||||
XEventMetadata* xevent_metadata =
|
||||
xplane.GetOrCreateEventMetadata(std::move(event.name));
|
||||
MayAddDisplayName(xevent_metadata);
|
||||
XEventBuilder xevent = xline.AddEvent(*xevent_metadata);
|
||||
xevent.SetTimestampNs(event.start_time);
|
||||
xevent.SetEndTimestampNs(event.end_time);
|
||||
continue;
|
||||
}
|
||||
Annotation annotation = ParseAnnotation(event.name);
|
||||
XEventMetadata* xevent_metadata =
|
||||
xplane.GetOrCreateEventMetadata(annotation.name);
|
||||
std::string tf_op_event_name = TfOpEventName(annotation.name);
|
||||
if (tf_op_event_name != annotation.name) {
|
||||
xevent_metadata->set_display_name(std::move(tf_op_event_name));
|
||||
}
|
||||
MayAddDisplayName(xevent_metadata);
|
||||
XEventBuilder xevent = xline.AddEvent(*xevent_metadata);
|
||||
xevent.SetTimestampNs(event.start_time);
|
||||
xevent.SetEndTimestampNs(event.end_time);
|
||||
|
@ -22,13 +22,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
// Combine events created by TraceMe::ActivityStart and TraceMe::ActivityEnd,
|
||||
// which can be paired up by their activity_id.
|
||||
void MakeCompleteEvents(TraceMeRecorder::Events* events);
|
||||
|
||||
// Convert complete events to XPlane format.
|
||||
void ConvertCompleteEventsToXPlane(uint64 start_timestamp_ns,
|
||||
const TraceMeRecorder::Events& events,
|
||||
TraceMeRecorder::Events&& events,
|
||||
XPlane* raw_plane);
|
||||
|
||||
} // namespace profiler
|
||||
|
@ -45,6 +45,53 @@ static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic<int> was lock free");
|
||||
|
||||
namespace {
|
||||
|
||||
// Track events created by ActivityStart and merge their data into events
|
||||
// created by ActivityEnd. TraceMe records events in its destructor, so this
|
||||
// results in complete events sorted by their end_time in the thread they ended.
|
||||
// Within the same thread, the record created by ActivityStart must appear
|
||||
// before the record created by ActivityEnd. Cross-thread events must be
|
||||
// processed in a separate pass. A single map can be used because the
|
||||
// activity_id is globally unique.
|
||||
class SplitEventTracker {
|
||||
public:
|
||||
void AddStart(TraceMeRecorder::Event&& event) {
|
||||
DCHECK(event.IsStart());
|
||||
start_events_.emplace(event.ActivityId(), std::move(event));
|
||||
}
|
||||
|
||||
void AddEnd(TraceMeRecorder::Event* event) {
|
||||
DCHECK(event->IsEnd());
|
||||
if (!FindStartAndMerge(event)) {
|
||||
end_events_.push_back(event);
|
||||
}
|
||||
}
|
||||
|
||||
void HandleCrossThreadEvents() {
|
||||
for (auto* event : end_events_) {
|
||||
FindStartAndMerge(event);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Finds the start of the given event and merges data into it.
|
||||
bool FindStartAndMerge(TraceMeRecorder::Event* event) {
|
||||
auto iter = start_events_.find(event->ActivityId());
|
||||
if (iter == start_events_.end()) return false;
|
||||
auto& start_event = iter->second;
|
||||
event->name = std::move(start_event.name);
|
||||
event->start_time = start_event.start_time;
|
||||
start_events_.erase(iter);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Start events are collected from each ThreadLocalRecorder::Consume() call.
|
||||
// Their data is merged into end_events.
|
||||
absl::flat_hash_map<int64, TraceMeRecorder::Event> start_events_;
|
||||
|
||||
// End events are stored in the output of TraceMeRecorder::Consume().
|
||||
std::vector<TraceMeRecorder::Event*> end_events_;
|
||||
};
|
||||
|
||||
// A single-producer single-consumer queue of Events.
|
||||
//
|
||||
// Implemented as a linked-list of blocks containing numbered slots, with start
|
||||
@ -114,13 +161,25 @@ class EventQueue {
|
||||
// Consume is only called from ThreadLocalRecorder::Clear, which in turn is
|
||||
// only called while holding TraceMeRecorder::Mutex, so Consume has a single
|
||||
// caller at a time.
|
||||
TF_MUST_USE_RESULT std::vector<TraceMeRecorder::Event> Consume() {
|
||||
TF_MUST_USE_RESULT std::deque<TraceMeRecorder::Event> Consume(
|
||||
SplitEventTracker* split_event_tracker) {
|
||||
// Read index before contents.
|
||||
size_t end = end_.load(std::memory_order_acquire);
|
||||
std::vector<TraceMeRecorder::Event> result;
|
||||
result.reserve(end - start_);
|
||||
std::deque<TraceMeRecorder::Event> result;
|
||||
while (start_ != end) {
|
||||
result.emplace_back(Pop());
|
||||
TraceMeRecorder::Event event = Pop();
|
||||
// Copy data from start events to end events. TraceMe records events in
|
||||
// its destructor, so this results in complete events sorted by their
|
||||
// end_time in the thread they ended. Within the same thread, the start
|
||||
// event must appear before the corresponding end event.
|
||||
if (event.IsStart()) {
|
||||
split_event_tracker->AddStart(std::move(event));
|
||||
continue;
|
||||
}
|
||||
result.emplace_back(std::move(event));
|
||||
if (result.back().IsEnd()) {
|
||||
split_event_tracker->AddEnd(&result.back());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -205,8 +264,9 @@ class TraceMeRecorder::ThreadLocalRecorder {
|
||||
void Clear() { queue_.Clear(); }
|
||||
|
||||
// Consume is called from the control thread when tracing stops.
|
||||
TF_MUST_USE_RESULT TraceMeRecorder::ThreadEvents Consume() {
|
||||
return {info_, queue_.Consume()};
|
||||
TF_MUST_USE_RESULT TraceMeRecorder::ThreadEvents Consume(
|
||||
SplitEventTracker* split_event_tracker) {
|
||||
return {info_, queue_.Consume(split_event_tracker)};
|
||||
}
|
||||
|
||||
private:
|
||||
@ -283,9 +343,11 @@ void TraceMeRecorder::Clear() {
|
||||
TraceMeRecorder::Events TraceMeRecorder::Consume() {
|
||||
TraceMeRecorder::Events result;
|
||||
result.reserve(threads_.size());
|
||||
SplitEventTracker split_event_tracker;
|
||||
for (auto iter = threads_.begin(); iter != threads_.end();) {
|
||||
auto& recorder = iter->second;
|
||||
TraceMeRecorder::ThreadEvents events = recorder->Consume();
|
||||
TraceMeRecorder::ThreadEvents events =
|
||||
recorder->Consume(&split_event_tracker);
|
||||
if (!events.events.empty()) {
|
||||
result.push_back(std::move(events));
|
||||
}
|
||||
@ -297,6 +359,7 @@ TraceMeRecorder::Events TraceMeRecorder::Consume() {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
split_event_tracker.HandleCrossThreadEvents();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_TRACEME_RECORDER_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -75,7 +76,7 @@ class TraceMeRecorder {
|
||||
};
|
||||
struct ThreadEvents {
|
||||
ThreadInfo thread;
|
||||
std::vector<Event> events;
|
||||
std::deque<Event> events;
|
||||
};
|
||||
using Events = std::vector<ThreadEvents>;
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace {
|
||||
std::vector<absl::string_view> SplitNameAndMetadata(
|
||||
absl::string_view annotation) {
|
||||
std::vector<absl::string_view> parts;
|
||||
if (annotation.empty() || annotation.back() != '#') {
|
||||
if (!HasMetadata(annotation)) {
|
||||
parts.emplace_back(annotation);
|
||||
} else {
|
||||
annotation.remove_suffix(1);
|
||||
|
@ -37,6 +37,11 @@ struct Annotation {
|
||||
};
|
||||
Annotation ParseAnnotation(absl::string_view annotation);
|
||||
|
||||
inline bool HasMetadata(absl::string_view annotation) {
|
||||
constexpr char kUserMetadataMarker = '#';
|
||||
return !annotation.empty() && annotation.back() == kUserMetadataMarker;
|
||||
}
|
||||
|
||||
std::vector<Annotation> ParseAnnotationStack(
|
||||
absl::string_view annotation_stack);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user