Remove TraceMeRecorder::Event::activity_id

Reduces TraceMe memory overhead.

PiperOrigin-RevId: 342359162
Change-Id: Idbe185aa6e3eeb030a3788567413d8c676e41e75
This commit is contained in:
Jose Baiocchi 2020-11-13 16:15:59 -08:00 committed by TensorFlower Gardener
parent aa56651979
commit a3b0b2bf41
9 changed files with 84 additions and 86 deletions

View File

@ -109,6 +109,8 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/utils:time_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)

View File

@ -39,14 +39,14 @@ void MakeCompleteEvents(TraceMeRecorder::Events* events) {
// before the record created by ActivityEnd. Cross-thread events must be
// processed in a separate pass. A single map can be used because the
// activity_id is globally unique.
absl::flat_hash_map<uint64, TraceMeRecorder::Event*> start_events;
absl::flat_hash_map<int64, TraceMeRecorder::Event*> start_events;
std::vector<TraceMeRecorder::Event*> end_events;
for (auto& thread : *events) {
for (auto& event : thread.events) {
if (IsStartEvent(event)) {
start_events.emplace(event.activity_id, &event);
} else if (IsEndEvent(event)) {
auto iter = start_events.find(event.activity_id);
if (event.IsStart()) {
start_events.emplace(event.ActivityId(), &event);
} else if (event.IsEnd()) {
auto iter = start_events.find(event.ActivityId());
if (iter != start_events.end()) { // same thread
auto* start_event = iter->second;
event.name = std::move(start_event->name);
@ -59,7 +59,7 @@ void MakeCompleteEvents(TraceMeRecorder::Events* events) {
}
}
for (auto* event : end_events) { // cross-thread
auto iter = start_events.find(event->activity_id);
auto iter = start_events.find(event->ActivityId());
if (iter != start_events.end()) {
auto* start_event = iter->second;
event->name = std::move(start_event->name);
@ -79,7 +79,7 @@ void ConvertCompleteEventsToXPlane(uint64 start_timestamp_ns,
xline.SetTimestampNs(start_timestamp_ns);
xline.ReserveEvents(thread.events.size());
for (const auto& event : thread.events) {
if (!IsCompleteEvent(event)) continue;
if (!event.IsComplete()) continue;
if (event.start_time < start_timestamp_ns) continue;
Annotation annotation = ParseAnnotation(event.name);
XEventMetadata* xevent_metadata =

View File

@ -22,22 +22,6 @@ limitations under the License.
namespace tensorflow {
namespace profiler {
// Returns true if event was created by TraceMe::ActivityStart.
inline bool IsStartEvent(const TraceMeRecorder::Event& event) {
return (event.start_time != 0) && (event.end_time == 0);
}
// Returns true if event was created by TraceMe::ActivityEnd.
inline bool IsEndEvent(const TraceMeRecorder::Event& event) {
return (event.start_time == 0) && (event.end_time != 0);
}
// Returns true if event was created by TraceMe::Stop or MakeCompleteEvents
// below.
inline bool IsCompleteEvent(const TraceMeRecorder::Event& event) {
return (event.start_time != 0) && (event.end_time != 0);
}
// Combine events created by TraceMe::ActivityStart and TraceMe::ActivityEnd,
// which can be paired up by their activity_id.
void MakeCompleteEvents(TraceMeRecorder::Events* events);

View File

@ -254,7 +254,7 @@ bool TraceMeRecorder::StartRecording(int level) {
return started;
}
void TraceMeRecorder::Record(Event event) {
void TraceMeRecorder::Record(Event&& event) {
static thread_local ThreadLocalRecorder thread_local_recorder;
thread_local_recorder.Record(std::move(event));
}
@ -270,16 +270,16 @@ TraceMeRecorder::Events TraceMeRecorder::StopRecording() {
return events;
}
/*static*/ uint64 TraceMeRecorder::NewActivityId() {
/*static*/ int64 TraceMeRecorder::NewActivityId() {
// Activity IDs: To avoid contention over a counter, the top 32 bits identify
// the originating thread, the bottom 32 bits name the event within a thread.
// IDs may be reused after 4 billion events on one thread, or 4 billion
// IDs may be reused after 4 billion events on one thread, or 2 billion
// threads.
static std::atomic<uint32> thread_counter(1); // avoid kUntracedActivity
const thread_local static uint32 thread_id =
static std::atomic<int32> thread_counter(1); // avoid kUntracedActivity
const thread_local static int32 thread_id =
thread_counter.fetch_add(1, std::memory_order_relaxed);
thread_local static uint32 per_thread_activity_id = 0;
return static_cast<uint64>(thread_id) << 32 | per_thread_activity_id++;
return static_cast<int64>(thread_id) << 32 | per_thread_activity_id++;
}
} // namespace profiler

View File

@ -44,18 +44,29 @@ TF_EXPORT extern std::atomic<int> g_trace_level;
//
// This is the backend for TraceMe instrumentation.
// The profiler starts the recorder, the TraceMe destructor records complete
// events. TraceMe::ActivityStart records begin events, and TraceMe::ActivityEnd
// events. TraceMe::ActivityStart records start events, and TraceMe::ActivityEnd
// records end events. The profiler then stops the recorder and finds start/end
// pairs. (Unpaired start/end events are discarded at that point).
class TraceMeRecorder {
public:
// An Event is either the start of a TraceMe, the end of a TraceMe, or both.
// Times are in ns since the Unix epoch.
// A negative time encodes the activity_id used to pair up the start of an
// event with its end.
struct Event {
uint64 activity_id;
bool IsComplete() const { return start_time > 0 && end_time > 0; }
bool IsStart() const { return end_time < 0; }
bool IsEnd() const { return start_time < 0; }
int64 ActivityId() const {
if (IsStart()) return -end_time;
if (IsEnd()) return -start_time;
return 1; // complete
}
std::string name;
uint64 start_time; // 0 = missing
uint64 end_time; // 0 = missing
int64 start_time;
int64 end_time;
};
struct ThreadInfo {
uint32 tid;
@ -85,10 +96,10 @@ class TraceMeRecorder {
static constexpr int kTracingDisabled = -1;
// Records an event. Non-blocking.
static void Record(Event event);
static void Record(Event&& event);
// Returns an activity_id for TraceMe::ActivityStart.
static uint64 NewActivityId();
static int64 NewActivityId();
private:
class ThreadLocalRecorder;

View File

@ -15,12 +15,13 @@ limitations under the License.
#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h"
#include <atomic>
#include <istream>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
namespace tensorflow {
namespace profiler {
@ -41,15 +43,15 @@ MATCHER_P(Named, name, "") { return arg.name == name; }
constexpr static uint64 kNanosInSec = 1000000000;
TEST(RecorderTest, SingleThreaded) {
uint64 start_time = Env::Default()->NowNanos();
uint64 end_time = start_time + kNanosInSec;
int64 start_time = GetCurrentTimeNanos();
int64 end_time = start_time + kNanosInSec;
TraceMeRecorder::Record({1, "before", start_time, end_time});
TraceMeRecorder::Record({"before", start_time, end_time});
TraceMeRecorder::Start(/*level=*/1);
TraceMeRecorder::Record({2, "during1", start_time, end_time});
TraceMeRecorder::Record({3, "during2", start_time, end_time});
TraceMeRecorder::Record({"during1", start_time, end_time});
TraceMeRecorder::Record({"during2", start_time, end_time});
auto results = TraceMeRecorder::Stop();
TraceMeRecorder::Record({4, "after", start_time, end_time});
TraceMeRecorder::Record({"after", start_time, end_time});
ASSERT_EQ(results.size(), 1);
EXPECT_THAT(results[0].events,
@ -83,15 +85,14 @@ TEST(RecorderTest, Multithreaded) {
thread::ThreadPool pool(Env::Default(), "testpool", kNumThreads);
std::atomic<int> thread_count = {0};
for (int i = 0; i < kNumThreads; i++) {
pool.Schedule([&start, &stop, &thread_count, i] {
pool.Schedule([&start, &stop, &thread_count] {
uint64 j = 0;
bool was_active = false;
auto record_event = [&j, i]() {
uint64 start_time = Env::Default()->NowNanos();
uint64 end_time = start_time + kNanosInSec;
TraceMeRecorder::Record({/*activity_id=*/j++,
/*name=*/absl::StrCat(i), start_time,
end_time});
auto record_event = [&j]() {
int64 start_time = GetCurrentTimeNanos();
int64 end_time = start_time + kNanosInSec;
TraceMeRecorder::Record(
{/*name=*/absl::StrCat(j++), start_time, end_time});
};
thread_count.fetch_add(1, std::memory_order_relaxed);
start.WaitForNotification();
@ -121,15 +122,17 @@ TEST(RecorderTest, Multithreaded) {
}
// For each thread, keep track of which events we've seen.
struct {
struct ThreadState {
bool split_session = false;
bool overlapping_sessions = false;
std::set<uint64> events;
} thread_state[kNumThreads];
};
absl::flat_hash_map<uint32 /*tid*/, ThreadState> thread_state;
// We expect each thread to eventually have multiple events, not all in a
// contiguous range.
auto done = [&thread_state] {
for (const auto& t : thread_state) {
for (const auto& id_and_thread : thread_state) {
auto& t = id_and_thread.second;
if (t.events.size() < 2) return false;
}
return true;
@ -153,20 +156,19 @@ TEST(RecorderTest, Multithreaded) {
auto results = TraceMeRecorder::Stop();
for (const auto& thread : results) {
if (thread.events.empty()) continue;
std::istringstream ss(thread.events.front().name);
int thread_index = 0;
ss >> thread_index;
auto& state = thread_state[thread_index];
auto& state = thread_state[thread.thread.tid];
std::set<uint64> session_events;
uint64 current = 0;
for (const auto& event : thread.events) {
session_events.emplace(event.activity_id);
uint64 activity_id;
ASSERT_TRUE(absl::SimpleAtoi(event.name, &activity_id));
session_events.emplace(activity_id);
// Session events should be contiguous.
if (current != 0 && event.activity_id != current + 1) {
if (current != 0 && activity_id != current + 1) {
state.split_session = true;
}
current = event.activity_id;
current = activity_id;
}
for (const auto& event : session_events) {
@ -182,7 +184,8 @@ TEST(RecorderTest, Multithreaded) {
}
stop.Notify();
for (const auto& thread : thread_state) {
for (const auto& id_and_thread : thread_state) {
auto& thread = id_and_thread.second;
EXPECT_FALSE(thread.split_session)
<< "Expected contiguous events in a session";
EXPECT_FALSE(thread.overlapping_sessions) << "Expected disjoint sessions";

View File

@ -174,8 +174,8 @@ class TraceMe {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) {
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
TraceMeRecorder::Record({kCompleteActivity, std::move(no_init_.name),
start_time_, GetCurrentTimeNanos()});
TraceMeRecorder::Record(
{std::move(no_init_.name), start_time_, GetCurrentTimeNanos()});
}
no_init_.name.~string();
start_time_ = kUntracedActivity;
@ -210,12 +210,12 @@ class TraceMe {
// Returns the activity ID, which is used to stop the activity.
// Calls `name_generator` to get the name for activity.
template <typename NameGeneratorT>
static uint64 ActivityStart(NameGeneratorT name_generator, int level = 1) {
static int64 ActivityStart(NameGeneratorT name_generator, int level = 1) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
uint64 activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record({activity_id, name_generator(),
GetCurrentTimeNanos(), /*end_time=*/0});
int64 activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record(
{name_generator(), GetCurrentTimeNanos(), -activity_id});
return activity_id;
}
#endif
@ -224,12 +224,12 @@ class TraceMe {
// Record the start time of an activity.
// Returns the activity ID, which is used to stop the activity.
static uint64 ActivityStart(absl::string_view name, int level = 1) {
static int64 ActivityStart(absl::string_view name, int level = 1) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
uint64 activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record({activity_id, std::string(name),
GetCurrentTimeNanos(), /*end_time=*/0});
int64 activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record(
{std::string(name), GetCurrentTimeNanos(), -activity_id});
return activity_id;
}
#endif
@ -237,23 +237,23 @@ class TraceMe {
}
// Same as ActivityStart above, an overload for "const std::string&"
static uint64 ActivityStart(const std::string& name, int level = 1) {
static int64 ActivityStart(const std::string& name, int level = 1) {
return ActivityStart(absl::string_view(name), level);
}
// Same as ActivityStart above, an overload for "const char*"
static uint64 ActivityStart(const char* name, int level = 1) {
static int64 ActivityStart(const char* name, int level = 1) {
return ActivityStart(absl::string_view(name), level);
}
// Record the end time of an activity started by ActivityStart().
static void ActivityEnd(uint64 activity_id) {
static void ActivityEnd(int64 activity_id) {
#if !defined(IS_MOBILE_PLATFORM)
// We don't check the level again (see TraceMe::Stop()).
if (TF_PREDICT_FALSE(activity_id != kUntracedActivity)) {
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
TraceMeRecorder::Record({activity_id, /*name=*/std::string(),
/*start_time=*/0, GetCurrentTimeNanos()});
TraceMeRecorder::Record(
{std::string(), -activity_id, GetCurrentTimeNanos()});
}
}
#endif
@ -264,9 +264,9 @@ class TraceMe {
static void InstantActivity(NameGeneratorT name_generator, int level = 1) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
uint64 now = GetCurrentTimeNanos();
TraceMeRecorder::Record({kCompleteActivity, name_generator(),
/*start_time=*/now, /*end_time=*/now});
int64 now = GetCurrentTimeNanos();
TraceMeRecorder::Record(
{name_generator(), /*start_time=*/now, /*end_time=*/now});
}
#endif
}
@ -279,7 +279,7 @@ class TraceMe {
#endif
}
static uint64 NewActivityId() {
static int64 NewActivityId() {
#if !defined(IS_MOBILE_PLATFORM)
return TraceMeRecorder::NewActivityId();
#else
@ -288,10 +288,8 @@ class TraceMe {
}
private:
// Activity ID or start time used when tracing is disabled.
constexpr static uint64 kUntracedActivity = 0;
// Activity ID used as a placeholder when both start and end are present.
constexpr static uint64 kCompleteActivity = 1;
// Start time used when tracing is disabled.
constexpr static int64 kUntracedActivity = 0;
TF_DISALLOW_COPY_AND_ASSIGN(TraceMe);
@ -303,7 +301,7 @@ class TraceMe {
std::string name;
} no_init_;
uint64 start_time_ = kUntracedActivity;
int64 start_time_ = kUntracedActivity;
};
// Whether OpKernel::TraceString will populate additional information for

View File

@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace profiler {
uint64 GetCurrentTimeNanos() {
int64 GetCurrentTimeNanos() {
// absl::GetCurrentTimeNanos() is much faster than EnvTime::NowNanos().
// It is wrapped under tensorflow::profiler::GetCurrentTimeNanos to avoid ODR
// violation and to allow switching to yet another implementation if required.

View File

@ -38,7 +38,7 @@ inline double MillisToSeconds(double ms) { return ms / 1E3; }
inline uint64 SecondsToNanos(double s) { return s * 1E9; }
// Returns the current CPU wallclock time in nanoseconds.
uint64 GetCurrentTimeNanos();
int64 GetCurrentTimeNanos();
} // namespace profiler
} // namespace tensorflow