diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index 9d24b2c6f0b..5962e15171c 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -74,6 +74,9 @@ tf_cc_test_gpu( "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", + "//tensorflow/core/profiler/utils:xplane_visitor", ], ) diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index b18f9422f35..298ccb1326a 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" @@ -270,7 +271,18 @@ TEST_F(DeviceTracerTest, TraceToXSpace) { TF_ASSERT_OK(tracer->CollectData(&space)); // At least one gpu plane and one host plane for launching events. EXPECT_NE(FindPlaneWithName(space, kHostThreads), nullptr); - EXPECT_NE(FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0)), nullptr); + + const XPlane* device_plane = + FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0)); + EXPECT_NE(device_plane, nullptr); // Check if device plane is serialized. + // Check if device capacity is serialized. + XPlaneVisitor plane(device_plane); + EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr); + EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr); + EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr); + EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr); + EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr); + EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr); } } // namespace diff --git a/tensorflow/core/profiler/utils/xplane_visitor.cc b/tensorflow/core/profiler/utils/xplane_visitor.cc index 39fd7cd92e2..919cdc2a2f0 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.cc +++ b/tensorflow/core/profiler/utils/xplane_visitor.cc @@ -27,12 +27,14 @@ XStatVisitor::XStatVisitor(const XPlaneVisitor* plane, const XStat* stat) XEventVisitor::XEventVisitor(const XPlaneVisitor* plane, const XLine* line, const XEvent* event) - : plane_(plane), + : XStatsOwner(plane, event), + plane_(plane), line_(line), event_(event), metadata_(plane->GetEventMetadata(event_->metadata_id())) {} -XPlaneVisitor::XPlaneVisitor(const XPlane* plane) : plane_(plane) { +XPlaneVisitor::XPlaneVisitor(const XPlane* plane) + : XStatsOwner(this, plane), plane_(plane) { for (const auto& stat_metadata : plane->stat_metadata()) { StatType type = tensorflow::profiler::GetStatType(stat_metadata.second.name()); diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index 09152831be8..4acdec34563 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -62,7 +62,29 @@ class XStatVisitor { const StatType type_; }; -class XEventVisitor { +template +class XStatsOwner { + public: + XStatsOwner(const XPlaneVisitor* metadata, const T* stats_owner) + : stats_owner_(stats_owner), metadata_(metadata) {} + + // For each plane level stats, call the specified lambda. + template + void ForEachStat(ForEachStatFunc&& for_each_stat) const { + for (const XStat& stat : stats_owner_->stats()) { + for_each_stat(XStatVisitor(metadata_, &stat)); + } + } + + // Shortcut to get a specfic stat type, nullptr if it is absent. + const XStat* GetStats(StatType stat_type) const; + + private: + const T* stats_owner_; + const XPlaneVisitor* metadata_; +}; + +class XEventVisitor : public XStatsOwner { public: XEventVisitor(const XPlaneVisitor* plane, const XLine* line, const XEvent* event); @@ -99,13 +121,6 @@ class XEventVisitor { int64 NumOccurrences() const { return event_->num_occurrences(); } - template - void ForEachStat(ForEachStatFunc&& for_each_stat) const { - for (const XStat& stat : event_->stats()) { - for_each_stat(XStatVisitor(plane_, &stat)); - } - } - bool operator<(const XEventVisitor& other) const { return GetTimespan() < other.GetTimespan(); } @@ -155,7 +170,7 @@ class XLineVisitor { const XLine* line_; }; -class XPlaneVisitor { +class XPlaneVisitor : public XStatsOwner { public: explicit XPlaneVisitor(const XPlane* plane); @@ -186,6 +201,17 @@ class XPlaneVisitor { absl::flat_hash_map stat_type_map_; }; +template +const XStat* XStatsOwner::GetStats(StatType stat_type) const { + absl::optional stat_metadata_id = + metadata_->GetStatMetadataId(stat_type); + if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane. + for (const XStat& stat : stats_owner_->stats()) { + if (stat.metadata_id() == *stat_metadata_id) return &stat; + } + return nullptr; // type does not exist in this owner. +} + } // namespace profiler } // namespace tensorflow