diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java index d3bb43a8958..cda6156be33 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java @@ -179,7 +179,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { /** Builds an eager session with the selected options. */ public EagerSession build() { - return new EagerSession(this); + return new EagerSession(this, new ReferenceQueue()); + } + + // For garbage-collection tests only + EagerSession buildForGcTest(ReferenceQueue gcQueue) { + return new EagerSession(this, gcQueue); } private boolean async; @@ -344,6 +349,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { return nativeHandle; } + ResourceCleanupStrategy resourceCleanupStrategy() { + return resourceCleanupStrategy; + } + /** * A reference to one or more allocated native resources. * @@ -411,6 +420,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { * longer needed. */ private static class NativeResourceCollector { + + NativeResourceCollector(ReferenceQueue garbageQueue) { + this.garbageQueue = garbageQueue; + } void attach(NativeReference nativeRef) { synchronized (nativeRefs) { @@ -484,17 +497,18 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { private final ExecutorService cleanupService = Executors.newSingleThreadExecutor(); private final Map nativeRefs = new IdentityHashMap<>(); - private final ReferenceQueue garbageQueue = new ReferenceQueue<>(); + private final ReferenceQueue garbageQueue; private volatile boolean cleanupInBackground = false; } private static volatile EagerSession defaultSession = null; - private final NativeResourceCollector nativeResources = new NativeResourceCollector(); + private final NativeResourceCollector nativeResources; private final ResourceCleanupStrategy resourceCleanupStrategy; private long nativeHandle; - private EagerSession(Options options) { + private EagerSession(Options options, ReferenceQueue garbageQueue) { + this.nativeResources = new NativeResourceCollector(garbageQueue); this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); this.resourceCleanupStrategy = options.resourceCleanupStrategy; @@ -509,11 +523,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { } } - // For tests - ResourceCleanupStrategy resourceCleanupStrategy() { - return resourceCleanupStrategy; - } - private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config); private static native void delete(long handle); diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java index 7db1cecb943..b4f50c6e7c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java @@ -21,8 +21,13 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Ignore; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -40,74 +45,67 @@ public class EagerSessionTest { @Test public void cleanupResourceOnSessionClose() { - AtomicBoolean deleted = new AtomicBoolean(); - + TestReference ref; try (EagerSession s = EagerSession.options() .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE) .build()) { + ref = new TestReference(s, new Object()); + assertFalse(ref.isDeleted()); - new TestReference(s, new Object(), deleted); - - assertFalse(deleted.get()); - runGC(); - assertFalse(deleted.get()); - + // check that reaching safe point did not release resources buildOp(s); - assertFalse(deleted.get()); // reaching safe point did not release resources + assertFalse(ref.isDeleted()); } - assertTrue(deleted.get()); + assertTrue(ref.isDeleted()); } - // TODO(b/135541743): Re-enable once fixed. - // Disabled due to flakiness with -c opt --config=cuda - @Ignore + @Test public void cleanupResourceOnSafePoints() { - AtomicBoolean deleted = new AtomicBoolean(); - + TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue(); try (EagerSession s = EagerSession.options() .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS) - .build()) { + .buildForGcTest(gcQueue)) { - new TestReference(s, new Object(), deleted); - - assertFalse(deleted.get()); - runGC(); - assertFalse(deleted.get()); - - buildOp(s); - assertTrue(deleted.get()); // reaching safe point released resources + TestReference ref = new TestReference(s, new Object()); + assertFalse(ref.isDeleted()); + + // garbage collecting the reference won't release until we reached safe point + gcQueue.collect(ref); + assertFalse(ref.isDeleted()); + buildOp(s); // safe point + assertTrue(ref.isDeleted()); + assertTrue(gcQueue.isEmpty()); } } @Test public void cleanupResourceInBackground() { - AtomicBoolean deleted = new AtomicBoolean(); - + TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue(); try (EagerSession s = EagerSession.options() .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND) - .build()) { + .buildForGcTest(gcQueue)) { - new TestReference(s, new Object(), deleted); + TestReference ref = new TestReference(s, new Object()); + assertFalse(ref.isDeleted()); - assertFalse(deleted.get()); - runGC(); + gcQueue.collect(ref); sleep(50); // allow some time to the background thread for cleaning up resources - assertTrue(deleted.get()); + assertTrue(ref.isDeleted()); + assertTrue(gcQueue.isEmpty()); } } @Test public void clearedResourcesAreNotCleanedUp() { - AtomicBoolean deleted = new AtomicBoolean(); - + TestReference ref; try (EagerSession s = EagerSession.create()) { - TestReference ref = new TestReference(s, new Object(), deleted); + ref = new TestReference(s, new Object()); ref.clear(); } - assertFalse(deleted.get()); + assertFalse(ref.isDeleted()); } @Test @@ -127,7 +125,7 @@ public class EagerSessionTest { EagerSession s = EagerSession.create(); s.close(); try { - new TestReference(s, new Object(), new AtomicBoolean()); + new TestReference(s, new Object()); fail(); } catch (IllegalStateException e) { // ok @@ -158,9 +156,8 @@ public class EagerSessionTest { private static class TestReference extends EagerSession.NativeReference { - TestReference(EagerSession session, Object referent, AtomicBoolean deleted) { + TestReference(EagerSession session, Object referent) { super(session, referent); - this.deleted = deleted; } @Override @@ -169,8 +166,40 @@ public class EagerSessionTest { fail("Reference was deleted more than once"); } } + + boolean isDeleted() { + return deleted.get(); + } + + private final AtomicBoolean deleted = new AtomicBoolean(); + } + + private static class TestGarbageCollectorQueue extends ReferenceQueue { - private final AtomicBoolean deleted; + @Override + public Reference poll() { + return garbage.poll(); + } + + @Override + public Reference remove() throws InterruptedException { + return garbage.take(); + } + + @Override + public Reference remove(long timeout) throws IllegalArgumentException, InterruptedException { + return garbage.poll(timeout, TimeUnit.MILLISECONDS); + } + + void collect(TestReference ref) { + garbage.add(ref); + } + + boolean isEmpty() { + return garbage.isEmpty(); + } + + private final BlockingQueue garbage = new LinkedBlockingQueue<>(); } private static void buildOp(EagerSession s) { @@ -182,14 +211,6 @@ public class EagerSessionTest { } } - private static void runGC() { - // Warning: There is no way to force the garbage collector to run, so here we simply to our best - // to get it triggered but it might be sufficient on some platforms. Adjust accordingly if some - // cleanup tests start to fail. - System.gc(); - System.runFinalization(); - } - private static void sleep(int millis) { try { Thread.sleep(millis);