diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 66e027ed8ac..2274a227f4d 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -514,6 +514,7 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc index 9a5130f63be..3b67124ef27 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc @@ -24,11 +24,13 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" @@ -424,23 +426,86 @@ void ProcessActiveAllocations(int64 peak_bytes_profile_step_id, << memory_profile->active_allocations_size(); } +struct Sample { + int64 orig_index; // original index to the snapshot. + MemoryProfileSnapshot* snapshot; +}; + +// This function samples max_num_snapshots from snapshots. We first keep the +// snapshots referenced by active_allocations in the samples. After this, if +// there is still room for more samples, we pick more from snapshots into the +// samples. Then, we sort the samples in time (so that they can be correctly +// displayed on the timeline). Finally, we need to adjust the original indices +// (to snapshots) in active_allocations to the new indices in the samples. void SampleSnapshots( int64 max_num_snapshots, - protobuf::RepeatedPtrField* snapshots) { + protobuf::RepeatedPtrField* snapshots, + protobuf::RepeatedPtrField* active_allocations) { if (snapshots->size() <= max_num_snapshots) return; - absl::c_partial_sort( - *snapshots, snapshots->begin() + max_num_snapshots, - [](const MemoryProfileSnapshot& a, const MemoryProfileSnapshot& b) { - return a.aggregation_stats().free_memory_bytes() < - b.aggregation_stats().free_memory_bytes(); - }); - snapshots->erase(snapshots->begin() + max_num_snapshots, snapshots->end()); - // Sort the memory_profile_snapshots by time_offset_ps (ascending) after - // sampling. - absl::c_sort(*snapshots, [](const MemoryProfileSnapshot& a, - const MemoryProfileSnapshot& b) { - return a.time_offset_ps() < b.time_offset_ps(); + + std::vector samples; + + // First, puts the snapshots referenced by active_allocations in samples[]. + absl::flat_hash_set allocation_snapshot_indices; + for (const auto& allocation : *active_allocations) { + auto orig_index = allocation.snapshot_index(); + if (orig_index < 0) continue; + allocation_snapshot_indices.insert(orig_index); + samples.push_back({orig_index, &(*snapshots)[orig_index]}); + if (allocation_snapshot_indices.size() >= max_num_snapshots) break; + } + + // Second, extracts remaining samples from snapshots. + int64 num_samples_remained = + max_num_snapshots - allocation_snapshot_indices.size(); + if (num_samples_remained > 0) { + std::vector remaining; + for (int64 i = 0; i < snapshots->size(); i++) { + if (allocation_snapshot_indices.contains(i)) continue; + // snapshots[i] is not yet sampled; put it in remaining[] for further + // consideration. + remaining.push_back({i, &(*snapshots)[i]}); + } + // Moves the num_samples_remained snapshots with least free bytes to the + // beginning of remaining[]. + absl::c_partial_sort( + remaining, remaining.begin() + num_samples_remained, + [](const Sample& a, const Sample& b) { + return a.snapshot->aggregation_stats().free_memory_bytes() < + b.snapshot->aggregation_stats().free_memory_bytes(); + }); + // Copies the first num_samples_remained in remaining[] to samples[]. + for (int64 i = 0; i < num_samples_remained; i++) + samples.push_back(remaining[i]); + } + + // Third, sorts samples[] in ascending order of time_offset_ps. + absl::c_sort(samples, [](const Sample& a, const Sample& b) { + return a.snapshot->time_offset_ps() < b.snapshot->time_offset_ps(); }); + + // Fourth, constructs a map from the original snapshot index to samples index. + absl::flat_hash_map index_map; + for (int64 i = 0; i < samples.size(); i++) { + index_map[samples[i].orig_index] = i; + } + + // Fifth, changes the original snapshot indices in active_allocations to the + // sample indices. + for (auto& allocation : *active_allocations) { + auto orig_index = allocation.snapshot_index(); + if (orig_index < 0) continue; + auto new_index = gtl::FindWithDefault(index_map, orig_index, -1); + allocation.set_snapshot_index(new_index); + } + + // Sixth, replaces *snapshot by samples[] + protobuf::RepeatedPtrField new_snapshots; + new_snapshots.Reserve(samples.size()); + for (const auto& sample : samples) { + *new_snapshots.Add() = std::move(*sample.snapshot); + } + *snapshots = std::move(new_snapshots); } // Post-process the memory profile to correctly update proto fields, and break @@ -478,7 +543,8 @@ void ProcessMemoryProfileProto(int64 max_num_snapshots, .peak_bytes_in_use(), allocator_memory_profile); ProcessActiveAllocations(peak_step_id, allocator_memory_profile); - SampleSnapshots(max_num_snapshots, snapshots); + SampleSnapshots(max_num_snapshots, snapshots, + allocator_memory_profile->mutable_active_allocations()); } }