XStatsOwner::GetStat returns optional<XStatVisitor>

PiperOrigin-RevId: 315305610
Change-Id: I14ea35d662c585789dc63877d1938ed250ca32fd
This commit is contained in:
Jose Baiocchi 2020-06-08 10:40:42 -07:00 committed by TensorFlower Gardener
parent cb8342c4eb
commit ee4f27a524
8 changed files with 95 additions and 96 deletions

View File

@ -148,11 +148,10 @@ void CollectTfActivities(const XLineVisitor& line,
if (tf_op != nullptr) {
++tf_op_id;
bool is_eager = false;
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kIsEager) {
is_eager = stat.IntValue();
}
});
if (absl::optional<XStatVisitor> stat =
event.GetStat(StatType::kIsEager)) {
is_eager = stat->IntValue();
}
Timespan span(event.TimestampPs(), event.DurationPs());
tf_activities->push_back(
{span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});

View File

@ -49,25 +49,29 @@ namespace {
DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
DeviceCapabilities cap;
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane);
if (auto clock_rate_khz = plane.GetStats(kDevCapClockRateKHz)) {
cap.set_clock_rate_in_ghz(clock_rate_khz->int64_value() / 1000000.0);
}
if (auto core_count = plane.GetStats(kDevCapCoreCount)) {
cap.set_num_cores(core_count->int64_value());
}
// Set memory bandwidth in bytes/s.
if (auto memory_bw = plane.GetStats(kDevCapMemoryBandwidth)) {
cap.set_memory_bandwidth(memory_bw->int64_value());
}
if (auto memory_size_in_bytes = plane.GetStats(kDevCapMemorySize)) {
cap.set_memory_size_in_bytes(memory_size_in_bytes->uint64_value());
}
if (auto cap_major = plane.GetStats(kDevCapComputeCapMajor)) {
cap.mutable_compute_capability()->set_major(cap_major->int64_value());
}
if (auto cap_minor = plane.GetStats(kDevCapComputeCapMinor)) {
cap.mutable_compute_capability()->set_minor(cap_minor->int64_value());
}
plane.ForEachStat([&cap](const XStatVisitor& stat) {
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case kDevCapClockRateKHz:
cap.set_clock_rate_in_ghz(stat.IntValue() / 1000000.0);
break;
case kDevCapCoreCount:
cap.set_num_cores(stat.IntValue());
break;
case kDevCapMemoryBandwidth:
cap.set_memory_bandwidth(stat.IntValue()); // bytes/s
break;
case kDevCapMemorySize:
cap.set_memory_size_in_bytes(stat.UintValue());
break;
case kDevCapComputeCapMajor:
cap.mutable_compute_capability()->set_major(stat.IntValue());
break;
case kDevCapComputeCapMinor:
cap.mutable_compute_capability()->set_minor(stat.IntValue());
break;
}
});
return cap;
}

View File

@ -112,14 +112,11 @@ StepEvents ConvertHostThreadsXPlaneToStepEvents(
StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
StepEvents result;
line.ForEachEvent([&](const XEventVisitor& event) {
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kGroupId) {
result[stat.IntValue()].AddMarker(
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
Timespan(event.TimestampPs(), event.DurationPs())));
return;
}
});
if (absl::optional<XStatVisitor> stat = event.GetStat(StatType::kGroupId)) {
result[stat->IntValue()].AddMarker(
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
Timespan(event.TimestampPs(), event.DurationPs())));
}
});
return result;
}

View File

