Add support for iterating plane level stats.

Add unit test coverage for device cap serialization.

PiperOrigin-RevId: 289913169
Change-Id: Ie5969da915436f3163ca24fa6a30c4f25d04269f
This commit is contained in:
A. Unique TensorFlower 2020-01-15 12:26:53 -08:00 committed by TensorFlower Gardener
parent b7f05ca3e4
commit b716f45921
4 changed files with 55 additions and 12 deletions

View File

@ -74,6 +74,9 @@ tf_cc_test_gpu(
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/core/profiler/utils:xplane_visitor",
],
)

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -270,7 +271,18 @@ TEST_F(DeviceTracerTest, TraceToXSpace) {
TF_ASSERT_OK(tracer->CollectData(&space));
// At least one gpu plane and one host plane for launching events.
EXPECT_NE(FindPlaneWithName(space, kHostThreads), nullptr);
EXPECT_NE(FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0)), nullptr);
const XPlane* device_plane =
FindPlaneWithName(space, StrCat(kGpuPlanePrefix, 0));
EXPECT_NE(device_plane, nullptr); // Check if device plane is serialized.
// Check if device capacity is serialized.
XPlaneVisitor plane(device_plane);
EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr);
EXPECT_NE(plane.GetStats(kDevCapCoreCount), nullptr);
EXPECT_NE(plane.GetStats(kDevCapMemoryBandwidth), nullptr);
EXPECT_NE(plane.GetStats(kDevCapMemorySize), nullptr);
EXPECT_NE(plane.GetStats(kDevCapComputeCapMajor), nullptr);
EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr);
}
} // namespace

View File

@ -27,12 +27,14 @@ XStatVisitor::XStatVisitor(const XPlaneVisitor* plane, const XStat* stat)
XEventVisitor::XEventVisitor(const XPlaneVisitor* plane, const XLine* line,
const XEvent* event)
: plane_(plane),
: XStatsOwner<XEvent>(plane, event),
plane_(plane),
line_(line),
event_(event),
metadata_(plane->GetEventMetadata(event_->metadata_id())) {}
XPlaneVisitor::XPlaneVisitor(const XPlane* plane) : plane_(plane) {
XPlaneVisitor::XPlaneVisitor(const XPlane* plane)
: XStatsOwner<XPlane>(this, plane), plane_(plane) {
for (const auto& stat_metadata : plane->stat_metadata()) {
StatType type =
tensorflow::profiler::GetStatType(stat_metadata.second.name());

View File

@ -62,7 +62,29 @@ class XStatVisitor {
const StatType type_;
};
class XEventVisitor {
template <class T>
class XStatsOwner {
public:
XStatsOwner(const XPlaneVisitor* metadata, const T* stats_owner)
: stats_owner_(stats_owner), metadata_(metadata) {}
// For each plane level stats, call the specified lambda.
template <typename ForEachStatFunc>
void ForEachStat(ForEachStatFunc&& for_each_stat) const {
for (const XStat& stat : stats_owner_->stats()) {
for_each_stat(XStatVisitor(metadata_, &stat));
}
}
// Shortcut to get a specfic stat type, nullptr if it is absent.
const XStat* GetStats(StatType stat_type) const;
private:
const T* stats_owner_;
const XPlaneVisitor* metadata_;
};
class XEventVisitor : public XStatsOwner<XEvent> {
public:
XEventVisitor(const XPlaneVisitor* plane, const XLine* line,
const XEvent* event);
@ -99,13 +121,6 @@ class XEventVisitor {
int64 NumOccurrences() const { return event_->num_occurrences(); }
template <typename ForEachStatFunc>
void ForEachStat(ForEachStatFunc&& for_each_stat) const {
for (const XStat& stat : event_->stats()) {
for_each_stat(XStatVisitor(plane_, &stat));
}
}
bool operator<(const XEventVisitor& other) const {
return GetTimespan() < other.GetTimespan();
}
@ -155,7 +170,7 @@ class XLineVisitor {
const XLine* line_;
};
class XPlaneVisitor {
class XPlaneVisitor : public XStatsOwner<XPlane> {
public:
explicit XPlaneVisitor(const XPlane* plane);
@ -186,6 +201,17 @@ class XPlaneVisitor {
absl::flat_hash_map<StatType, const XStatMetadata*> stat_type_map_;
};
template <class T>
const XStat* XStatsOwner<T>::GetStats(StatType stat_type) const {
absl::optional<int64> stat_metadata_id =
metadata_->GetStatMetadataId(stat_type);
if (!stat_metadata_id) return nullptr; // type does not exist in the XPlane.
for (const XStat& stat : stats_owner_->stats()) {
if (stat.metadata_id() == *stat_metadata_id) return &stat;
}
return nullptr; // type does not exist in this owner.
}
} // namespace profiler
} // namespace tensorflow