From 22e53cc8c4c24616419eaf7c8018f624e67f79ef Mon Sep 17 00:00:00 2001
From: Karl Lessard <karl@kubx.ca>
Date: Sun, 16 Jun 2019 22:11:57 -0400
Subject: [PATCH] Fix eager session testing with GC

---
 .../java/org/tensorflow/EagerSession.java     |  27 ++--
 .../java/org/tensorflow/EagerSessionTest.java | 117 +++++++++++-------
 2 files changed, 87 insertions(+), 57 deletions(-)

diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
index d3bb43a8958..cda6156be33 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java
@@ -179,7 +179,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
 
     /** Builds an eager session with the selected options. */
     public EagerSession build() {
-      return new EagerSession(this);
+      return new EagerSession(this, new ReferenceQueue<Object>());
+    }
+    
+    // For garbage-collection tests only
+    EagerSession buildForGcTest(ReferenceQueue<Object> gcQueue) {
+      return new EagerSession(this, gcQueue);
     }
 
     private boolean async;
@@ -344,6 +349,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
     return nativeHandle;
   }
 
+  ResourceCleanupStrategy resourceCleanupStrategy() {
+    return resourceCleanupStrategy;
+  }
+
   /**
    * A reference to one or more allocated native resources.
    *
@@ -411,6 +420,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
    * longer needed.
    */
   private static class NativeResourceCollector {
+    
+    NativeResourceCollector(ReferenceQueue<Object> garbageQueue) {
+      this.garbageQueue = garbageQueue;
+    }
 
     void attach(NativeReference nativeRef) {
       synchronized (nativeRefs) {
@@ -484,17 +497,18 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
 
     private final ExecutorService cleanupService = Executors.newSingleThreadExecutor();
     private final Map<NativeReference, Void> nativeRefs = new IdentityHashMap<>();
-    private final ReferenceQueue<Object> garbageQueue = new ReferenceQueue<>();
+    private final ReferenceQueue<Object> garbageQueue;
     private volatile boolean cleanupInBackground = false;
   }
 
   private static volatile EagerSession defaultSession = null;
 
-  private final NativeResourceCollector nativeResources = new NativeResourceCollector();
+  private final NativeResourceCollector nativeResources;
   private final ResourceCleanupStrategy resourceCleanupStrategy;
   private long nativeHandle;
 
-  private EagerSession(Options options) {
+  private EagerSession(Options options, ReferenceQueue<Object> garbageQueue) {
+    this.nativeResources = new NativeResourceCollector(garbageQueue);
     this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
     this.resourceCleanupStrategy = options.resourceCleanupStrategy;
 
@@ -509,11 +523,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
     }
   }
 
-  // For tests
-  ResourceCleanupStrategy resourceCleanupStrategy() {
-    return resourceCleanupStrategy;
-  }
-
   private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config);
 
   private static native void delete(long handle);
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java
index 7db1cecb943..b4f50c6e7c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java
@@ -21,8 +21,13 @@ import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
+import java.lang.ref.Reference;
+import java.lang.ref.ReferenceQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
-import org.junit.Ignore;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -40,74 +45,67 @@ public class EagerSessionTest {
 
   @Test
   public void cleanupResourceOnSessionClose() {
-    AtomicBoolean deleted = new AtomicBoolean();
-
+    TestReference ref;
     try (EagerSession s =
         EagerSession.options()
             .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE)
             .build()) {
+      ref = new TestReference(s, new Object());
+      assertFalse(ref.isDeleted());
 
-      new TestReference(s, new Object(), deleted);
-
-      assertFalse(deleted.get());
-      runGC();
-      assertFalse(deleted.get());
-
+      // check that reaching safe point did not release resources
       buildOp(s);
-      assertFalse(deleted.get()); // reaching safe point did not release resources
+      assertFalse(ref.isDeleted());
     }
-    assertTrue(deleted.get());
+    assertTrue(ref.isDeleted());
   }
 
-  // TODO(b/135541743): Re-enable once fixed.
-  // Disabled due to flakiness with -c opt --config=cuda
-  @Ignore
+  @Test
   public void cleanupResourceOnSafePoints() {
-    AtomicBoolean deleted = new AtomicBoolean();
-
+    TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
     try (EagerSession s =
         EagerSession.options()
             .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS)
-            .build()) {
+            .buildForGcTest(gcQueue)) {
 
-      new TestReference(s, new Object(), deleted);
-
-      assertFalse(deleted.get());
-      runGC();
-      assertFalse(deleted.get());
-
-      buildOp(s);
-      assertTrue(deleted.get()); // reaching safe point released resources
+      TestReference ref = new TestReference(s, new Object());
+      assertFalse(ref.isDeleted());
+      
+      // garbage collecting the reference won't release until we reached safe point
+      gcQueue.collect(ref);
+      assertFalse(ref.isDeleted());
+      buildOp(s); // safe point
+      assertTrue(ref.isDeleted());
+      assertTrue(gcQueue.isEmpty());
     }
   }
 
   @Test
   public void cleanupResourceInBackground() {
-    AtomicBoolean deleted = new AtomicBoolean();
-
+    TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
     try (EagerSession s =
         EagerSession.options()
             .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND)
-            .build()) {
+            .buildForGcTest(gcQueue)) {
 
-      new TestReference(s, new Object(), deleted);
+      TestReference ref = new TestReference(s, new Object());
+      assertFalse(ref.isDeleted());
 
-      assertFalse(deleted.get());
-      runGC();
+      gcQueue.collect(ref);
       sleep(50); // allow some time to the background thread for cleaning up resources
-      assertTrue(deleted.get());
+      assertTrue(ref.isDeleted());
+      assertTrue(gcQueue.isEmpty());
     }
   }
 
   @Test
   public void clearedResourcesAreNotCleanedUp() {
-    AtomicBoolean deleted = new AtomicBoolean();
-
+    TestReference ref;
     try (EagerSession s = EagerSession.create()) {
-      TestReference ref = new TestReference(s, new Object(), deleted);
+      ref = new TestReference(s, new Object());
       ref.clear();
     }
-    assertFalse(deleted.get());
+    assertFalse(ref.isDeleted());
   }
 
   @Test
