Use the number of occurrences when combining the memory access breakdown.
PiperOrigin-RevId: 314407603 Change-Id: Iaec9ad2b121035a1dfed45bec8565a578f1acf4f
This commit is contained in:
parent
67edd1327a
commit
6425351101
@ -57,6 +57,7 @@ cc_library(
|
||||
hdrs = ["op_metrics_db_combiner.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
|
||||
"//tensorflow/core/profiler/utils:op_metrics_db_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -25,31 +25,6 @@ namespace {
|
||||
|
||||
using OperationType = OpMetrics::MemoryAccessed::OperationType;
|
||||
|
||||
void CombineMemoryAccessedBreakdown(const OpMetrics& src, OpMetrics* dst) {
|
||||
absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
|
||||
OpMetrics_MemoryAccessed*>
|
||||
dst_memory_accessed_map;
|
||||
for (auto& dst_memory_accessed : *dst->mutable_memory_accessed_breakdown()) {
|
||||
dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
|
||||
dst_memory_accessed.operation_type()}] =
|
||||
&dst_memory_accessed;
|
||||
}
|
||||
for (const auto& src_memory_accessed : src.memory_accessed_breakdown()) {
|
||||
uint64 memory_space = src_memory_accessed.memory_space();
|
||||
OperationType operation_type = src_memory_accessed.operation_type();
|
||||
auto*& dst_memory_accessed =
|
||||
dst_memory_accessed_map[{memory_space, operation_type}];
|
||||
if (dst_memory_accessed == nullptr) {
|
||||
dst_memory_accessed = dst->add_memory_accessed_breakdown();
|
||||
dst_memory_accessed->set_memory_space(memory_space);
|
||||
dst_memory_accessed->set_operation_type(operation_type);
|
||||
}
|
||||
dst_memory_accessed->set_bytes_accessed(
|
||||
src_memory_accessed.bytes_accessed() +
|
||||
dst_memory_accessed->bytes_accessed());
|
||||
}
|
||||
}
|
||||
|
||||
// Combines the src OpMetrics into the dst OpMetrics.
|
||||
void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst) {
|
||||
DCHECK(dst != nullptr);
|
||||
@ -70,7 +45,8 @@ void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst) {
|
||||
dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps());
|
||||
dst->set_flops(src.flops() + dst->flops());
|
||||
dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed());
|
||||
CombineMemoryAccessedBreakdown(src, dst);
|
||||
CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(),
|
||||
dst->mutable_memory_accessed_breakdown());
|
||||
dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps());
|
||||
}
|
||||
|
||||
@ -81,6 +57,33 @@ void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void CombineMemoryAccessedBreakdown(
|
||||
const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
|
||||
protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst) {
|
||||
absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
|
||||
OpMetrics_MemoryAccessed*>
|
||||
dst_memory_accessed_map;
|
||||
for (auto& dst_memory_accessed : *dst) {
|
||||
dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
|
||||
dst_memory_accessed.operation_type()}] =
|
||||
&dst_memory_accessed;
|
||||
}
|
||||
for (const auto& src_memory_accessed : src) {
|
||||
uint64 memory_space = src_memory_accessed.memory_space();
|
||||
OperationType operation_type = src_memory_accessed.operation_type();
|
||||
auto*& dst_memory_accessed =
|
||||
dst_memory_accessed_map[{memory_space, operation_type}];
|
||||
if (dst_memory_accessed == nullptr) {
|
||||
dst_memory_accessed = dst->Add();
|
||||
dst_memory_accessed->set_memory_space(memory_space);
|
||||
dst_memory_accessed->set_operation_type(operation_type);
|
||||
}
|
||||
dst_memory_accessed->set_bytes_accessed(
|
||||
src_memory_accessed.bytes_accessed() +
|
||||
dst_memory_accessed->bytes_accessed());
|
||||
}
|
||||
}
|
||||
|
||||
void OpMetricsDbCombiner::Combine(const OpMetricsDb& src) {
|
||||
OpMetricsDb* dst = db();
|
||||
dst->set_total_host_infeed_enq_duration_ps(
|
||||
|
@ -16,12 +16,18 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_
|
||||
#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_
|
||||
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
// Combines the memory access breakdown.
|
||||
void CombineMemoryAccessedBreakdown(
|
||||
const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
|
||||
protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst);
|
||||
|
||||
// Helper to combine op metrics databases.
|
||||
class OpMetricsDbCombiner : public OpMetricsDbBuilder {
|
||||
public:
|
||||
|
@ -81,6 +81,8 @@ cc_library(
|
||||
":op_metrics_db_utils",
|
||||
":tf_op_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/core/profiler/convert:op_metrics_db_combiner",
|
||||
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
|
||||
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
|
||||
|
||||
@ -67,7 +68,8 @@ void DeviceOpMetricsDbBuilder::EnterOp(
|
||||
uint64 program_id, absl::string_view name, absl::string_view category,
|
||||
absl::string_view provenance, bool is_eager, uint64 occurrences,
|
||||
uint64 time_ps, uint64 children_time_ps, int64 flops, int64 bytes_accessed,
|
||||
const std::vector<OpMetrics::MemoryAccessed>& memory_accessed_breakdown) {
|
||||
const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>&
|
||||
memory_accessed_breakdown) {
|
||||
uint64 self_time_ps = time_ps - children_time_ps;
|
||||
DCHECK_GE(time_ps, self_time_ps);
|
||||
OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name);
|
||||
@ -87,9 +89,9 @@ void DeviceOpMetricsDbBuilder::EnterOp(
|
||||
op_metrics->bytes_accessed() +
|
||||
GetCappedPerf(bytes_accessed * occurrences, self_time_ps,
|
||||
peak_hbm_bw_giga_bytes_per_second_ / 1000));
|
||||
for (const auto& memory_accessed : memory_accessed_breakdown) {
|
||||
*op_metrics->add_memory_accessed_breakdown() = memory_accessed;
|
||||
}
|
||||
CombineMemoryAccessedBreakdown(
|
||||
memory_accessed_breakdown,
|
||||
op_metrics->mutable_memory_accessed_breakdown());
|
||||
db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps);
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
|
||||
@ -75,7 +76,7 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder {
|
||||
absl::string_view category, absl::string_view provenance,
|
||||
bool is_eager, uint64 occurrences, uint64 time_ps,
|
||||
uint64 children_time_ps, int64 flops, int64 bytes_accessed,
|
||||
const std::vector<OpMetrics::MemoryAccessed>&
|
||||
const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>&
|
||||
memory_accessed_breakdown = {});
|
||||
|
||||
protected:
|
||||
|
Loading…
x
Reference in New Issue
Block a user