diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 4a369b8b96a..4abe5740969 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -148,11 +148,10 @@ void CollectTfActivities(const XLineVisitor& line, if (tf_op != nullptr) { ++tf_op_id; bool is_eager = false; - event.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kIsEager) { - is_eager = stat.IntValue(); - } - }); + if (absl::optional stat = + event.GetStat(StatType::kIsEager)) { + is_eager = stat->IntValue(); + } Timespan span(event.TimestampPs(), event.DurationPs()); tf_activities->push_back( {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 4d2a45747e0..ccd7c54fa19 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -49,25 +49,29 @@ namespace { DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) { DeviceCapabilities cap; XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_plane); - if (auto clock_rate_khz = plane.GetStats(kDevCapClockRateKHz)) { - cap.set_clock_rate_in_ghz(clock_rate_khz->int64_value() / 1000000.0); - } - if (auto core_count = plane.GetStats(kDevCapCoreCount)) { - cap.set_num_cores(core_count->int64_value()); - } - // Set memory bandwidth in bytes/s. - if (auto memory_bw = plane.GetStats(kDevCapMemoryBandwidth)) { - cap.set_memory_bandwidth(memory_bw->int64_value()); - } - if (auto memory_size_in_bytes = plane.GetStats(kDevCapMemorySize)) { - cap.set_memory_size_in_bytes(memory_size_in_bytes->uint64_value()); - } - if (auto cap_major = plane.GetStats(kDevCapComputeCapMajor)) { - cap.mutable_compute_capability()->set_major(cap_major->int64_value()); - } - if (auto cap_minor = plane.GetStats(kDevCapComputeCapMinor)) { - cap.mutable_compute_capability()->set_minor(cap_minor->int64_value()); - } + plane.ForEachStat([&cap](const XStatVisitor& stat) { + if (!stat.Type().has_value()) return; + switch (stat.Type().value()) { + case kDevCapClockRateKHz: + cap.set_clock_rate_in_ghz(stat.IntValue() / 1000000.0); + break; + case kDevCapCoreCount: + cap.set_num_cores(stat.IntValue()); + break; + case kDevCapMemoryBandwidth: + cap.set_memory_bandwidth(stat.IntValue()); // bytes/s + break; + case kDevCapMemorySize: + cap.set_memory_size_in_bytes(stat.UintValue()); + break; + case kDevCapComputeCapMajor: + cap.mutable_compute_capability()->set_major(stat.IntValue()); + break; + case kDevCapComputeCapMinor: + cap.mutable_compute_capability()->set_minor(stat.IntValue()); + break; + } + }); return cap; } diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index 7bb7cd6943c..bfe0ac86ef4 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -112,14 +112,11 @@ StepEvents ConvertHostThreadsXPlaneToStepEvents( StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) { StepEvents result; line.ForEachEvent([&](const XEventVisitor& event) { - event.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kGroupId) { - result[stat.IntValue()].AddMarker( - StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(), - Timespan(event.TimestampPs(), event.DurationPs()))); - return; - } - }); + if (absl::optional stat = event.GetStat(StatType::kGroupId)) { + result[stat->IntValue()].AddMarker( + StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(), + Timespan(event.TimestampPs(), event.DurationPs()))); + } }); return result; } diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index e6aacb66b89..6fc19e776e1 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -274,12 +274,12 @@ TEST_F(DeviceTracerTest, TraceToXSpace) { EXPECT_EQ(device_plane->event_metadata_size(), 4); // Check if device capacity is serialized. XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane); - EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr); - EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr); - EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr); - EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr); - EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr); - EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr); + EXPECT_TRUE(plane.GetStat(kDevCapClockRateKHz).has_value()); + EXPECT_TRUE(plane.GetStat(kDevCapCoreCount).has_value()); + EXPECT_TRUE(plane.GetStat(kDevCapMemoryBandwidth).has_value()); + EXPECT_TRUE(plane.GetStat(kDevCapMemorySize).has_value()); + EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMajor).has_value()); + EXPECT_TRUE(plane.GetStat(kDevCapComputeCapMinor).has_value()); // Check if the device events timestamps are set. int total_events = 0; diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc index 8b4d68a0668..99c6136fc84 100644 --- a/tensorflow/core/profiler/utils/group_events.cc +++ b/tensorflow/core/profiler/utils/group_events.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -225,30 +224,30 @@ EventNode::EventNode(const EventNode& event_node) : EventNode(event_node.plane_, event_node.raw_line_, event_node.raw_event_) {} -const XStat* EventNode::GetContextStat(int64 stat_type) const { - if (const XStat* stat = visitor_.GetStats(stat_type)) { - return stat; - } else if (parent_) { - return parent_->GetContextStat(stat_type); +absl::optional EventNode::GetContextStat(int64 stat_type) const { + for (const EventNode* node = this; node != nullptr; node = node->parent_) { + if (absl::optional stat = node->visitor_.GetStat(stat_type)) { + return stat; + } } - return nullptr; + return absl::nullopt; } std::string EventNode::GetGroupName() const { - std::vector name_parts; - if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) { - XStatVisitor stat(plane_, graph_type_stat); - name_parts.push_back(stat.ToString()); + std::string name; + if (absl::optional stat = + GetContextStat(StatType::kGraphType)) { + absl::StrAppend(&name, stat->StrOrRefValue(), " "); } int64 step_num = group_id_.value_or(0); - if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) { - step_num = step_num_stat->int64_value(); + if (absl::optional stat = GetContextStat(StatType::kIterNum)) { + step_num = stat->IntValue(); + } else if (absl::optional stat = + GetContextStat(StatType::kStepNum)) { + step_num = stat->IntValue(); } - if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) { - step_num = iter_num_stat->int64_value(); - } - name_parts.push_back(absl::StrCat(step_num)); - return absl::StrJoin(name_parts, " "); + absl::StrAppend(&name, step_num); + return name; } void EventNode::PropagateGroupId(int64 group_id) { @@ -343,11 +342,12 @@ void EventForest::ConnectInterThread( for (const auto& parent_event_node : *parent_event_node_list) { std::vector stats; for (auto stat_type : parent_stat_types) { - const XStat* stat = parent_event_node->GetContextStat(stat_type); + absl::optional stat = + parent_event_node->GetContextStat(stat_type); if (!stat) break; - stats.push_back(stat->value_case() == stat->kInt64Value - ? stat->int64_value() - : stat->uint64_value()); + stats.push_back((stat->ValueCase() == XStat::kInt64Value) + ? stat->IntValue() + : stat->UintValue()); } if (stats.size() == parent_stat_types.size()) { connect_map[stats] = parent_event_node.get(); @@ -359,11 +359,12 @@ void EventForest::ConnectInterThread( for (const auto& child_event_node : *child_event_node_list) { std::vector stats; for (auto stat_type : *child_stat_types) { - const XStat* stat = child_event_node->GetContextStat(stat_type); + absl::optional stat = + child_event_node->GetContextStat(stat_type); if (!stat) break; - stats.push_back(stat->value_case() == stat->kInt64Value - ? stat->int64_value() - : stat->uint64_value()); + stats.push_back((stat->ValueCase() == XStat::kInt64Value) + ? stat->IntValue() + : stat->UintValue()); } if (stats.size() == child_stat_types->size()) { if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) { @@ -429,14 +430,14 @@ void EventForest::ProcessTensorFlowLoop() { if (!executor_event_list) return; for (auto& executor_event : *executor_event_list) { if (IsTfDataEvent(*executor_event)) continue; - const XStat* step_id_stat = + absl::optional step_id_stat = executor_event->GetContextStat(StatType::kStepId); - const XStat* iter_num_stat = + absl::optional iter_num_stat = executor_event->GetContextStat(StatType::kIterNum); if (!step_id_stat || !iter_num_stat) continue; - int64 step_id = step_id_stat->int64_value(); + int64 step_id = step_id_stat->IntValue(); TensorFlowLoop& tf_loop = tf_loops[step_id]; - TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->int64_value()]; + TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()]; if (!iteration.first_event || executor_event->StartsBefore(*iteration.first_event)) { iteration.first_event = executor_event.get(); diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h index 2d10480a64f..388da0f5d67 100644 --- a/tensorflow/core/profiler/utils/group_events.h +++ b/tensorflow/core/profiler/utils/group_events.h @@ -78,7 +78,7 @@ class EventNode { const XEventVisitor& GetEventVisitor() const { return visitor_; } - const XStat* GetContextStat(int64 stat_type) const; + absl::optional GetContextStat(int64 stat_type) const; void AddStepName(absl::string_view step_name); diff --git a/tensorflow/core/profiler/utils/group_events_test.cc b/tensorflow/core/profiler/utils/group_events_test.cc index ea378b7cb70..6ff069dc1ae 100644 --- a/tensorflow/core/profiler/utils/group_events_test.cc +++ b/tensorflow/core/profiler/utils/group_events_test.cc @@ -174,12 +174,10 @@ TEST(GroupEventsTest, GroupFunctionalOp) { line.ForEachEvent( [&](const tensorflow::profiler::XEventVisitor& event) { absl::optional group_id; - event.ForEachStat( - [&](const tensorflow::profiler::XStatVisitor& stat) { - if (stat.Type() == StatType::kGroupId) { - group_id = stat.IntValue(); - } - }); + if (absl::optional stat = + event.GetStat(StatType::kGroupId)) { + group_id = stat->IntValue(); + } EXPECT_TRUE(group_id.has_value()); EXPECT_EQ(*group_id, 0); }); @@ -305,12 +303,10 @@ TEST(GroupEventsTest, SemanticArgTest) { line.ForEachEvent( [&](const tensorflow::profiler::XEventVisitor& event) { absl::optional group_id; - event.ForEachStat( - [&](const tensorflow::profiler::XStatVisitor& stat) { - if (stat.Type() == StatType::kGroupId) { - group_id = stat.IntValue(); - } - }); + if (absl::optional stat = + event.GetStat(StatType::kGroupId)) { + group_id = stat->IntValue(); + } EXPECT_TRUE(group_id.has_value()); EXPECT_EQ(*group_id, 0); }); @@ -339,12 +335,10 @@ TEST(GroupEventsTest, AsyncEventTest) { line.ForEachEvent( [&](const tensorflow::profiler::XEventVisitor& event) { absl::optional group_id; - event.ForEachStat( - [&](const tensorflow::profiler::XStatVisitor& stat) { - if (stat.Type() == StatType::kGroupId) { - group_id = stat.IntValue(); - } - }); + if (absl::optional stat = + event.GetStat(StatType::kGroupId)) { + group_id = stat->IntValue(); + } if (event.Name() == kAsync) { EXPECT_FALSE(group_id.has_value()); } else { diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index 4120a2821ca..a838825c773 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -86,8 +86,10 @@ class XStatsOwner { } } - // Shortcut to get a specfic stat type, nullptr if it is absent. - const XStat* GetStats(int64 stat_type) const; + // Shortcut to get a specific stat type, nullopt if absent. + // This function performs a linear search for the requested stat value. + // Prefer ForEachStat above when multiple stat values are necessary. + absl::optional GetStat(int64 stat_type) const; private: const T* stats_owner_; @@ -241,14 +243,16 @@ class XPlaneVisitor : public XStatsOwner { }; template -const XStat* XStatsOwner::GetStats(int64 stat_type) const { - absl::optional stat_metadata_id = - metadata_->GetStatMetadataId(stat_type); - if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane. - for (const XStat& stat : stats_owner_->stats()) { - if (stat.metadata_id() == *stat_metadata_id) return &stat; +absl::optional XStatsOwner::GetStat(int64 stat_type) const { + if (absl::optional stat_metadata_id = + metadata_->GetStatMetadataId(stat_type)) { + for (const XStat& stat : stats_owner_->stats()) { + if (stat.metadata_id() == *stat_metadata_id) { + return XStatVisitor(metadata_, &stat); + } + } } - return nullptr; // type does not exist in this owner. + return absl::nullopt; // type does not exist in this owner. } } // namespace profiler