XStatsOwner::GetStat returns optional<XStatVisitor>
PiperOrigin-RevId: 315305610 Change-Id: I14ea35d662c585789dc63877d1938ed250ca32fd
This commit is contained in:
parent
cb8342c4eb
commit
ee4f27a524
|
@ -148,11 +148,10 @@ void CollectTfActivities(const XLineVisitor& line,
|
||||||
if (tf_op != nullptr) {
|
if (tf_op != nullptr) {
|
||||||
++tf_op_id;
|
++tf_op_id;
|
||||||
bool is_eager = false;
|
bool is_eager = false;
|
||||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
if (absl::optional<XStatVisitor> stat =
|
||||||
if (stat.Type() == StatType::kIsEager) {
|
event.GetStat(StatType::kIsEager)) {
|
||||||
is_eager = stat.IntValue();
|
is_eager = stat->IntValue();
|
||||||
}
|
}
|
||||||
});
|
|
||||||
Timespan span(event.TimestampPs(), event.DurationPs());
|
Timespan span(event.TimestampPs(), event.DurationPs());
|
||||||
tf_activities->push_back(
|
tf_activities->push_back(
|
||||||
{span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});
|
{span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});
|
||||||
|
|
|
@ -49,25 +49,29 @@ namespace {
|
||||||
DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
|
DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) {
|
||||||
DeviceCapabilities cap;
|
DeviceCapabilities cap;
|
||||||
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane);
|
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane);
|
||||||
if (auto clock_rate_khz = plane.GetStats(kDevCapClockRateKHz)) {
|
plane.ForEachStat([&cap](const XStatVisitor& stat) {
|
||||||
cap.set_clock_rate_in_ghz(clock_rate_khz->int64_value() / 1000000.0);
|
if (!stat.Type().has_value()) return;
|
||||||
}
|
switch (stat.Type().value()) {
|
||||||
if (auto core_count = plane.GetStats(kDevCapCoreCount)) {
|
case kDevCapClockRateKHz:
|
||||||
cap.set_num_cores(core_count->int64_value());
|
cap.set_clock_rate_in_ghz(stat.IntValue() / 1000000.0);
|
||||||
}
|
break;
|
||||||
// Set memory bandwidth in bytes/s.
|
case kDevCapCoreCount:
|
||||||
if (auto memory_bw = plane.GetStats(kDevCapMemoryBandwidth)) {
|
cap.set_num_cores(stat.IntValue());
|
||||||
cap.set_memory_bandwidth(memory_bw->int64_value());
|
break;
|
||||||
}
|
case kDevCapMemoryBandwidth:
|
||||||
if (auto memory_size_in_bytes = plane.GetStats(kDevCapMemorySize)) {
|
cap.set_memory_bandwidth(stat.IntValue()); // bytes/s
|
||||||
cap.set_memory_size_in_bytes(memory_size_in_bytes->uint64_value());
|
break;
|
||||||
}
|
case kDevCapMemorySize:
|
||||||
if (auto cap_major = plane.GetStats(kDevCapComputeCapMajor)) {
|
cap.set_memory_size_in_bytes(stat.UintValue());
|
||||||
cap.mutable_compute_capability()->set_major(cap_major->int64_value());
|
break;
|
||||||
}
|
case kDevCapComputeCapMajor:
|
||||||
if (auto cap_minor = plane.GetStats(kDevCapComputeCapMinor)) {
|
cap.mutable_compute_capability()->set_major(stat.IntValue());
|
||||||
cap.mutable_compute_capability()->set_minor(cap_minor->int64_value());
|
break;
|
||||||
}
|
case kDevCapComputeCapMinor:
|
||||||
|
cap.mutable_compute_capability()->set_minor(stat.IntValue());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
return cap;
|
return cap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,14 +112,11 @@ StepEvents ConvertHostThreadsXPlaneToStepEvents(
|
||||||
StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
|
StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
|
||||||
StepEvents result;
|
StepEvents result;
|
||||||
line.ForEachEvent([&](const XEventVisitor& event) {
|
line.ForEachEvent([&](const XEventVisitor& event) {
|
||||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
if (absl::optional<XStatVisitor> stat = event.GetStat(StatType::kGroupId)) {
|
||||||
if (stat.Type() == StatType::kGroupId) {
|
result[stat->IntValue()].AddMarker(
|
||||||
result[stat.IntValue()].AddMarker(
|
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
|
||||||
StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
|
Timespan(event.TimestampPs(), event.DurationPs())));
|
||||||
Timespan(event.TimestampPs(), event.DurationPs())));
|
}
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -274,12 +274,12 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
|
||||||
EXPECT_EQ(device_plane->event_metadata_size(), 4);
|
EXPECT_EQ(device_plane->event_metadata_size(), 4);
|
||||||
// Check if device capacity is serialized.
|
// Check if device capacity is serialized.
|
||||||
XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane);
|
XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane);
|
||||||
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapClockRateKHz).has_value());
|
||||||
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapCoreCount).has_value());
|
||||||
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapMemoryBandwidth).has_value());
|
||||||
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapMemorySize).has_value());
|
||||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMajor).has_value());
|
||||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr);
|
EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMinor).has_value());
|
||||||
|
|
||||||
// Check if the device events timestamps are set.
|
// Check if the device events timestamps are set.
|
||||||
int total_events = 0;
|
int total_events = 0;
|
||||||
|
|
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
@ -225,30 +224,30 @@ EventNode::EventNode(const EventNode& event_node)
|
||||||
: EventNode(event_node.plane_, event_node.raw_line_,
|
: EventNode(event_node.plane_, event_node.raw_line_,
|
||||||
event_node.raw_event_) {}
|
event_node.raw_event_) {}
|
||||||
|
|
||||||
const XStat* EventNode::GetContextStat(int64 stat_type) const {
|
absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
|
||||||
if (const XStat* stat = visitor_.GetStats(stat_type)) {
|
for (const EventNode* node = this; node != nullptr; node = node->parent_) {
|
||||||
return stat;
|
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
|
||||||
} else if (parent_) {
|
return stat;
|
||||||
return parent_->GetContextStat(stat_type);
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string EventNode::GetGroupName() const {
|
std::string EventNode::GetGroupName() const {
|
||||||
std::vector<std::string> name_parts;
|
std::string name;
|
||||||
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
|
if (absl::optional<XStatVisitor> stat =
|
||||||
XStatVisitor stat(plane_, graph_type_stat);
|
GetContextStat(StatType::kGraphType)) {
|
||||||
name_parts.push_back(stat.ToString());
|
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
|
||||||
}
|
}
|
||||||
int64 step_num = group_id_.value_or(0);
|
int64 step_num = group_id_.value_or(0);
|
||||||
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) {
|
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
|
||||||
step_num = step_num_stat->int64_value();
|
step_num = stat->IntValue();
|
||||||
|
} else if (absl::optional<XStatVisitor> stat =
|
||||||
|
GetContextStat(StatType::kStepNum)) {
|
||||||
|
step_num = stat->IntValue();
|
||||||
}
|
}
|
||||||
if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) {
|
absl::StrAppend(&name, step_num);
|
||||||
step_num = iter_num_stat->int64_value();
|
return name;
|
||||||
}
|
|
||||||
name_parts.push_back(absl::StrCat(step_num));
|
|
||||||
return absl::StrJoin(name_parts, " ");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void EventNode::PropagateGroupId(int64 group_id) {
|
void EventNode::PropagateGroupId(int64 group_id) {
|
||||||
|
@ -343,11 +342,12 @@ void EventForest::ConnectInterThread(
|
||||||
for (const auto& parent_event_node : *parent_event_node_list) {
|
for (const auto& parent_event_node : *parent_event_node_list) {
|
||||||
std::vector<int64> stats;
|
std::vector<int64> stats;
|
||||||
for (auto stat_type : parent_stat_types) {
|
for (auto stat_type : parent_stat_types) {
|
||||||
const XStat* stat = parent_event_node->GetContextStat(stat_type);
|
absl::optional<XStatVisitor> stat =
|
||||||
|
parent_event_node->GetContextStat(stat_type);
|
||||||
if (!stat) break;
|
if (!stat) break;
|
||||||
stats.push_back(stat->value_case() == stat->kInt64Value
|
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
|
||||||
? stat->int64_value()
|
? stat->IntValue()
|
||||||
: stat->uint64_value());
|
: stat->UintValue());
|
||||||
}
|
}
|
||||||
if (stats.size() == parent_stat_types.size()) {
|
if (stats.size() == parent_stat_types.size()) {
|
||||||
connect_map[stats] = parent_event_node.get();
|
connect_map[stats] = parent_event_node.get();
|
||||||
|
@ -359,11 +359,12 @@ void EventForest::ConnectInterThread(
|
||||||
for (const auto& child_event_node : *child_event_node_list) {
|
for (const auto& child_event_node : *child_event_node_list) {
|
||||||
std::vector<int64> stats;
|
std::vector<int64> stats;
|
||||||
for (auto stat_type : *child_stat_types) {
|
for (auto stat_type : *child_stat_types) {
|
||||||
const XStat* stat = child_event_node->GetContextStat(stat_type);
|
absl::optional<XStatVisitor> stat =
|
||||||
|
child_event_node->GetContextStat(stat_type);
|
||||||
if (!stat) break;
|
if (!stat) break;
|
||||||
stats.push_back(stat->value_case() == stat->kInt64Value
|
stats.push_back((stat->ValueCase() == XStat::kInt64Value)
|
||||||
? stat->int64_value()
|
? stat->IntValue()
|
||||||
: stat->uint64_value());
|
: stat->UintValue());
|
||||||
}
|
}
|
||||||
if (stats.size() == child_stat_types->size()) {
|
if (stats.size() == child_stat_types->size()) {
|
||||||
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
|
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
|
||||||
|
@ -429,14 +430,14 @@ void EventForest::ProcessTensorFlowLoop() {
|
||||||
if (!executor_event_list) return;
|
if (!executor_event_list) return;
|
||||||
for (auto& executor_event : *executor_event_list) {
|
for (auto& executor_event : *executor_event_list) {
|
||||||
if (IsTfDataEvent(*executor_event)) continue;
|
if (IsTfDataEvent(*executor_event)) continue;
|
||||||
const XStat* step_id_stat =
|
absl::optional<XStatVisitor> step_id_stat =
|
||||||
executor_event->GetContextStat(StatType::kStepId);
|
executor_event->GetContextStat(StatType::kStepId);
|
||||||
const XStat* iter_num_stat =
|
absl::optional<XStatVisitor> iter_num_stat =
|
||||||
executor_event->GetContextStat(StatType::kIterNum);
|
executor_event->GetContextStat(StatType::kIterNum);
|
||||||
if (!step_id_stat || !iter_num_stat) continue;
|
if (!step_id_stat || !iter_num_stat) continue;
|
||||||
int64 step_id = step_id_stat->int64_value();
|
int64 step_id = step_id_stat->IntValue();
|
||||||
TensorFlowLoop& tf_loop = tf_loops[step_id];
|
TensorFlowLoop& tf_loop = tf_loops[step_id];
|
||||||
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->int64_value()];
|
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
|
||||||
if (!iteration.first_event ||
|
if (!iteration.first_event ||
|
||||||
executor_event->StartsBefore(*iteration.first_event)) {
|
executor_event->StartsBefore(*iteration.first_event)) {
|
||||||
iteration.first_event = executor_event.get();
|
iteration.first_event = executor_event.get();
|
||||||
|
|
|
@ -78,7 +78,7 @@ class EventNode {
|
||||||
|
|
||||||
const XEventVisitor& GetEventVisitor() const { return visitor_; }
|
const XEventVisitor& GetEventVisitor() const { return visitor_; }
|
||||||
|
|
||||||
const XStat* GetContextStat(int64 stat_type) const;
|
absl::optional<XStatVisitor> GetContextStat(int64 stat_type) const;
|
||||||
|
|
||||||
void AddStepName(absl::string_view step_name);
|
void AddStepName(absl::string_view step_name);
|
||||||
|
|
||||||
|
|
|
@ -174,12 +174,10 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
|
||||||
line.ForEachEvent(
|
line.ForEachEvent(
|
||||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||||
absl::optional<int64> group_id;
|
absl::optional<int64> group_id;
|
||||||
event.ForEachStat(
|
if (absl::optional<XStatVisitor> stat =
|
||||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
event.GetStat(StatType::kGroupId)) {
|
||||||
if (stat.Type() == StatType::kGroupId) {
|
group_id = stat->IntValue();
|
||||||
group_id = stat.IntValue();
|
}
|
||||||
}
|
|
||||||
});
|
|
||||||
EXPECT_TRUE(group_id.has_value());
|
EXPECT_TRUE(group_id.has_value());
|
||||||
EXPECT_EQ(*group_id, 0);
|
EXPECT_EQ(*group_id, 0);
|
||||||
});
|
});
|
||||||
|
@ -305,12 +303,10 @@ TEST(GroupEventsTest, SemanticArgTest) {
|
||||||
line.ForEachEvent(
|
line.ForEachEvent(
|
||||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||||
absl::optional<int64> group_id;
|
absl::optional<int64> group_id;
|
||||||
event.ForEachStat(
|
if (absl::optional<XStatVisitor> stat =
|
||||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
event.GetStat(StatType::kGroupId)) {
|
||||||
if (stat.Type() == StatType::kGroupId) {
|
group_id = stat->IntValue();
|
||||||
group_id = stat.IntValue();
|
}
|
||||||
}
|
|
||||||
});
|
|
||||||
EXPECT_TRUE(group_id.has_value());
|
EXPECT_TRUE(group_id.has_value());
|
||||||
EXPECT_EQ(*group_id, 0);
|
EXPECT_EQ(*group_id, 0);
|
||||||
});
|
});
|
||||||
|
@ -339,12 +335,10 @@ TEST(GroupEventsTest, AsyncEventTest) {
|
||||||
line.ForEachEvent(
|
line.ForEachEvent(
|
||||||
[&](const tensorflow::profiler::XEventVisitor& event) {
|
[&](const tensorflow::profiler::XEventVisitor& event) {
|
||||||
absl::optional<int64> group_id;
|
absl::optional<int64> group_id;
|
||||||
event.ForEachStat(
|
if (absl::optional<XStatVisitor> stat =
|
||||||
[&](const tensorflow::profiler::XStatVisitor& stat) {
|
event.GetStat(StatType::kGroupId)) {
|
||||||
if (stat.Type() == StatType::kGroupId) {
|
group_id = stat->IntValue();
|
||||||
group_id = stat.IntValue();
|
}
|
||||||
}
|
|
||||||
});
|
|
||||||
if (event.Name() == kAsync) {
|
if (event.Name() == kAsync) {
|
||||||
EXPECT_FALSE(group_id.has_value());
|
EXPECT_FALSE(group_id.has_value());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -86,8 +86,10 @@ class XStatsOwner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shortcut to get a specfic stat type, nullptr if it is absent.
|
// Shortcut to get a specific stat type, nullopt if absent.
|
||||||
const XStat* GetStats(int64 stat_type) const;
|
// This function performs a linear search for the requested stat value.
|
||||||
|
// Prefer ForEachStat above when multiple stat values are necessary.
|
||||||
|
absl::optional<XStatVisitor> GetStat(int64 stat_type) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const T* stats_owner_;
|
const T* stats_owner_;
|
||||||
|
@ -241,14 +243,16 @@ class XPlaneVisitor : public XStatsOwner<XPlane> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
const XStat* XStatsOwner<T>::GetStats(int64 stat_type) const {
|
absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const {
|
||||||
absl::optional<int64> stat_metadata_id =
|
if (absl::optional<int64> stat_metadata_id =
|
||||||
metadata_->GetStatMetadataId(stat_type);
|
metadata_->GetStatMetadataId(stat_type)) {
|
||||||
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane.
|
for (const XStat& stat : stats_owner_->stats()) {
|
||||||
for (const XStat& stat : stats_owner_->stats()) {
|
if (stat.metadata_id() == *stat_metadata_id) {
|
||||||
if (stat.metadata_id() == *stat_metadata_id) return &stat;
|
return XStatVisitor(metadata_, &stat);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nullptr; // type does not exist in this owner.
|
return absl::nullopt; // type does not exist in this owner.
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
|
|
Loading…
Reference in New Issue