Add support for iterating plane level stats.
Add unit test coverage for device cap serialization. PiperOrigin-RevId: 289913169 Change-Id: Ie5969da915436f3163ca24fa6a30c4f25d04269f
This commit is contained in:
parent
b7f05ca3e4
commit
b716f45921
|
@ -74,6 +74,9 @@ tf_cc_test_gpu(
|
|||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/profiler/internal:profiler_interface",
|
||||
"//tensorflow/core/profiler/utils:xplane_schema",
|
||||
"//tensorflow/core/profiler/utils:xplane_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_visitor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/profiler/internal/profiler_interface.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_schema.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_utils.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
|
@ -270,7 +271,18 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
|
|||
TF_ASSERT_OK(tracer->CollectData(&space));
|
||||
// At least one gpu plane and one host plane for launching events.
|
||||
EXPECT_NE(FindPlaneWithName(space, kHostThreads), nullptr);
|
||||
EXPECT_NE(FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0)), nullptr);
|
||||
|
||||
const XPlane* device_plane =
|
||||
FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0));
|
||||
EXPECT_NE(device_plane, nullptr); // Check if device plane is serialized.
|
||||
// Check if device capacity is serialized.
|
||||
XPlaneVisitor plane(device_plane);
|
||||
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr);
|
||||
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -27,12 +27,14 @@ XStatVisitor::XStatVisitor(const XPlaneVisitor* plane, const XStat* stat)
|
|||
|
||||
XEventVisitor::XEventVisitor(const XPlaneVisitor* plane, const XLine* line,
|
||||
const XEvent* event)
|
||||
: plane_(plane),
|
||||
: XStatsOwner<XEvent>(plane, event),
|
||||
plane_(plane),
|
||||
line_(line),
|
||||
event_(event),
|
||||
metadata_(plane->GetEventMetadata(event_->metadata_id())) {}
|
||||
|
||||
XPlaneVisitor::XPlaneVisitor(const XPlane* plane) : plane_(plane) {
|
||||
XPlaneVisitor::XPlaneVisitor(const XPlane* plane)
|
||||
: XStatsOwner<XPlane>(this, plane), plane_(plane) {
|
||||
for (const auto& stat_metadata : plane->stat_metadata()) {
|
||||
StatType type =
|
||||
tensorflow::profiler::GetStatType(stat_metadata.second.name());
|
||||
|
|
|
@ -62,7 +62,29 @@ class XStatVisitor {
|
|||
const StatType type_;
|
||||
};
|
||||
|
||||
class XEventVisitor {
|
||||
template <class T>
|
||||
class XStatsOwner {
|
||||
public:
|
||||
XStatsOwner(const XPlaneVisitor* metadata, const T* stats_owner)
|
||||
: stats_owner_(stats_owner), metadata_(metadata) {}
|
||||
|
||||
// For each plane level stats, call the specified lambda.
|
||||
template <typename ForEachStatFunc>
|
||||
void ForEachStat(ForEachStatFunc&& for_each_stat) const {
|
||||
for (const XStat& stat : stats_owner_->stats()) {
|
||||
for_each_stat(XStatVisitor(metadata_, &stat));
|
||||
}
|
||||
}
|
||||
|
||||
// Shortcut to get a specfic stat type, nullptr if it is absent.
|
||||
const XStat* GetStats(StatType stat_type) const;
|
||||
|
||||
private:
|
||||
const T* stats_owner_;
|
||||
const XPlaneVisitor* metadata_;
|
||||
};
|
||||
|
||||
class XEventVisitor : public XStatsOwner<XEvent> {
|
||||
public:
|
||||
XEventVisitor(const XPlaneVisitor* plane, const XLine* line,
|
||||
const XEvent* event);
|
||||
|
@ -99,13 +121,6 @@ class XEventVisitor {
|
|||
|
||||
int64 NumOccurrences() const { return event_->num_occurrences(); }
|
||||
|
||||
template <typename ForEachStatFunc>
|
||||
void ForEachStat(ForEachStatFunc&& for_each_stat) const {
|
||||
for (const XStat& stat : event_->stats()) {
|
||||
for_each_stat(XStatVisitor(plane_, &stat));
|
||||
}
|
||||
}
|
||||
|
||||
bool operator<(const XEventVisitor& other) const {
|
||||
return GetTimespan() < other.GetTimespan();
|
||||
}
|
||||
|
@ -155,7 +170,7 @@ class XLineVisitor {
|
|||
const XLine* line_;
|
||||
};
|
||||
|
||||
class XPlaneVisitor {
|
||||
class XPlaneVisitor : public XStatsOwner<XPlane> {
|
||||
public:
|
||||
explicit XPlaneVisitor(const XPlane* plane);
|
||||
|
||||
|
@ -186,6 +201,17 @@ class XPlaneVisitor {
|
|||
absl::flat_hash_map<StatType, const XStatMetadata*> stat_type_map_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
const XStat* XStatsOwner<T>::GetStats(StatType stat_type) const {
|
||||
absl::optional<int64> stat_metadata_id =
|
||||
metadata_->GetStatMetadataId(stat_type);
|
||||
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane.
|
||||
for (const XStat& stat : stats_owner_->stats()) {
|
||||
if (stat.metadata_id() == *stat_metadata_id) return &stat;
|
||||
}
|
||||
return nullptr; // type does not exist in this owner.
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
Loading…
Reference in New Issue