@@ -127,7 +125,7 @@ public class EagerSessionTest {
     EagerSession s = EagerSession.create();
     s.close();
     try {
-      new TestReference(s, new Object(), new AtomicBoolean());
+      new TestReference(s, new Object());
       fail();
     } catch (IllegalStateException e) {
       // ok
@@ -158,9 +156,8 @@ public class EagerSessionTest {
 
   private static class TestReference extends EagerSession.NativeReference {
 
-    TestReference(EagerSession session, Object referent, AtomicBoolean deleted) {
+    TestReference(EagerSession session, Object referent) {
       super(session, referent);
-      this.deleted = deleted;
     }
 
     @Override
@@ -169,8 +166,40 @@ public class EagerSessionTest {
         fail("Reference was deleted more than once");
       }
     }
+    
+    boolean isDeleted() {
+      return deleted.get();
+    }
+    
+    private final AtomicBoolean deleted = new AtomicBoolean();
+  }
+  
+  private static class TestGarbageCollectorQueue extends ReferenceQueue<Object> {
 
-    private final AtomicBoolean deleted;
+    @Override
+    public Reference<? extends Object> poll() {
+      return garbage.poll();
+    }
+
+    @Override
+    public Reference<? extends Object> remove() throws InterruptedException {
+      return garbage.take();
+    }
+
+    @Override
+    public Reference<? extends Object> remove(long timeout) throws IllegalArgumentException, InterruptedException {
+      return garbage.poll(timeout, TimeUnit.MILLISECONDS);
+    }
+    
+    void collect(TestReference ref) {
+      garbage.add(ref);
+    }
+    
+    boolean isEmpty() {
+      return garbage.isEmpty();
+    }
+
+    private final BlockingQueue<TestReference> garbage = new LinkedBlockingQueue<>();
   }
 
   private static void buildOp(EagerSession s) {
@@ -182,14 +211,6 @@ public class EagerSessionTest {
     }
   }
 
-  private static void runGC() {
-    // Warning: There is no way to force the garbage collector to run, so here we simply to our best
-    // to get it triggered but it might be sufficient on some platforms. Adjust accordingly if some
-    // cleanup tests start to fail.
-    System.gc();
-    System.runFinalization();
-  }
-
   private static void sleep(int millis) {
     try {
       Thread.sleep(millis);