From 6adc38555f66357426cd4c2028574ef00ba147cf Mon Sep 17 00:00:00 2001 From: David Norman <DavidNorman@users.noreply.github.com> Date: Wed, 3 May 2017 22:16:18 +0100 Subject: [PATCH] Allocation tracking of pre-tracked buffer returned in a tuple should increment the buffer by 2 (#9408) * Allocation tracking of pre-tracked buffer returned in a tuple should increment the buffer ref count by 2, not 1. A buffer not pre-tracked has its ref-count set to 2 when it is part of a tuple. A buffer pre-tracked has its value increased by 1, whether it is part of a tuple or not. This is an error. * Adjust range check following review * Fix silly off by 2 error --- tensorflow/compiler/xla/service/allocation_tracker.cc | 5 +++-- tensorflow/compiler/xla/service/allocation_tracker.h | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 49a621810ef..83759a7a0c6 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal( auto& allocation = FindOrDie(handle_to_allocation_, handle); int ref_count = allocation->ref_count(); CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; - allocation->increment_ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << + (ref_count + initial_ref_count); + allocation->increment_ref_count(initial_ref_count); } else { handle = next_handle_++; VLOG(2) << "ref_count: " << initial_ref_count; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index e0076800162..ebbf35b6fe8 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -63,10 +63,10 @@ class Allocation { CHECK_GE(ref_count_, 0); return ref_count_; } - void increment_ref_count() { + void increment_ref_count(int inc) { CHECK_GT(ref_count_, 0); - CHECK_LT(ref_count_, INT_MAX); - ++ref_count_; + CHECK_LE(ref_count_, INT_MAX - inc); + ref_count_ += inc; } void decrement_ref_count() { CHECK_GT(ref_count_, 0);