diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 49a621810ef..83759a7a0c6 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal( auto& allocation = FindOrDie(handle_to_allocation_, handle); int ref_count = allocation->ref_count(); CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; - allocation->increment_ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << + (ref_count + initial_ref_count); + allocation->increment_ref_count(initial_ref_count); } else { handle = next_handle_++; VLOG(2) << "ref_count: " << initial_ref_count; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index e0076800162..ebbf35b6fe8 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -63,10 +63,10 @@ class Allocation { CHECK_GE(ref_count_, 0); return ref_count_; } - void increment_ref_count() { + void increment_ref_count(int inc) { CHECK_GT(ref_count_, 0); - CHECK_LT(ref_count_, INT_MAX); - ++ref_count_; + CHECK_LE(ref_count_, INT_MAX - inc); + ref_count_ += inc; } void decrement_ref_count() { CHECK_GT(ref_count_, 0);