diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index e3e40a17df5..7d22a78bcb4 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -53,11 +53,11 @@ import com.squareup.javapoet.FieldSpec; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.WildcardTypeName; +import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; /** * A compile-time Processor that aggregates classes annotated with {@link @@ -159,6 +159,7 @@ public final class OperatorProcessor extends AbstractProcessor { private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment"); + private static final TypeName T_EAGER_SESSION = ClassName.get("org.tensorflow", "EagerSession"); private static final TypeName T_STRING = ClassName.get(String.class); // Operand<?> private static final TypeName T_OPERAND = @@ -359,13 +360,13 @@ public final class OperatorProcessor extends AbstractProcessor { + " Operand four = ops.constant(4);\n" + " // Most builders are found within a group, and accept\n" + " // Operand types as operands\n" - + " Operand nine = ops.math().add(four, ops.constant(5));\n" + + " Operand nine = ops.math.add(four, ops.constant(5));\n" + " // Multi-result operations however offer methods to\n" + " // select a particular result for use.\n" + " Operand result = \n" - + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " ops.math.add(ops.unique(s, a).y(), b);\n" + " // Optional attributes\n" - + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " ops.linalg.matMul(a, b, MatMul.transposeA(true));\n" + " // Naming operators\n" + " ops.withName(\"foo\").constant(5); // name \"foo\"\n" + " // Names can exist in a hierarchy\n" @@ -446,7 +447,18 @@ public final class OperatorProcessor extends AbstractProcessor { .addParameter(T_EXEC_ENV, "env") .returns(T_OPS) .addStatement("return new Ops(new $T(env))", T_SCOPE) - .addJavadoc("Creates an API for building operations in the provided environment\n") + .addJavadoc( + "Creates an API for building operations in the provided execution environment\n") + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .returns(T_OPS) + .addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION) + .addJavadoc( + "Creates an API for building operations in the default eager execution environment\n\n" + + "<p>Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n") .build()); return opsBuilder.build(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java index 7f36da173e6..d3bb43a8958 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java @@ -34,19 +34,6 @@ import java.util.concurrent.Executors; * standard programming library. * * <p>Instances of a {@code EagerSession} are thread-safe. - * - * <p><b>WARNING:</b> Resources consumed by an {@code EagerSession} object must be explicitly freed - * by invoking the {@link #close()} method when it is no longer needed. This could be achieve using - * the `try-with-resources` technique as the example below: - * - * <pre>{@code - * try (EagerSession s = EagerSession.create()) { - * // execute operations eagerly - * } - * }</pre> - * - * In addition, {@code EagerSession} objects clean up unused resources during the session, working - * in pair with the JVM garbage collector. See {@link ResourceCleanupStrategy} for more details. */ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { @@ -208,27 +195,131 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { } } - /** Returns an object that configures and builds a {@code EagerSession} with custom options. */ - public static EagerSession.Options options() { - return new Options(); + /** + * Initializes the default eager session, which remains active for the lifetime of the + * application. + * + * <p>This method is implicitly invoked on the first call to {@link #getDefault()}, but can also + * be invoked explicitly to override default options. + * + * <p>Note that calling this method more than once will throw an {@code IllegalArgumentException} + * as the default session cannot be modified once it has been created. Therefore, it is important + * to explicitly initialize it before {@link #getDefault()} is invoked for the first time from any + * thread. + * + * <p>Example usage: + * + * <pre>{@code + * // Initializing default session to override default options is valid but + * // is optional + * EagerSession.initDefault(EagerSession.options().async(true)); + * + * // Starting to build eager operations using default session, by calling + * // EagerSession.getDefault() implicitly + * Ops tf = Ops.create(); + * + * // Initializing default session more than once or after using it is not + * // permitted and throws an exception + * EagerSession.initDefault(EagerSession.options().async(true)); // throws + * }</pre> + * + * @param options options to use to build default session + * @return default eager session + * @throws IllegalStateException if the default session is already initialized + * @see {@link #getDefault()} + */ + public static EagerSession initDefault(Options options) { + synchronized (EagerSession.class) { + if (defaultSession != null) { + throw new IllegalStateException("Default eager session is already initialized"); + } + defaultSession = options.build(); + } + return defaultSession; } - /** Returns an {@code EagerSession} configured with default options. */ + /** + * Returns the default eager session + * + * <p>Once initialized, the default eager session remains active for the whole life of the + * application, as opposed to sessions obtained from {@link #create()} or {@link Options#build()} + * which should be closed after their usage. + * + * <p>The default set of {@link Options} is used to initialize the session on the first call. To + * override this behavior, it is possible to invoke {@link #initDefault(Options)} with a different + * set of options prior to this first call. + * + * <p>Example usage: + * + * <pre>{@code + * // Starting to build eager operations using default session, by calling + * // EagerSession.getDefault() implicitly + * Ops tf = Ops.create(); + * + * // Starting to build eager operations using default session, by calling + * // EagerSession.getDefault() explictly + * Ops tf = Ops.create(EagerSession.getDefault()); + * }</pre> + * + * @return default eager session + * @see {@link #initDefault(Options)} + */ + public static EagerSession getDefault() { + if (defaultSession == null) { + synchronized (EagerSession.class) { + if (defaultSession == null) { + defaultSession = options().build(); + } + } + } + return defaultSession; + } + + /** + * Returns an {@code EagerSession} configured with default options. + * + * <p><b>WARNING:</b>Instances of {@code EagerSession} returned by this method must be explicitly + * freed by invoking {@link #close()} when they are no longer needed. This could be achieve using + * the `try-with-resources` technique. + * + * <p>Example usage: + * + * <pre>{@code + * try (EagerSession session = EagerSession.create()) { + * Ops tf = Ops.create(session); + * // build execute operations eagerly... + * } + * }</pre> + */ public static EagerSession create() { return options().build(); } - private EagerSession(Options options) { - this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); - this.resourceCleanupStrategy = options.resourceCleanupStrategy; - - if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { - nativeResources.startCleanupThread(); - } + /** + * Returns an object that configures and builds a {@code EagerSession} with custom options. + * + * <p><b>WARNING:</b>Instances of {@code EagerSession} returned by this method must be explicitly + * freed by invoking {@link #close()} when they are no longer needed. This could be achieve using + * the `try-with-resources` technique. + * + * <p>Example usage: + * + * <pre>{@code + * try (EagerSession session = EagerSession.options().async(true).build()) { + * Ops tf = Ops.create(session); + * // build execute operations eagerly and asynchronously... + * } + * }</pre> + */ + public static EagerSession.Options options() { + return new Options(); } @Override public synchronized void close() { + if (this == defaultSession) { + throw new IllegalStateException("Default eager session cannot be closed"); + } if (nativeHandle != 0L) { if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { nativeResources.stopCleanupThread(); @@ -397,16 +488,32 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { private volatile boolean cleanupInBackground = false; } + private static volatile EagerSession defaultSession = null; + private final NativeResourceCollector nativeResources = new NativeResourceCollector(); private final ResourceCleanupStrategy resourceCleanupStrategy; private long nativeHandle; + private EagerSession(Options options) { + this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); + this.resourceCleanupStrategy = options.resourceCleanupStrategy; + + if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { + nativeResources.startCleanupThread(); + } + } + private void checkSession() { if (nativeHandle == 0L) { throw new IllegalStateException("Eager session has been closed"); } } + // For tests + ResourceCleanupStrategy resourceCleanupStrategy() { + return resourceCleanupStrategy; + } + private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config); private static native void delete(long handle); diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java index 77f38bb6160..99133cbd0fc 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java @@ -15,7 +15,9 @@ limitations under the License. package org.tensorflow; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -130,6 +132,28 @@ public class EagerSessionTest { } } + @Test + public void defaultSession() throws Exception { + EagerSession.Options options = + EagerSession.options().resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE); + EagerSession.initDefault(options); + EagerSession session = EagerSession.getDefault(); + assertNotNull(session); + assertEquals(ResourceCleanupStrategy.ON_SESSION_CLOSE, session.resourceCleanupStrategy()); + try { + EagerSession.initDefault(options); + fail(); + } catch (IllegalStateException e) { + // expected + } + try { + session.close(); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + private static class TestReference extends EagerSession.NativeReference { TestReference(EagerSession session, Object referent, AtomicBoolean deleted) {