XStatsOwner::GetStat returns optional<XStatVisitor>
PiperOrigin-RevId: 315305610 Change-Id: I14ea35d662c585789dc63877d1938ed250ca32fd
This commit is contained in:
parent
cb8342c4eb
commit
ee4f27a524
|
@ -148,11 +148,10 @@ void CollectTfActivities(const XLineVisitor& line,
|
|||
if (tf_op != nullptr) {
|
||||
++tf_op_id;
|
||||
bool is_eager = false;
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kIsEager) {
|
||||
is_eager = stat.IntValue();
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
event.GetStat(StatType::kIsEager)) {
|
||||
is_eager = stat->IntValue();
|
||||
}
|
||||
});
|
||||
Timespan span(event.TimestampPs(), event.DurationPs());
|
||||
tf_activities->push_back(
|
||||
{span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});
|
||||
|
|
|
@ -49,25 +49,29 @@ namespace {
|
|||
DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
|
||||
DeviceCapabilities cap;
|
||||
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane);
|
||||
if (auto clock_rate_khz = plane.GetStats(kDevCapClockRateKHz)) {
|
||||
cap.set_clock_rate_in_ghz(clock_rate_khz->int64_value() / 1000000.0);
|
||||
}
|
||||
if (auto core_count = plane.GetStats(kDevCapCoreCount)) {
|
||||
cap.set_num_cores(core_count->int64_value());
|
||||
}
|
||||
// Set memory bandwidth in bytes/s.
|
||||
if (auto memory_bw = plane.GetStats(kDevCapMemoryBandwidth)) {
|
||||
cap.set_memory_bandwidth(memory_bw->int64_value());
|
||||
}
|
||||
if (auto memory_size_in_bytes = plane.GetStats(kDevCapMemorySize)) {
|
||||
cap.set_memory_size_in_bytes(memory_size_in_bytes->uint64_value());
|
||||
}
|
||||
if (auto cap_major = plane.GetStats(kDevCapComputeCapMajor)) {
|
||||
cap.mutable_compute_capability()->set_major(cap_major->int64_value());
|
||||
}
|
||||
if (auto cap_minor = plane.GetStats(kDevCapComputeCapMinor)) {
|
||||
cap.mutable_compute_capability()->set_minor(cap_minor->int64_value());
|
||||
plane.ForEachStat([&cap](const XStatVisitor& stat) {
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case kDevCapClockRateKHz:
|
||||
cap.set_clock_rate_in_ghz(stat.IntValue() / 1000000.0);
|
||||
break;
|
||||
case kDevCapCoreCount:
|
||||
cap.set_num_cores(stat.IntValue());
|
||||
break;
|
||||
case kDevCapMemoryBandwidth:
|
||||
cap.set_memory_bandwidth(stat.IntValue()); // bytes/s
|
||||
break;
|
||||
case kDevCapMemorySize:
|
||||
cap.set_memory_size_in_bytes(stat.UintValue());
|
||||
break;
|
||||
case kDevCapComputeCapMajor:
|
||||
cap.mutable_compute_capability()->set_major(stat.IntValue());
|
||||
break;
|
||||
case kDevCapComputeCapMinor:
|
||||
cap.mutable_compute_capability()->set_minor(stat.IntValue());
|
||||
break;
|
||||
}
|
||||
});
|
||||
return cap;
|
||||
}
|
||||
|
||||
|
|
|
@ -112,15 +112,12 @@ StepEvents ConvertHostThreadsXPlaneToStepEvents(
|
|||
StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
|
||||
StepEvents result;
|
||||
line.ForEachEvent([&](const XEventVisitor& event) {
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kGroupId) {
|
||||
result[stat.IntValue()].AddMarker(
|
||||
if (absl::optional<XStatVisitor> stat = event.GetStat(StatType::kGroupId)) {
|
||||
result[stat->IntValue()].AddMarker(
|
||||
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
|
||||
Timespan(event.TimestampPs(), event.DurationPs())));
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -274,12 +274,12 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
|
|||
EXPECT_EQ(device_plane->event_metadata_size(), 4);
|
||||
// Check if device capacity is serialized.
|
||||
XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane);
|
||||
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr);
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapClockRateKHz).has_value());
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapCoreCount).has_value());
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapMemoryBandwidth).has_value());
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapMemorySize).has_value());
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMajor).has_value());
|
||||
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMinor).has_value());
|
||||
|
||||
// Check if the device events timestamps are set.
|
||||
int total_events = 0;
|
||||
|
|
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
@ -225,30 +224,30 @@ EventNode::EventNode(const EventNode& event_node)
|
|||
: EventNode(event_node.plane_, event_node.raw_line_,
|
||||
event_node.raw_event_) {}
|
||||
|
||||
const XStat* EventNode::GetContextStat(int64 stat_type) const {
|
||||
if (const XStat* stat = visitor_.GetStats(stat_type)) {
|
||||
absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
|
||||
for (const EventNode* node = this; node != nullptr; node = node->parent_) {
|
||||
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
|
||||
return stat;
|
||||
} else if (parent_) {
|
||||
return parent_->GetContextStat(stat_type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
std::string EventNode::GetGroupName() const {
|
||||
std::vector<std::string> name_parts;
|
||||
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
|
||||
XStatVisitor stat(plane_, graph_type_stat);
|
||||
name_parts.push_back(stat.ToString());
|
||||
std::string name;
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
GetContextStat(StatType::kGraphType)) {
|
||||
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
|
||||
}
|
||||
int64 step_num = group_id_.value_or(0);
|
||||
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) {
|
||||
step_num = step_num_stat->int64_value();
|
||||
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
|
||||
step_num = stat->IntValue();
|
||||
} else if (absl::optional<XStatVisitor> stat =
|
||||
GetContextStat(StatType::kStepNum)) {
|
||||
step_num = stat->IntValue();
|
||||
}
|
||||
if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) {
|
||||
step_num = iter_num_stat->int64_value();
|
||||
}
|
||||
name_parts.push_back(absl::StrCat(step_num));
|
||||
return absl::StrJoin(name_parts, " ");
|
||||
absl::StrAppend(&name, step_num);
|
||||
return name;
|
||||
}
|
||||
|
||||
void EventNode::PropagateGroupId(int64 group_id) {
|
||||
|
@ -343,11 +342,12 @@ void EventForest::ConnectInterThread(
|
|||
for (const auto& parent_event_node : *parent_event_node_list) {
|
||||
std::vector<int64> stats;
|
||||
for (auto stat_type : parent_stat_types) {
|
||||
const XStat* stat = parent_event_node->GetContextStat(stat_type);
|
||||
absl::optional<XStatVisitor> stat =
|
||||
parent_event_node->GetContextStat(stat_type);
|
||||
if (!stat) break;
|
||||
stats.push_back(stat->value_case() == stat->kInt64Value
|
||||
? stat->int64_value()
|
||||
: stat->uint64_value());
|
||||
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
|
||||
? stat->IntValue()
|
||||
: stat->UintValue());
|
||||
}
|
||||
if (stats.size() == parent_stat_types.size()) {
|
||||
connect_map[stats] = parent_event_node.get();
|
||||
|
@ -359,11 +359,12 @@ void EventForest::ConnectInterThread(
|
|||
for (const auto& child_event_node : *child_event_node_list) {
|
||||
std::vector<int64> stats;
|
||||
for (auto stat_type : *child_stat_types) {
|
||||
const XStat* stat = child_event_node->GetContextStat(stat_type);
|
||||
absl::optional<XStatVisitor> stat =
|
||||
child_event_node->GetContextStat(stat_type);
|
||||
if (!stat) break;
|
||||
stats.push_back(stat->value_case() == stat->kInt64Value
|
||||
? stat->int64_value()
|
||||
: stat->uint64_value());
|
||||
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
|
||||
? stat->IntValue()
|
||||
: stat->UintValue());
|
||||
}
|
||||
if (stats.size() == child_stat_types->size()) {
|
||||
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
|
||||
|
@ -429,14 +430,14 @@ void EventForest::ProcessTensorFlowLoop() {
|
|||
if (!executor_event_list) return;
|
||||
for (auto& executor_event : *executor_event_list) {
|
||||
if (IsTfDataEvent(*executor_event)) continue;
|
||||
const XStat* step_id_stat =
|
||||
absl::optional<XStatVisitor> step_id_stat =
|
||||
executor_event->GetContextStat(StatType::kStepId);
|
||||
const XStat* iter_num_stat =
|
||||
absl::optional<XStatVisitor> iter_num_stat =
|
||||
executor_event->GetContextStat(StatType::kIterNum);
|
||||
if (!step_id_stat || !iter_num_stat) continue;
|
||||
int64 step_id = step_id_stat->int64_value();
|
||||
int64 step_id = step_id_stat->IntValue();
|
||||
TensorFlowLoop& tf_loop = tf_loops[step_id];
|
||||
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->int64_value()];
|
||||
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
|
||||
if (!iteration.first_event ||
|
||||
executor_event->StartsBefore(*iteration.first_event)) {
|
||||
iteration.first_event = executor_event.get();
|
||||
|
|
|
@ -78,7 +78,7 @@ class EventNode {
|
|||
|
||||
const XEventVisitor& GetEventVisitor() const { return visitor_; }
|
||||
|
||||
const XStat* GetContextStat(int64 stat_type) const;
|
||||
absl::optional<XStatVisitor> GetContextStat(int64 stat_type) const;
|
||||
|
||||
void AddStepName(absl::string_view step_name);
|
||||
|
||||
|
|
|
@ -174,12 +174,10 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
|
|||
line.ForEachEvent(
|
||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||
absl::optional<int64> group_id;
|
||||
event.ForEachStat(
|
||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kGroupId) {
|
||||
group_id = stat.IntValue();
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
event.GetStat(StatType::kGroupId)) {
|
||||
group_id = stat->IntValue();
|
||||
}
|
||||
});
|
||||
EXPECT_TRUE(group_id.has_value());
|
||||
EXPECT_EQ(*group_id, 0);
|
||||
});
|
||||
|
@ -305,12 +303,10 @@ TEST(GroupEventsTest, SemanticArgTest) {
|
|||
line.ForEachEvent(
|
||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||
absl::optional<int64> group_id;
|
||||
event.ForEachStat(
|
||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kGroupId) {
|
||||
group_id = stat.IntValue();
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
event.GetStat(StatType::kGroupId)) {
|
||||
group_id = stat->IntValue();
|
||||
}
|
||||
});
|
||||
EXPECT_TRUE(group_id.has_value());
|
||||
EXPECT_EQ(*group_id, 0);
|
||||
});
|
||||
|
@ -339,12 +335,10 @@ TEST(GroupEventsTest, AsyncEventTest) {
|
|||
line.ForEachEvent(
|
||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||
absl::optional<int64> group_id;
|
||||
event.ForEachStat(
|
||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kGroupId) {
|
||||
group_id = stat.IntValue();
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
event.GetStat(StatType::kGroupId)) {
|
||||
group_id = stat->IntValue();
|
||||
}
|
||||
});
|
||||
if (event.Name() == kAsync) {
|
||||
EXPECT_FALSE(group_id.has_value());
|
||||
} else {
|
||||
|
|
|
@ -86,8 +86,10 @@ class XStatsOwner {
|
|||
}
|
||||
}
|
||||
|
||||
// Shortcut to get a specfic stat type, nullptr if it is absent.
|
||||
const XStat* GetStats(int64 stat_type) const;
|
||||
// Shortcut to get a specific stat type, nullopt if absent.
|
||||
// This function performs a linear search for the requested stat value.
|
||||
// Prefer ForEachStat above when multiple stat values are necessary.
|
||||
absl::optional<XStatVisitor> GetStat(int64 stat_type) const;
|
||||
|
||||
private:
|
||||
const T* stats_owner_;
|
||||
|
@ -241,14 +243,16 @@ class XPlaneVisitor : public XStatsOwner<XPlane> {
|
|||
};
|
||||
|
||||
template <class T>
|
||||
const XStat* XStatsOwner<T>::GetStats(int64 stat_type) const {
|
||||
absl::optional<int64> stat_metadata_id =
|
||||
metadata_->GetStatMetadataId(stat_type);
|
||||
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane.
|
||||
absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const {
|
||||
if (absl::optional<int64> stat_metadata_id =
|
||||
metadata_->GetStatMetadataId(stat_type)) {
|
||||
for (const XStat& stat : stats_owner_->stats()) {
|
||||
if (stat.metadata_id() == *stat_metadata_id) return &stat;
|
||||
if (stat.metadata_id() == *stat_metadata_id) {
|
||||
return XStatVisitor(metadata_, &stat);
|
||||
}
|
||||
return nullptr; // type does not exist in this owner.
|
||||
}
|
||||
}
|
||||
return absl::nullopt; // type does not exist in this owner.
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
|
Loading…
Reference in New Issue