diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 94f1b1f2044..b467e7b311e 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -82,14 +82,18 @@ class MklSmallSizeAllocator : public Allocator { port::AlignedFree(ptr); } - void GetStats(AllocatorStats* stats) override { + absl::optional GetStats() override { mutex_lock l(mutex_); - *stats = stats_; + return stats_; } void ClearStats() override { mutex_lock l(mutex_); - stats_.Clear(); + stats_.num_allocs = 0; + stats_.peak_bytes_in_use = 0; + stats_.largest_alloc_size = 0; + stats_.bytes_in_use = 0; + stats_.bytes_limit = 0; } private: @@ -98,10 +102,10 @@ class MklSmallSizeAllocator : public Allocator { mutex_lock l(mutex_); ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; - stats_.max_bytes_in_use = - std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); - stats_.max_alloc_size = - std::max(alloc_size, static_cast(stats_.max_alloc_size)); + stats_.peak_bytes_in_use = + std::max(stats_.peak_bytes_in_use, stats_.bytes_in_use); + stats_.largest_alloc_size = + std::max(alloc_size, static_cast(stats_.largest_alloc_size)); } // Decrement statistics for the allocator handling small allocations. @@ -244,22 +248,22 @@ class MklCPUAllocator : public Allocator { } } - void GetStats(AllocatorStats* stats) override { - AllocatorStats l_stats, s_stats; - small_size_allocator_->GetStats(&s_stats); - large_size_allocator_->GetStats(&l_stats); + absl::optional GetStats() override { + auto s_stats = small_size_allocator_->GetStats(); + auto l_stats = large_size_allocator_->GetStats(); // Combine statistics from small-size and large-size allocator. - stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; - stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; - stats->max_bytes_in_use = - l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; + stats_.num_allocs = l_stats->num_allocs + s_stats->num_allocs; + stats_.bytes_in_use = l_stats->bytes_in_use + s_stats->bytes_in_use; + stats_.peak_bytes_in_use = + l_stats->peak_bytes_in_use + s_stats->peak_bytes_in_use; // Since small-size allocations go to MklSmallSizeAllocator, // max_alloc_size from large_size_allocator would be the maximum // size allocated by MklCPUAllocator. - stats->max_alloc_size = l_stats.max_alloc_size; - stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit); + stats_.largest_alloc_size = l_stats->largest_alloc_size; + stats_.bytes_limit = std::max(s_stats->bytes_limit, l_stats->bytes_limit); + return stats_; } void ClearStats() override { @@ -308,6 +312,7 @@ class MklCPUAllocator : public Allocator { SubAllocator* sub_allocator_; // not owned by this class mutable mutex mutex_; + AllocatorStats stats_ GUARDED_BY(mutex_); // Hash map to keep track of "BFC" allocations // We do not use BFC allocator for small allocations. diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc index e08ab576385..ee1d9cd281b 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc @@ -24,22 +24,21 @@ limitations under the License. namespace tensorflow { TEST(MKLBFCAllocatorTest, TestMaxLimit) { - AllocatorStats stats; setenv(MklCPUAllocator::kMaxLimitStr, "1000", 1); MklCPUAllocator a; TF_EXPECT_OK(a.Initialize()); - a.GetStats(&stats); - EXPECT_EQ(stats.bytes_limit, 1000); + auto stats = a.GetStats(); + EXPECT_EQ(stats->bytes_limit, 1000); unsetenv(MklCPUAllocator::kMaxLimitStr); TF_EXPECT_OK(a.Initialize()); - a.GetStats(&stats); + stats = a.GetStats(); uint64 max_mem_bytes = MklCPUAllocator::kDefaultMaxLimit; #if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) max_mem_bytes = (uint64)sysconf(_SC_PHYS_PAGES) * (uint64)sysconf(_SC_PAGESIZE); #endif - EXPECT_EQ(stats.bytes_limit, max_mem_bytes); + EXPECT_EQ(stats->bytes_limit, max_mem_bytes); setenv(MklCPUAllocator::kMaxLimitStr, "wrong-input", 1); EXPECT_TRUE(errors::IsInvalidArgument(a.Initialize()));