From a1b5beec58c7d815abd2b585ac7e6d39acf05ec5 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Fri, 5 Mar 2021 02:31:01 -0800 Subject: [PATCH] Implement allocation tracking support for XLA allocated tensors. We recently found the same issue with mlir generated kernels and this ports the solution to XLA. PiperOrigin-RevId: 361103973 Change-Id: Iaa0f69638c9c34754a62ee631ffb1cd067ebd646 --- tensorflow/compiler/jit/xla_launch_util.h | 15 ++++++++++++++- tensorflow/python/kernel_tests/benchmark_test.py | 2 -- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 8b939365ee5..97b82324a7f 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -208,7 +208,20 @@ class XlaTensorBuffer : public TensorBuffer { TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { - proto->set_allocated_bytes(actual_size_); + proto->set_requested_bytes(static_cast(expected_size_)); + proto->set_allocator_name(allocator_->Name()); + proto->set_ptr(reinterpret_cast(data())); + if (allocator_->TracksAllocationSizes()) { + auto ab = static_cast(allocator_->AllocatedSize(data())); + proto->set_allocated_bytes(ab); + int64 id = allocator_->AllocationId(data()); + if (id > 0) { + proto->set_allocation_id(id); + } + if (RefCountIsOne()) { + proto->set_has_single_reference(true); + } + } } private: diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py index 3e64f9d5c15..8865c8b972b 100644 --- a/tensorflow/python/kernel_tests/benchmark_test.py +++ b/tensorflow/python/kernel_tests/benchmark_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.core.util import test_log_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import gfile @@ -127,7 +126,6 @@ class BenchmarkTest(test.TestCase): self.assertFalse(_ran_somebenchmark_2[0]) self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) - @test_util.disable_xla("b/123744455") # GPU memory is incorrect def testReportingBenchmark(self): tempdir = test.get_temp_dir() try: