From 06668ca549ad2534eaa5297151deb149e6256125 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 16 May 2019 18:24:47 -0400 Subject: [PATCH] Add default eager session --- .../processor/OperatorProcessor.java | 24 ++- .../java/org/tensorflow/EagerSession.java | 154 +++++++++++++++--- .../java/org/tensorflow/EagerSessionTest.java | 23 +++ 3 files changed, 171 insertions(+), 30 deletions(-) diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index e3e40a17df5..20237c86ce1 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -53,11 +53,11 @@ import com.squareup.javapoet.FieldSpec; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.WildcardTypeName; +import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; /** * A compile-time Processor that aggregates classes annotated with {@link @@ -159,6 +159,8 @@ public final class OperatorProcessor extends AbstractProcessor { private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment"); + private static final TypeName T_EAGER_SESSION = + ClassName.get("org.tensorflow", "EagerSession"); private static final TypeName T_STRING = ClassName.get(String.class); // Operand private static final TypeName T_OPERAND = @@ -359,13 +361,13 @@ public final class OperatorProcessor extends AbstractProcessor { + " Operand four = ops.constant(4);\n" + " // Most builders are found within a group, and accept\n" + " // Operand types as operands\n" - + " Operand nine = ops.math().add(four, ops.constant(5));\n" + + " Operand nine = ops.math.add(four, ops.constant(5));\n" + " // Multi-result operations however offer methods to\n" + " // select a particular result for use.\n" + " Operand result = \n" - + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " ops.math.add(ops.unique(s, a).y(), b);\n" + " // Optional attributes\n" - + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " ops.linalg.matMul(a, b, MatMul.transposeA(true));\n" + " // Naming operators\n" + " ops.withName(\"foo\").constant(5); // name \"foo\"\n" + " // Names can exist in a hierarchy\n" @@ -446,7 +448,17 @@ public final class OperatorProcessor extends AbstractProcessor { .addParameter(T_EXEC_ENV, "env") .returns(T_OPS) .addStatement("return new Ops(new $T(env))", T_SCOPE) - .addJavadoc("Creates an API for building operations in the provided environment\n") + .addJavadoc("Creates an API for building operations in the provided execution environment\n") + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .returns(T_OPS) + .addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION) + .addJavadoc( + "Creates an API for building operations in the default eager execution environment\n\n" + + "

Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n") .build()); return opsBuilder.build(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java index 7f36da173e6..1e7e97146ce 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java @@ -34,19 +34,6 @@ import java.util.concurrent.Executors; * standard programming library. * *

Instances of a {@code EagerSession} are thread-safe. - * - *

WARNING: Resources consumed by an {@code EagerSession} object must be explicitly freed - * by invoking the {@link #close()} method when it is no longer needed. This could be achieve using - * the `try-with-resources` technique as the example below: - * - *

{@code
- * try (EagerSession s = EagerSession.create()) {
- *    // execute operations eagerly
- * }
- * }
- * - * In addition, {@code EagerSession} objects clean up unused resources during the session, working - * in pair with the JVM garbage collector. See {@link ResourceCleanupStrategy} for more details. */ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { @@ -208,27 +195,130 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { } } - /** Returns an object that configures and builds a {@code EagerSession} with custom options. */ - public static EagerSession.Options options() { - return new Options(); + /** + * Initializes the default eager session, which remains active for the lifetime of the application. + * + *

This method is implicitly invoked on the first call to {@link #getDefault()}, + * but can also be invoked explicitly to override default options. + * + *

Note that calling this method more than once will throw an {@code IllegalArgumentException} + * as the default session cannot be modified once it has been created. Therefore, it is important + * to explicitly initialize it before {@link #getDefault()} is invoked for the first time from + * any thread. + * + *

Example usage: + * + *

{@code
+   * // Initializing default session to override default options is valid but 
+   * // is optional
+   * EagerSession.initDefault(EagerSession.options().async(true)); 
+   * 
+   * // Starting to build eager operations using default session, by calling 
+   * // EagerSession.getDefault() implicitly
+   * Ops tf = Ops.create();
+   * 
+   * // Initializing default session more than once or after using it is not
+   * // permitted and throws an exception
+   * EagerSession.initDefault(EagerSession.options().async(true));  // throws
+   * }
+ * + * @param options options to use to build default session + * @return default eager session + * @throws IllegalStateException if the default session is already initialized + * @see {@link #getDefault()} + */ + public static EagerSession initDefault(Options options) { + synchronized(EagerSession.class) { + if (defaultSession != null) { + throw new IllegalStateException("Default eager session is already initialized"); + } + defaultSession = options.build(); + } + return defaultSession; } - /** Returns an {@code EagerSession} configured with default options. */ + /** + * Returns the default eager session + * + *

Once initialized, the default eager session remains active for the whole life of the + * application, as opposed to sessions obtained from {@link #create()} or {@link Options#build()} + * which should be closed after their usage. + * + *

The default set of {@link Options} is used to initialize the session on the first call. To + * override this behavior, it is possible to invoke {@link #initDefault(Options)} with a different + * set of options prior to this first call. + * + *

Example usage: + * + *

{@code
+   * // Starting to build eager operations using default session, by calling 
+   * // EagerSession.getDefault() implicitly
+   * Ops tf = Ops.create();
+   * 
+   * // Starting to build eager operations using default session, by calling 
+   * // EagerSession.getDefault() explictly
+   * Ops tf = Ops.create(EagerSession.getDefault());
+   * }
+ * + * @return default eager session + * @see {@link #initDefault(Options)} + */ + public static EagerSession getDefault() { + if (defaultSession == null) { + synchronized(EagerSession.class) { + if (defaultSession == null) { + defaultSession = options().build(); + } + } + } + return defaultSession; + } + + /** + * Returns an {@code EagerSession} configured with default options. + * + *

WARNING:Instances of {@code EagerSession} returned by this method must be explicitly freed + * by invoking {@link #close()} when they are no longer needed. This could be achieve using + * the `try-with-resources` technique. + * + *

Example usage: + * + *

{@code
+   * try (EagerSession session = EagerSession.create()) {
+   *   Ops tf = Ops.create(session);
+   *   // build execute operations eagerly...
+   * }
+   * }
+ */ public static EagerSession create() { return options().build(); } - private EagerSession(Options options) { - this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); - this.resourceCleanupStrategy = options.resourceCleanupStrategy; - - if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { - nativeResources.startCleanupThread(); - } + /** + * Returns an object that configures and builds a {@code EagerSession} with custom options. + * + *

WARNING:Instances of {@code EagerSession} returned by this method must be explicitly freed + * by invoking {@link #close()} when they are no longer needed. This could be achieve using + * the `try-with-resources` technique. + * + *

Example usage: + * + *

{@code
+   * try (EagerSession session = EagerSession.options().async(true).build()) {
+   *   Ops tf = Ops.create(session);
+   *   // build execute operations eagerly and asynchronously...
+   * }
+   * }
+ */ + public static EagerSession.Options options() { + return new Options(); } @Override public synchronized void close() { + if (this == defaultSession) { + throw new IllegalStateException("Default eager session cannot be closed"); + } if (nativeHandle != 0L) { if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { nativeResources.stopCleanupThread(); @@ -396,16 +486,32 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { private final ReferenceQueue garbageQueue = new ReferenceQueue<>(); private volatile boolean cleanupInBackground = false; } + + private static volatile EagerSession defaultSession = null; private final NativeResourceCollector nativeResources = new NativeResourceCollector(); private final ResourceCleanupStrategy resourceCleanupStrategy; private long nativeHandle; + private EagerSession(Options options) { + this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); + this.resourceCleanupStrategy = options.resourceCleanupStrategy; + + if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { + nativeResources.startCleanupThread(); + } + } + private void checkSession() { if (nativeHandle == 0L) { throw new IllegalStateException("Eager session has been closed"); } } + + // For tests + ResourceCleanupStrategy resourceCleanupStrategy() { + return resourceCleanupStrategy; + } private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config); diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java index 77f38bb6160..78ce6540318 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java @@ -15,7 +15,9 @@ limitations under the License. package org.tensorflow; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -130,6 +132,27 @@ public class EagerSessionTest { } } + @Test + public void defaultSession() throws Exception { + EagerSession.Options options = EagerSession.options().resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE); + EagerSession.initDefault(options); + EagerSession session = EagerSession.getDefault(); + assertNotNull(session); + assertEquals(ResourceCleanupStrategy.ON_SESSION_CLOSE, session.resourceCleanupStrategy()); + try { + EagerSession.initDefault(options); + fail(); + } catch (IllegalStateException e) { + // expected + } + try { + session.close(); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + private static class TestReference extends EagerSession.NativeReference { TestReference(EagerSession session, Object referent, AtomicBoolean deleted) {