Remove TraceMeRecorder::Event::activity_id
Reduces TraceMe memory overhead. PiperOrigin-RevId: 342359162 Change-Id: Idbe185aa6e3eeb030a3788567413d8c676e41e75
This commit is contained in:
parent
aa56651979
commit
a3b0b2bf41
@ -109,6 +109,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/profiler/utils:time_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -39,14 +39,14 @@ void MakeCompleteEvents(TraceMeRecorder::Events* events) {
|
||||
// before the record created by ActivityEnd. Cross-thread events must be
|
||||
// processed in a separate pass. A single map can be used because the
|
||||
// activity_id is globally unique.
|
||||
absl::flat_hash_map<uint64, TraceMeRecorder::Event*> start_events;
|
||||
absl::flat_hash_map<int64, TraceMeRecorder::Event*> start_events;
|
||||
std::vector<TraceMeRecorder::Event*> end_events;
|
||||
for (auto& thread : *events) {
|
||||
for (auto& event : thread.events) {
|
||||
if (IsStartEvent(event)) {
|
||||
start_events.emplace(event.activity_id, &event);
|
||||
} else if (IsEndEvent(event)) {
|
||||
auto iter = start_events.find(event.activity_id);
|
||||
if (event.IsStart()) {
|
||||
start_events.emplace(event.ActivityId(), &event);
|
||||
} else if (event.IsEnd()) {
|
||||
auto iter = start_events.find(event.ActivityId());
|
||||
if (iter != start_events.end()) { // same thread
|
||||
auto* start_event = iter->second;
|
||||
event.name = std::move(start_event->name);
|
||||
@ -59,7 +59,7 @@ void MakeCompleteEvents(TraceMeRecorder::Events* events) {
|
||||
}
|
||||
}
|
||||
for (auto* event : end_events) { // cross-thread
|
||||
auto iter = start_events.find(event->activity_id);
|
||||
auto iter = start_events.find(event->ActivityId());
|
||||
if (iter != start_events.end()) {
|
||||
auto* start_event = iter->second;
|
||||
event->name = std::move(start_event->name);
|
||||
@ -79,7 +79,7 @@ void ConvertCompleteEventsToXPlane(uint64 start_timestamp_ns,
|
||||
xline.SetTimestampNs(start_timestamp_ns);
|
||||
xline.ReserveEvents(thread.events.size());
|
||||
for (const auto& event : thread.events) {
|
||||
if (!IsCompleteEvent(event)) continue;
|
||||
if (!event.IsComplete()) continue;
|
||||
if (event.start_time < start_timestamp_ns) continue;
|
||||
Annotation annotation = ParseAnnotation(event.name);
|
||||
XEventMetadata* xevent_metadata =
|
||||
|
@ -22,22 +22,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
// Returns true if event was created by TraceMe::ActivityStart.
|
||||
inline bool IsStartEvent(const TraceMeRecorder::Event& event) {
|
||||
return (event.start_time != 0) && (event.end_time == 0);
|
||||
}
|
||||
|
||||
// Returns true if event was created by TraceMe::ActivityEnd.
|
||||
inline bool IsEndEvent(const TraceMeRecorder::Event& event) {
|
||||
return (event.start_time == 0) && (event.end_time != 0);
|
||||
}
|
||||
|
||||
// Returns true if event was created by TraceMe::Stop or MakeCompleteEvents
|
||||
// below.
|
||||
inline bool IsCompleteEvent(const TraceMeRecorder::Event& event) {
|
||||
return (event.start_time != 0) && (event.end_time != 0);
|
||||
}
|
||||
|
||||
// Combine events created by TraceMe::ActivityStart and TraceMe::ActivityEnd,
|
||||
// which can be paired up by their activity_id.
|
||||
void MakeCompleteEvents(TraceMeRecorder::Events* events);
|
||||
|
@ -254,7 +254,7 @@ bool TraceMeRecorder::StartRecording(int level) {
|
||||
return started;
|
||||
}
|
||||
|
||||
void TraceMeRecorder::Record(Event event) {
|
||||
void TraceMeRecorder::Record(Event&& event) {
|
||||
static thread_local ThreadLocalRecorder thread_local_recorder;
|
||||
thread_local_recorder.Record(std::move(event));
|
||||
}
|
||||
@ -270,16 +270,16 @@ TraceMeRecorder::Events TraceMeRecorder::StopRecording() {
|
||||
return events;
|
||||
}
|
||||
|
||||
/*static*/ uint64 TraceMeRecorder::NewActivityId() {
|
||||
/*static*/ int64 TraceMeRecorder::NewActivityId() {
|
||||
// Activity IDs: To avoid contention over a counter, the top 32 bits identify
|
||||
// the originating thread, the bottom 32 bits name the event within a thread.
|
||||
// IDs may be reused after 4 billion events on one thread, or 4 billion
|
||||
// IDs may be reused after 4 billion events on one thread, or 2 billion
|
||||
// threads.
|
||||
static std::atomic<uint32> thread_counter(1); // avoid kUntracedActivity
|
||||
const thread_local static uint32 thread_id =
|
||||
static std::atomic<int32> thread_counter(1); // avoid kUntracedActivity
|
||||
const thread_local static int32 thread_id =
|
||||
thread_counter.fetch_add(1, std::memory_order_relaxed);
|
||||
thread_local static uint32 per_thread_activity_id = 0;
|
||||
return static_cast<uint64>(thread_id) << 32 | per_thread_activity_id++;
|
||||
return static_cast<int64>(thread_id) << 32 | per_thread_activity_id++;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
@ -44,18 +44,29 @@ TF_EXPORT extern std::atomic<int> g_trace_level;
|
||||
//
|
||||
// This is the backend for TraceMe instrumentation.
|
||||
// The profiler starts the recorder, the TraceMe destructor records complete
|
||||
// events. TraceMe::ActivityStart records begin events, and TraceMe::ActivityEnd
|
||||
// events. TraceMe::ActivityStart records start events, and TraceMe::ActivityEnd
|
||||
// records end events. The profiler then stops the recorder and finds start/end
|
||||
// pairs. (Unpaired start/end events are discarded at that point).
|
||||
class TraceMeRecorder {
|
||||
public:
|
||||
// An Event is either the start of a TraceMe, the end of a TraceMe, or both.
|
||||
// Times are in ns since the Unix epoch.
|
||||
// A negative time encodes the activity_id used to pair up the start of an
|
||||
// event with its end.
|
||||
struct Event {
|
||||
uint64 activity_id;
|
||||
bool IsComplete() const { return start_time > 0 && end_time > 0; }
|
||||
bool IsStart() const { return end_time < 0; }
|
||||
bool IsEnd() const { return start_time < 0; }
|
||||
|
||||
int64 ActivityId() const {
|
||||
if (IsStart()) return -end_time;
|
||||
if (IsEnd()) return -start_time;
|
||||
return 1; // complete
|
||||
}
|
||||
|
||||
std::string name;
|
||||
uint64 start_time; // 0 = missing
|
||||
uint64 end_time; // 0 = missing
|
||||
int64 start_time;
|
||||
int64 end_time;
|
||||
};
|
||||
struct ThreadInfo {
|
||||
uint32 tid;
|
||||
@ -85,10 +96,10 @@ class TraceMeRecorder {
|
||||
static constexpr int kTracingDisabled = -1;
|
||||
|
||||
// Records an event. Non-blocking.
|
||||
static void Record(Event event);
|
||||
static void Record(Event&& event);
|
||||
|
||||
// Returns an activity_id for TraceMe::ActivityStart.
|
||||
static uint64 NewActivityId();
|
||||
static int64 NewActivityId();
|
||||
|
||||
private:
|
||||
class ThreadLocalRecorder;
|
||||
|
@ -15,12 +15,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <istream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/env_time.h"
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/utils/time_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
@ -41,15 +43,15 @@ MATCHER_P(Named, name, "") { return arg.name == name; }
|
||||
constexpr static uint64 kNanosInSec = 1000000000;
|
||||
|
||||
TEST(RecorderTest, SingleThreaded) {
|
||||
uint64 start_time = Env::Default()->NowNanos();
|
||||
uint64 end_time = start_time + kNanosInSec;
|
||||
int64 start_time = GetCurrentTimeNanos();
|
||||
int64 end_time = start_time + kNanosInSec;
|
||||
|
||||
TraceMeRecorder::Record({1, "before", start_time, end_time});
|
||||
TraceMeRecorder::Record({"before", start_time, end_time});
|
||||
TraceMeRecorder::Start(/*level=*/1);
|
||||
TraceMeRecorder::Record({2, "during1", start_time, end_time});
|
||||
TraceMeRecorder::Record({3, "during2", start_time, end_time});
|
||||
TraceMeRecorder::Record({"during1", start_time, end_time});
|
||||
TraceMeRecorder::Record({"during2", start_time, end_time});
|
||||
auto results = TraceMeRecorder::Stop();
|
||||
TraceMeRecorder::Record({4, "after", start_time, end_time});
|
||||
TraceMeRecorder::Record({"after", start_time, end_time});
|
||||
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
EXPECT_THAT(results[0].events,
|
||||
@ -83,15 +85,14 @@ TEST(RecorderTest, Multithreaded) {
|
||||
thread::ThreadPool pool(Env::Default(), "testpool", kNumThreads);
|
||||
std::atomic<int> thread_count = {0};
|
||||
for (int i = 0; i < kNumThreads; i++) {
|
||||
pool.Schedule([&start, &stop, &thread_count, i] {
|
||||
pool.Schedule([&start, &stop, &thread_count] {
|
||||
uint64 j = 0;
|
||||
bool was_active = false;
|
||||
auto record_event = [&j, i]() {
|
||||
uint64 start_time = Env::Default()->NowNanos();
|
||||
uint64 end_time = start_time + kNanosInSec;
|
||||
TraceMeRecorder::Record({/*activity_id=*/j++,
|
||||
/*name=*/absl::StrCat(i), start_time,
|
||||
end_time});
|
||||
auto record_event = [&j]() {
|
||||
int64 start_time = GetCurrentTimeNanos();
|
||||
int64 end_time = start_time + kNanosInSec;
|
||||
TraceMeRecorder::Record(
|
||||
{/*name=*/absl::StrCat(j++), start_time, end_time});
|
||||
};
|
||||
thread_count.fetch_add(1, std::memory_order_relaxed);
|
||||
start.WaitForNotification();
|
||||
@ -121,15 +122,17 @@ TEST(RecorderTest, Multithreaded) {
|
||||
}
|
||||
|
||||
// For each thread, keep track of which events we've seen.
|
||||
struct {
|
||||
struct ThreadState {
|
||||
bool split_session = false;
|
||||
bool overlapping_sessions = false;
|
||||
std::set<uint64> events;
|
||||
} thread_state[kNumThreads];
|
||||
};
|
||||
absl::flat_hash_map<uint32 /*tid*/, ThreadState> thread_state;
|
||||
// We expect each thread to eventually have multiple events, not all in a
|
||||
// contiguous range.
|
||||
auto done = [&thread_state] {
|
||||
for (const auto& t : thread_state) {
|
||||
for (const auto& id_and_thread : thread_state) {
|
||||
auto& t = id_and_thread.second;
|
||||
if (t.events.size() < 2) return false;
|
||||
}
|
||||
return true;
|
||||
@ -153,20 +156,19 @@ TEST(RecorderTest, Multithreaded) {
|
||||
auto results = TraceMeRecorder::Stop();
|
||||
for (const auto& thread : results) {
|
||||
if (thread.events.empty()) continue;
|
||||
std::istringstream ss(thread.events.front().name);
|
||||
int thread_index = 0;
|
||||
ss >> thread_index;
|
||||
auto& state = thread_state[thread_index];
|
||||
auto& state = thread_state[thread.thread.tid];
|
||||
|
||||
std::set<uint64> session_events;
|
||||
uint64 current = 0;
|
||||
for (const auto& event : thread.events) {
|
||||
session_events.emplace(event.activity_id);
|
||||
uint64 activity_id;
|
||||
ASSERT_TRUE(absl::SimpleAtoi(event.name, &activity_id));
|
||||
session_events.emplace(activity_id);
|
||||
// Session events should be contiguous.
|
||||
if (current != 0 && event.activity_id != current + 1) {
|
||||
if (current != 0 && activity_id != current + 1) {
|
||||
state.split_session = true;
|
||||
}
|
||||
current = event.activity_id;
|
||||
current = activity_id;
|
||||
}
|
||||
|
||||
for (const auto& event : session_events) {
|
||||
@ -182,7 +184,8 @@ TEST(RecorderTest, Multithreaded) {
|
||||
}
|
||||
stop.Notify();
|
||||
|
||||
for (const auto& thread : thread_state) {
|
||||
for (const auto& id_and_thread : thread_state) {
|
||||
auto& thread = id_and_thread.second;
|
||||
EXPECT_FALSE(thread.split_session)
|
||||
<< "Expected contiguous events in a session";
|
||||
EXPECT_FALSE(thread.overlapping_sessions) << "Expected disjoint sessions";
|
||||
|
@ -174,8 +174,8 @@ class TraceMe {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) {
|
||||
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
|
||||
TraceMeRecorder::Record({kCompleteActivity, std::move(no_init_.name),
|
||||
start_time_, GetCurrentTimeNanos()});
|
||||
TraceMeRecorder::Record(
|
||||
{std::move(no_init_.name), start_time_, GetCurrentTimeNanos()});
|
||||
}
|
||||
no_init_.name.~string();
|
||||
start_time_ = kUntracedActivity;
|
||||
@ -210,12 +210,12 @@ class TraceMe {
|
||||
// Returns the activity ID, which is used to stop the activity.
|
||||
// Calls `name_generator` to get the name for activity.
|
||||
template <typename NameGeneratorT>
|
||||
static uint64 ActivityStart(NameGeneratorT name_generator, int level = 1) {
|
||||
static int64 ActivityStart(NameGeneratorT name_generator, int level = 1) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
|
||||
uint64 activity_id = TraceMeRecorder::NewActivityId();
|
||||
TraceMeRecorder::Record({activity_id, name_generator(),
|
||||
GetCurrentTimeNanos(), /*end_time=*/0});
|
||||
int64 activity_id = TraceMeRecorder::NewActivityId();
|
||||
TraceMeRecorder::Record(
|
||||
{name_generator(), GetCurrentTimeNanos(), -activity_id});
|
||||
return activity_id;
|
||||
}
|
||||
#endif
|
||||
@ -224,12 +224,12 @@ class TraceMe {
|
||||
|
||||
// Record the start time of an activity.
|
||||
// Returns the activity ID, which is used to stop the activity.
|
||||
static uint64 ActivityStart(absl::string_view name, int level = 1) {
|
||||
static int64 ActivityStart(absl::string_view name, int level = 1) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
|
||||
uint64 activity_id = TraceMeRecorder::NewActivityId();
|
||||
TraceMeRecorder::Record({activity_id, std::string(name),
|
||||
GetCurrentTimeNanos(), /*end_time=*/0});
|
||||
int64 activity_id = TraceMeRecorder::NewActivityId();
|
||||
TraceMeRecorder::Record(
|
||||
{std::string(name), GetCurrentTimeNanos(), -activity_id});
|
||||
return activity_id;
|
||||
}
|
||||
#endif
|
||||
@ -237,23 +237,23 @@ class TraceMe {
|
||||
}
|
||||
|
||||
// Same as ActivityStart above, an overload for "const std::string&"
|
||||
static uint64 ActivityStart(const std::string& name, int level = 1) {
|
||||
static int64 ActivityStart(const std::string& name, int level = 1) {
|
||||
return ActivityStart(absl::string_view(name), level);
|
||||
}
|
||||
|
||||
// Same as ActivityStart above, an overload for "const char*"
|
||||
static uint64 ActivityStart(const char* name, int level = 1) {
|
||||
static int64 ActivityStart(const char* name, int level = 1) {
|
||||
return ActivityStart(absl::string_view(name), level);
|
||||
}
|
||||
|
||||
// Record the end time of an activity started by ActivityStart().
|
||||
static void ActivityEnd(uint64 activity_id) {
|
||||
static void ActivityEnd(int64 activity_id) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// We don't check the level again (see TraceMe::Stop()).
|
||||
if (TF_PREDICT_FALSE(activity_id != kUntracedActivity)) {
|
||||
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
|
||||
TraceMeRecorder::Record({activity_id, /*name=*/std::string(),
|
||||
/*start_time=*/0, GetCurrentTimeNanos()});
|
||||
TraceMeRecorder::Record(
|
||||
{std::string(), -activity_id, GetCurrentTimeNanos()});
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -264,9 +264,9 @@ class TraceMe {
|
||||
static void InstantActivity(NameGeneratorT name_generator, int level = 1) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
|
||||
uint64 now = GetCurrentTimeNanos();
|
||||
TraceMeRecorder::Record({kCompleteActivity, name_generator(),
|
||||
/*start_time=*/now, /*end_time=*/now});
|
||||
int64 now = GetCurrentTimeNanos();
|
||||
TraceMeRecorder::Record(
|
||||
{name_generator(), /*start_time=*/now, /*end_time=*/now});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -279,7 +279,7 @@ class TraceMe {
|
||||
#endif
|
||||
}
|
||||
|
||||
static uint64 NewActivityId() {
|
||||
static int64 NewActivityId() {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
return TraceMeRecorder::NewActivityId();
|
||||
#else
|
||||
@ -288,10 +288,8 @@ class TraceMe {
|
||||
}
|
||||
|
||||
private:
|
||||
// Activity ID or start time used when tracing is disabled.
|
||||
constexpr static uint64 kUntracedActivity = 0;
|
||||
// Activity ID used as a placeholder when both start and end are present.
|
||||
constexpr static uint64 kCompleteActivity = 1;
|
||||
// Start time used when tracing is disabled.
|
||||
constexpr static int64 kUntracedActivity = 0;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TraceMe);
|
||||
|
||||
@ -303,7 +301,7 @@ class TraceMe {
|
||||
std::string name;
|
||||
} no_init_;
|
||||
|
||||
uint64 start_time_ = kUntracedActivity;
|
||||
int64 start_time_ = kUntracedActivity;
|
||||
};
|
||||
|
||||
// Whether OpKernel::TraceString will populate additional information for
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
uint64 GetCurrentTimeNanos() {
|
||||
int64 GetCurrentTimeNanos() {
|
||||
// absl::GetCurrentTimeNanos() is much faster than EnvTime::NowNanos().
|
||||
// It is wrapped under tensorflow::profiler::GetCurrentTimeNanos to avoid ODR
|
||||
// violation and to allow switching to yet another implementation if required.
|
||||
|
@ -38,7 +38,7 @@ inline double MillisToSeconds(double ms) { return ms / 1E3; }
|
||||
inline uint64 SecondsToNanos(double s) { return s * 1E9; }
|
||||
|
||||
// Returns the current CPU wallclock time in nanoseconds.
|
||||
uint64 GetCurrentTimeNanos();
|
||||
int64 GetCurrentTimeNanos();
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user