@ -274,12 +274,12 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
EXPECT_EQ(device_plane->event_metadata_size(), 4);
// Check if device capacity is serialized.
XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane);
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr);
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr);
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr);
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr);
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr);
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr);
EXPECT_TRUE(plane.GetStat(kDevCapClockRateKHz).has_value());
EXPECT_TRUE(plane.GetStat(kDevCapCoreCount).has_value());
EXPECT_TRUE(plane.GetStat(kDevCapMemoryBandwidth).has_value());
EXPECT_TRUE(plane.GetStat(kDevCapMemorySize).has_value());
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMajor).has_value());
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMinor).has_value());
// Check if the device events timestamps are set.
int total_events = 0;

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -225,30 +224,30 @@ EventNode::EventNode(const EventNode& event_node)
: EventNode(event_node.plane_, event_node.raw_line_,
event_node.raw_event_) {}
const XStat* EventNode::GetContextStat(int64 stat_type) const {
if (const XStat* stat = visitor_.GetStats(stat_type)) {
return stat;
} else if (parent_) {
return parent_->GetContextStat(stat_type);
absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
for (const EventNode* node = this; node != nullptr; node = node->parent_) {
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
return stat;
}
}
return nullptr;
return absl::nullopt;
}
std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
XStatVisitor stat(plane_, graph_type_stat);
name_parts.push_back(stat.ToString());
std::string name;
if (absl::optional<XStatVisitor> stat =
GetContextStat(StatType::kGraphType)) {
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
}
int64 step_num = group_id_.value_or(0);
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) {
step_num = step_num_stat->int64_value();
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
step_num = stat->IntValue();
} else if (absl::optional<XStatVisitor> stat =
GetContextStat(StatType::kStepNum)) {
step_num = stat->IntValue();
}
if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) {
step_num = iter_num_stat->int64_value();
}
name_parts.push_back(absl::StrCat(step_num));
return absl::StrJoin(name_parts, " ");
absl::StrAppend(&name, step_num);
return name;
}
void EventNode::PropagateGroupId(int64 group_id) {
@ -343,11 +342,12 @@ void EventForest::ConnectInterThread(
for (const auto& parent_event_node : *parent_event_node_list) {
std::vector<int64> stats;
for (auto stat_type : parent_stat_types) {
const XStat* stat = parent_event_node->GetContextStat(stat_type);
absl::optional<XStatVisitor> stat =
parent_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value
? stat->int64_value()
: stat->uint64_value());
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
? stat->IntValue()
: stat->UintValue());
}
if (stats.size() == parent_stat_types.size()) {
connect_map[stats] = parent_event_node.get();
@ -359,11 +359,12 @@ void EventForest::ConnectInterThread(
for (const auto& child_event_node : *child_event_node_list) {
std::vector<int64> stats;
for (auto stat_type : *child_stat_types) {
const XStat* stat = child_event_node->GetContextStat(stat_type);
absl::optional<XStatVisitor> stat =
child_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value
? stat->int64_value()
: stat->uint64_value());
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
? stat->IntValue()
: stat->UintValue());
}
if (stats.size() == child_stat_types->size()) {
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
@ -429,14 +430,14 @@ void EventForest::ProcessTensorFlowLoop() {
if (!executor_event_list) return;
for (auto& executor_event : *executor_event_list) {
if (IsTfDataEvent(*executor_event)) continue;
const XStat* step_id_stat =
absl::optional<XStatVisitor> step_id_stat =
executor_event->GetContextStat(StatType::kStepId);
const XStat* iter_num_stat =
absl::optional<XStatVisitor> iter_num_stat =
executor_event->GetContextStat(StatType::kIterNum);
if (!step_id_stat || !iter_num_stat) continue;
int64 step_id = step_id_stat->int64_value();
int64 step_id = step_id_stat->IntValue();
TensorFlowLoop& tf_loop = tf_loops[step_id];
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->int64_value()];
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
if (!iteration.first_event ||
executor_event->StartsBefore(*iteration.first_event)) {
iteration.first_event = executor_event.get();

View File

@ -78,7 +78,7 @@ class EventNode {
const XEventVisitor& GetEventVisitor() const { return visitor_; }
const XStat* GetContextStat(int64 stat_type) const;
absl::optional<XStatVisitor> GetContextStat(int64 stat_type) const;
void AddStepName(absl::string_view step_name);

View File

@ -174,12 +174,10 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id;
event.ForEachStat(
[&](const tensorflow::profiler::XStatVisitor& stat) {
if (stat.Type() == StatType::kGroupId) {
group_id = stat.IntValue();
}
});
if (absl::optional<XStatVisitor> stat =
event.GetStat(StatType::kGroupId)) {
group_id = stat->IntValue();
}
EXPECT_TRUE(group_id.has_value());
EXPECT_EQ(*group_id, 0);
});
@ -305,12 +303,10 @@ TEST(GroupEventsTest, SemanticArgTest) {
line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id;
event.ForEachStat(
[&](const tensorflow::profiler::XStatVisitor& stat) {
if (stat.Type() == StatType::kGroupId) {
group_id = stat.IntValue();
}
});
if (absl::optional<XStatVisitor> stat =
event.GetStat(StatType::kGroupId)) {
group_id = stat->IntValue();
}
EXPECT_TRUE(group_id.has_value());
EXPECT_EQ(*group_id, 0);
});
@ -339,12 +335,10 @@ TEST(GroupEventsTest, AsyncEventTest) {
line.ForEachEvent(
[&](const tensorflow::profiler::XEventVisitor& event) {
absl::optional<int64> group_id;
event.ForEachStat(
[&](const tensorflow::profiler::XStatVisitor& stat) {
if (stat.Type() == StatType::kGroupId) {
group_id = stat.IntValue();
}
});
if (absl::optional<XStatVisitor> stat =
event.GetStat(StatType::kGroupId)) {
group_id = stat->IntValue();
}
if (event.Name() == kAsync) {
EXPECT_FALSE(group_id.has_value());
} else {

View File

@ -86,8 +86,10 @@ class XStatsOwner {
}
}
// Shortcut to get a specfic stat type, nullptr if it is absent.
const XStat* GetStats(int64 stat_type) const;
// Shortcut to get a specific stat type, nullopt if absent.
// This function performs a linear search for the requested stat value.
// Prefer ForEachStat above when multiple stat values are necessary.
absl::optional<XStatVisitor> GetStat(int64 stat_type) const;
private:
const T* stats_owner_;
@ -241,14 +243,16 @@ class XPlaneVisitor : public XStatsOwner<XPlane> {
};
template <class T>
const XStat* XStatsOwner<T>::GetStats(int64 stat_type) const {
absl::optional<int64> stat_metadata_id =
metadata_->GetStatMetadataId(stat_type);
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane.
for (const XStat& stat : stats_owner_->stats()) {
if (stat.metadata_id() == *stat_metadata_id) return &stat;
absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const {
if (absl::optional<int64> stat_metadata_id =
metadata_->GetStatMetadataId(stat_type)) {
for (const XStat& stat : stats_owner_->stats()) {
if (stat.metadata_id() == *stat_metadata_id) {
return XStatVisitor(metadata_, &stat);
}
}
}
return nullptr; // type does not exist in this owner.
return absl::nullopt; // type does not exist in this owner.
}
} // namespace profiler