From 6adc38555f66357426cd4c2028574ef00ba147cf Mon Sep 17 00:00:00 2001
From: David Norman <DavidNorman@users.noreply.github.com>
Date: Wed, 3 May 2017 22:16:18 +0100
Subject: [PATCH] Allocation tracking of pre-tracked buffer returned in a tuple
 should increment the buffer by 2 (#9408)

* Allocation tracking of pre-tracked buffer returned in a tuple should increment the buffer
ref count by 2, not 1.

A buffer not pre-tracked has its ref-count set to 2 when it is part of a tuple. A buffer
pre-tracked has its value increased by 1, whether it is part of a tuple or not.  This is
an error.

* Adjust range check following review

* Fix silly off by 2 error
---
 tensorflow/compiler/xla/service/allocation_tracker.cc | 5 +++--
 tensorflow/compiler/xla/service/allocation_tracker.h  | 6 +++---
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 49a621810ef..83759a7a0c6 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal(
     auto& allocation = FindOrDie(handle_to_allocation_, handle);
     int ref_count = allocation->ref_count();
     CHECK_GT(ref_count, 0);
-    VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1;
-    allocation->increment_ref_count();
+    VLOG(2) << "ref_count: " << ref_count << " -> " <<
+            (ref_count + initial_ref_count);
+    allocation->increment_ref_count(initial_ref_count);
   } else {
     handle = next_handle_++;
     VLOG(2) << "ref_count: " << initial_ref_count;
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index e0076800162..ebbf35b6fe8 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -63,10 +63,10 @@ class Allocation {
     CHECK_GE(ref_count_, 0);
     return ref_count_;
   }
-  void increment_ref_count() {
+  void increment_ref_count(int inc) {
     CHECK_GT(ref_count_, 0);
-    CHECK_LT(ref_count_, INT_MAX);
-    ++ref_count_;
+    CHECK_LE(ref_count_, INT_MAX - inc);
+    ref_count_ += inc;
   }
   void decrement_ref_count() {
     CHECK_GT(ref_count_, 0);