Merge pull request #30470 from karllessard:unit-test-fix

PiperOrigin-RevId: 257595934
This commit is contained in:
TensorFlower Gardener 2019-07-11 06:30:34 -07:00
commit a1501b002e
2 changed files with 87 additions and 56 deletions

View File

@ -179,7 +179,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
/** Builds an eager session with the selected options. */ /** Builds an eager session with the selected options. */
public EagerSession build() { public EagerSession build() {
return new EagerSession(this); return new EagerSession(this, new ReferenceQueue<Object>());
}
// For garbage-collection tests only
EagerSession buildForGcTest(ReferenceQueue<Object> gcQueue) {
return new EagerSession(this, gcQueue);
} }
private boolean async; private boolean async;
@ -344,6 +349,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
return nativeHandle; return nativeHandle;
} }
ResourceCleanupStrategy resourceCleanupStrategy() {
return resourceCleanupStrategy;
}
/** /**
* A reference to one or more allocated native resources. * A reference to one or more allocated native resources.
* *
@ -412,6 +421,10 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
*/ */
private static class NativeResourceCollector { private static class NativeResourceCollector {
NativeResourceCollector(ReferenceQueue<Object> garbageQueue) {
this.garbageQueue = garbageQueue;
}
void attach(NativeReference nativeRef) { void attach(NativeReference nativeRef) {
synchronized (nativeRefs) { synchronized (nativeRefs) {
nativeRefs.put(nativeRef, null); nativeRefs.put(nativeRef, null);
@ -484,17 +497,18 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
private final ExecutorService cleanupService = Executors.newSingleThreadExecutor(); private final ExecutorService cleanupService = Executors.newSingleThreadExecutor();
private final Map<NativeReference, Void> nativeRefs = new IdentityHashMap<>(); private final Map<NativeReference, Void> nativeRefs = new IdentityHashMap<>();
private final ReferenceQueue<Object> garbageQueue = new ReferenceQueue<>(); private final ReferenceQueue<Object> garbageQueue;
private volatile boolean cleanupInBackground = false; private volatile boolean cleanupInBackground = false;
} }
private static volatile EagerSession defaultSession = null; private static volatile EagerSession defaultSession = null;
private final NativeResourceCollector nativeResources = new NativeResourceCollector(); private final NativeResourceCollector nativeResources;
private final ResourceCleanupStrategy resourceCleanupStrategy; private final ResourceCleanupStrategy resourceCleanupStrategy;
private long nativeHandle; private long nativeHandle;
private EagerSession(Options options) { private EagerSession(Options options, ReferenceQueue<Object> garbageQueue) {
this.nativeResources = new NativeResourceCollector(garbageQueue);
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
this.resourceCleanupStrategy = options.resourceCleanupStrategy; this.resourceCleanupStrategy = options.resourceCleanupStrategy;
@ -509,11 +523,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
} }
} }
// For tests
ResourceCleanupStrategy resourceCleanupStrategy() {
return resourceCleanupStrategy;
}
private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config); private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config);
private static native void delete(long handle); private static native void delete(long handle);

View File

