XStatsOwner::GetStat returns optional<XStatVisitor>

PiperOrigin-RevId: 315305610
Change-Id: I14ea35d662c585789dc63877d1938ed250ca32fd
This commit is contained in:
Jose Baiocchi 2020-06-08 10:40:42 -07:00 committed by TensorFlower Gardener
parent cb8342c4eb
commit ee4f27a524
8 changed files with 95 additions and 96 deletions

View File

@ -148,11 +148,10 @@ void CollectTfActivities(const XLineVisitor& line,
if (tf_op != nullptr) { if (tf_op != nullptr) {
++tf_op_id; ++tf_op_id;
bool is_eager = false; bool is_eager = false;
event.ForEachStat([&](const XStatVisitor& stat) { if (absl::optional<XStatVisitor> stat =
if (stat.Type() == StatType::kIsEager) { event.GetStat(StatType::kIsEager)) {
is_eager = stat.IntValue(); is_eager = stat->IntValue();
} }
});
Timespan span(event.TimestampPs(), event.DurationPs()); Timespan span(event.TimestampPs(), event.DurationPs());
tf_activities->push_back( tf_activities->push_back(
{span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});

View File

@ -49,25 +49,29 @@ namespace {
DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) { DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
DeviceCapabilities cap; DeviceCapabilities cap;
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane); XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane);
if (auto clock_rate_khz = plane.GetStats(kDevCapClockRateKHz)) { plane.ForEachStat([&cap](const XStatVisitor& stat) {
cap.set_clock_rate_in_ghz(clock_rate_khz->int64_value() / 1000000.0); if (!stat.Type().has_value()) return;
} switch (stat.Type().value()) {
if (auto core_count = plane.GetStats(kDevCapCoreCount)) { case kDevCapClockRateKHz:
cap.set_num_cores(core_count->int64_value()); cap.set_clock_rate_in_ghz(stat.IntValue() / 1000000.0);
} break;
// Set memory bandwidth in bytes/s. case kDevCapCoreCount:
if (auto memory_bw = plane.GetStats(kDevCapMemoryBandwidth)) { cap.set_num_cores(stat.IntValue());
cap.set_memory_bandwidth(memory_bw->int64_value()); break;
} case kDevCapMemoryBandwidth:
if (auto memory_size_in_bytes = plane.GetStats(kDevCapMemorySize)) { cap.set_memory_bandwidth(stat.IntValue()); // bytes/s
cap.set_memory_size_in_bytes(memory_size_in_bytes->uint64_value()); break;
} case kDevCapMemorySize:
if (auto cap_major = plane.GetStats(kDevCapComputeCapMajor)) { cap.set_memory_size_in_bytes(stat.UintValue());
cap.mutable_compute_capability()->set_major(cap_major->int64_value()); break;
} case kDevCapComputeCapMajor:
if (auto cap_minor = plane.GetStats(kDevCapComputeCapMinor)) { cap.mutable_compute_capability()->set_major(stat.IntValue());
cap.mutable_compute_capability()->set_minor(cap_minor->int64_value()); break;
} case kDevCapComputeCapMinor:
cap.mutable_compute_capability()->set_minor(stat.IntValue());
break;
}
});
return cap; return cap;
} }

View File

@ -112,14 +112,11 @@ StepEvents ConvertHostThreadsXPlaneToStepEvents(
StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) { StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
StepEvents result; StepEvents result;
line.ForEachEvent([&](const XEventVisitor& event) { line.ForEachEvent([&](const XEventVisitor& event) {
event.ForEachStat([&](const XStatVisitor& stat) { if (absl::optional<XStatVisitor> stat = event.GetStat(StatType::kGroupId)) {
if (stat.Type() == StatType::kGroupId) { result[stat->IntValue()].AddMarker(
result[stat.IntValue()].AddMarker( StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(), Timespan(event.TimestampPs(), event.DurationPs())));
Timespan(event.TimestampPs(), event.DurationPs()))); }
return;
}
});
}); });
return result; return result;
} }

View File

@ -274,12 +274,12 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
EXPECT_EQ(device_plane->event_metadata_size(), 4); EXPECT_EQ(device_plane->event_metadata_size(), 4);
// Check if device capacity is serialized. // Check if device capacity is serialized.
XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane); XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane);
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapClockRateKHz).has_value());
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapCoreCount).has_value());
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapMemoryBandwidth).has_value());
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapMemorySize).has_value());
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMajor).has_value());
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr); EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMinor).has_value());
// Check if the device events timestamps are set. // Check if the device events timestamps are set.
int total_events = 0; int total_events = 0;

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
@ -225,30 +224,30 @@ EventNode::EventNode(const EventNode& event_node)
: EventNode(event_node.plane_, event_node.raw_line_, : EventNode(event_node.plane_, event_node.raw_line_,
event_node.raw_event_) {} event_node.raw_event_) {}
const XStat* EventNode::GetContextStat(int64 stat_type) const { absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
if (const XStat* stat = visitor_.GetStats(stat_type)) { for (const EventNode* node = this; node != nullptr; node = node->parent_) {
return stat; if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
} else if (parent_) { return stat;
return parent_->GetContextStat(stat_type); }
} }
return nullptr; return absl::nullopt;
} }
std::string EventNode::GetGroupName() const { std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts; std::string name;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) { if (absl::optional<XStatVisitor> stat =
XStatVisitor stat(plane_, graph_type_stat); GetContextStat(StatType::kGraphType)) {
name_parts.push_back(stat.ToString()); absl::StrAppend(&name, stat->StrOrRefValue(), " ");
} }
int64 step_num = group_id_.value_or(0); int64 step_num = group_id_.value_or(0);
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) { if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
step_num = step_num_stat->int64_value(); step_num = stat->IntValue();
} else if (absl::optional<XStatVisitor> stat =
GetContextStat(StatType::kStepNum)) {
step_num = stat->IntValue();
} }
if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) { absl::StrAppend(&name, step_num);
step_num = iter_num_stat->int64_value(); return name;
}
name_parts.push_back(absl::StrCat(step_num));
return absl::StrJoin(name_parts, " ");
} }
void EventNode::PropagateGroupId(int64 group_id) { void EventNode::PropagateGroupId(int64 group_id) {
@ -343,11 +342,12 @@ void EventForest::ConnectInterThread(
for (const auto& parent_event_node : *parent_event_node_list) { for (const auto& parent_event_node : *parent_event_node_list) {
std::vector<int64> stats; std::vector<int64> stats;
for (auto stat_type : parent_stat_types) { for (auto stat_type : parent_stat_types) {
const XStat* stat = parent_event_node->GetContextStat(stat_type); absl::optional<XStatVisitor> stat =
parent_event_node->GetContextStat(stat_type);
if (!stat) break; if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value stats.push_back((stat->ValueCase() == XStat::kInt64Value)
? stat->int64_value() ? stat->IntValue()
: stat->uint64_value()); : stat->UintValue());
} }
if (stats.size() == parent_stat_types.size()) { if (stats.size() == parent_stat_types.size()) {
connect_map[stats] = parent_event_node.get(); connect_map[stats] = parent_event_node.get();
@ -359,11 +359,12 @@ void EventForest::ConnectInterThread(
for (const auto& child_event_node : *child_event_node_list) { for (const auto& child_event_node : *child_event_node_list) {
std::vector<int64> stats; std::vector<int64> stats;
for (auto stat_type : *child_stat_types) { for (auto stat_type : *child_stat_types) {
const XStat* stat = child_event_node->GetContextStat(stat_type); absl::optional<XStatVisitor> stat =
child_event_node->GetContextStat(stat_type);
if (!stat) break; if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value stats.push_back((stat->ValueCase() == XStat::kInt64Value)
? stat->int64_value() ? stat->IntValue()
: stat->uint64_value()); : stat->UintValue());
} }
if (stats.size() == child_stat_types->size()) { if (stats.size() == child_stat_types->size()) {
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) { if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
@ -429,14 +430,14 @@ void EventForest::ProcessTensorFlowLoop() {
if (!executor_event_list) return; if (!executor_event_list) return;
for (auto& executor_event : *executor_event_list) { for (auto& executor_event : *executor_event_list) {
if (IsTfDataEvent(*executor_event)) continue; if (IsTfDataEvent(*executor_event)) continue;
const XStat* step_id_stat = absl::optional<XStatVisitor> step_id_stat =
executor_event->GetContextStat(StatType::kStepId); executor_event->GetContextStat(StatType::kStepId);
const XStat* iter_num_stat = absl::optional<XStatVisitor> iter_num_stat =
executor_event->GetContextStat(StatType::kIterNum); executor_event->GetContextStat(StatType::kIterNum);
if (!step_id_stat || !iter_num_stat) continue; if (!step_id_stat || !iter_num_stat) continue;
int64 step_id = step_id_stat->int64_value(); int64 step_id = step_id_stat->IntValue();
TensorFlowLoop& tf_loop = tf_loops[step_id]; TensorFlowLoop& tf_loop = tf_loops[step_id];
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->int64_value()]; TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
if (!iteration.first_event || if (!iteration.first_event ||
executor_event->StartsBefore(*iteration.first_event)) { executor_event->StartsBefore(*iteration.first_event)) {
iteration.first_event = executor_event.get(); iteration.first_event = executor_event.get();

View File

@ -78,7 +78,7 @@ class EventNode {
const XEventVisitor& GetEventVisitor() const { return visitor_; } const XEventVisitor& GetEventVisitor() const { return visitor_; }
const XStat* GetContextStat(int64 stat_type) const; absl::optional<XStatVisitor> GetContextStat(int64 stat_type) const;
void AddStepName(absl::string_view step_name); void AddStepName(absl::string_view step_name);

View File

@ -174,12 +174,10 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
line.ForEachEvent( line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) { [&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id; absl::optional<int64> group_id;
event.ForEachStat( if (absl::optional<XStatVisitor> stat =
[&](const tensorflow::profiler::XStatVisitor& stat) { event.GetStat(StatType::kGroupId)) {
if (stat.Type() == StatType::kGroupId) { group_id = stat->IntValue();
group_id = stat.IntValue(); }
}
});
EXPECT_TRUE(group_id.has_value()); EXPECT_TRUE(group_id.has_value());
EXPECT_EQ(*group_id, 0); EXPECT_EQ(*group_id, 0);
}); });
@ -305,12 +303,10 @@ TEST(GroupEventsTest, SemanticArgTest) {
line.ForEachEvent( line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) { [&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id; absl::optional<int64> group_id;
event.ForEachStat( if (absl::optional<XStatVisitor> stat =
[&](const tensorflow::profiler::XStatVisitor& stat) { event.GetStat(StatType::kGroupId)) {
if (stat.Type() == StatType::kGroupId) { group_id = stat->IntValue();
group_id = stat.IntValue(); }
}
});
EXPECT_TRUE(group_id.has_value()); EXPECT_TRUE(group_id.has_value());
EXPECT_EQ(*group_id, 0); EXPECT_EQ(*group_id, 0);
}); });
@ -339,12 +335,10 @@ TEST(GroupEventsTest, AsyncEventTest) {
line.ForEachEvent( line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) { [&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id; absl::optional<int64> group_id;
event.ForEachStat( if (absl::optional<XStatVisitor> stat =
[&](const tensorflow::profiler::XStatVisitor& stat) { event.GetStat(StatType::kGroupId)) {
if (stat.Type() == StatType::kGroupId) { group_id = stat->IntValue();
group_id = stat.IntValue(); }
}
});
if (event.Name() == kAsync) { if (event.Name() == kAsync) {
EXPECT_FALSE(group_id.has_value()); EXPECT_FALSE(group_id.has_value());
} else { } else {

View File

@ -86,8 +86,10 @@ class XStatsOwner {
} }
} }
// Shortcut to get a specfic stat type, nullptr if it is absent. // Shortcut to get a specific stat type, nullopt if absent.
const XStat* GetStats(int64 stat_type) const; // This function performs a linear search for the requested stat value.
// Prefer ForEachStat above when multiple stat values are necessary.
absl::optional<XStatVisitor> GetStat(int64 stat_type) const;
private: private:
const T* stats_owner_; const T* stats_owner_;
@ -241,14 +243,16 @@ class XPlaneVisitor : public XStatsOwner<XPlane> {
}; };
template <class T> template <class T>
const XStat* XStatsOwner<T>::GetStats(int64 stat_type) const { absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const {
absl::optional<int64> stat_metadata_id = if (absl::optional<int64> stat_metadata_id =
metadata_->GetStatMetadataId(stat_type); metadata_->GetStatMetadataId(stat_type)) {
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane. for (const XStat& stat : stats_owner_->stats()) {
for (const XStat& stat : stats_owner_->stats()) { if (stat.metadata_id() == *stat_metadata_id) {
if (stat.metadata_id() == *stat_metadata_id) return &stat; return XStatVisitor(metadata_, &stat);
}
}
} }
return nullptr; // type does not exist in this owner. return absl::nullopt; // type does not exist in this owner.
} }
} // namespace profiler } // namespace profiler