@ -21,8 +21,13 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -40,74 +45,67 @@ public class EagerSessionTest {
@Test @Test
public void cleanupResourceOnSessionClose() { public void cleanupResourceOnSessionClose() {
AtomicBoolean deleted = new AtomicBoolean(); TestReference ref;
try (EagerSession s = try (EagerSession s =
EagerSession.options() EagerSession.options()
.resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE) .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE)
.build()) { .build()) {
ref = new TestReference(s, new Object());
assertFalse(ref.isDeleted());
new TestReference(s, new Object(), deleted); // check that reaching safe point did not release resources
assertFalse(deleted.get());
runGC();
assertFalse(deleted.get());
buildOp(s); buildOp(s);
assertFalse(deleted.get()); // reaching safe point did not release resources assertFalse(ref.isDeleted());
} }
assertTrue(deleted.get()); assertTrue(ref.isDeleted());
} }
// TODO(b/135541743): Re-enable once fixed. @Test
// Disabled due to flakiness with -c opt --config=cuda
@Ignore
public void cleanupResourceOnSafePoints() { public void cleanupResourceOnSafePoints() {
AtomicBoolean deleted = new AtomicBoolean(); TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
try (EagerSession s = try (EagerSession s =
EagerSession.options() EagerSession.options()
.resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS) .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS)
.build()) { .buildForGcTest(gcQueue)) {
new TestReference(s, new Object(), deleted); TestReference ref = new TestReference(s, new Object());
assertFalse(ref.isDeleted());
assertFalse(deleted.get()); // garbage collecting the reference won't release until we reached safe point
runGC(); gcQueue.collect(ref);
assertFalse(deleted.get()); assertFalse(ref.isDeleted());
buildOp(s); // safe point
buildOp(s); assertTrue(ref.isDeleted());
assertTrue(deleted.get()); // reaching safe point released resources assertTrue(gcQueue.isEmpty());
} }
} }
@Test @Test
public void cleanupResourceInBackground() { public void cleanupResourceInBackground() {
AtomicBoolean deleted = new AtomicBoolean(); TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
try (EagerSession s = try (EagerSession s =
EagerSession.options() EagerSession.options()
.resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND) .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND)
.build()) { .buildForGcTest(gcQueue)) {
new TestReference(s, new Object(), deleted); TestReference ref = new TestReference(s, new Object());
assertFalse(ref.isDeleted());
assertFalse(deleted.get()); gcQueue.collect(ref);
runGC();
sleep(50); // allow some time to the background thread for cleaning up resources sleep(50); // allow some time to the background thread for cleaning up resources
assertTrue(deleted.get()); assertTrue(ref.isDeleted());
assertTrue(gcQueue.isEmpty());
} }
} }
@Test @Test
public void clearedResourcesAreNotCleanedUp() { public void clearedResourcesAreNotCleanedUp() {
AtomicBoolean deleted = new AtomicBoolean(); TestReference ref;
try (EagerSession s = EagerSession.create()) { try (EagerSession s = EagerSession.create()) {
TestReference ref = new TestReference(s, new Object(), deleted); ref = new TestReference(s, new Object());
ref.clear(); ref.clear();
} }
assertFalse(deleted.get()); assertFalse(ref.isDeleted());
} }
@Test @Test
@ -127,7 +125,7 @@ public class EagerSessionTest {
EagerSession s = EagerSession.create(); EagerSession s = EagerSession.create();
s.close(); s.close();
try { try {
new TestReference(s, new Object(), new AtomicBoolean()); new TestReference(s, new Object());
fail(); fail();
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
// ok // ok
@ -158,9 +156,8 @@ public class EagerSessionTest {
private static class TestReference extends EagerSession.NativeReference { private static class TestReference extends EagerSession.NativeReference {
TestReference(EagerSession session, Object referent, AtomicBoolean deleted) { TestReference(EagerSession session, Object referent) {
super(session, referent); super(session, referent);
this.deleted = deleted;
} }
@Override @Override
@ -170,7 +167,40 @@ public class EagerSessionTest {
} }
} }
private final AtomicBoolean deleted; boolean isDeleted() {
return deleted.get();
}
private final AtomicBoolean deleted = new AtomicBoolean();
}
private static class TestGarbageCollectorQueue extends ReferenceQueue<Object> {
@Override
public Reference<? extends Object> poll() {
return garbage.poll();
}
@Override
public Reference<? extends Object> remove() throws InterruptedException {
return garbage.take();
}
@Override
public Reference<? extends Object> remove(long timeout)
throws IllegalArgumentException, InterruptedException {
return garbage.poll(timeout, TimeUnit.MILLISECONDS);
}
void collect(TestReference ref) {
garbage.add(ref);
}
boolean isEmpty() {
return garbage.isEmpty();
}
private final BlockingQueue<TestReference> garbage = new LinkedBlockingQueue<>();
} }
private static void buildOp(EagerSession s) { private static void buildOp(EagerSession s) {
@ -182,14 +212,6 @@ public class EagerSessionTest {
} }
} }
private static void runGC() {
// Warning: There is no way to force the garbage collector to run, so here we simply to our best
// to get it triggered but it might be sufficient on some platforms. Adjust accordingly if some
// cleanup tests start to fail.
System.gc();
System.runFinalization();
}
private static void sleep(int millis) { private static void sleep(int millis) {
try { try {
Thread.sleep(millis); Thread.sleep(millis);