Merge pull request #28781 from karllessard:java-eager-default-session
PiperOrigin-RevId: 251058736
This commit is contained in:
commit
270305c6d5
@ -53,11 +53,11 @@ import com.squareup.javapoet.FieldSpec;
|
||||
import com.squareup.javapoet.JavaFile;
|
||||
import com.squareup.javapoet.MethodSpec;
|
||||
import com.squareup.javapoet.ParameterSpec;
|
||||
import com.squareup.javapoet.TypeName;
|
||||
import com.squareup.javapoet.ParameterizedTypeName;
|
||||
import com.squareup.javapoet.WildcardTypeName;
|
||||
import com.squareup.javapoet.TypeName;
|
||||
import com.squareup.javapoet.TypeSpec;
|
||||
import com.squareup.javapoet.TypeVariableName;
|
||||
import com.squareup.javapoet.WildcardTypeName;
|
||||
|
||||
/**
|
||||
* A compile-time Processor that aggregates classes annotated with {@link
|
||||
@ -159,6 +159,7 @@ public final class OperatorProcessor extends AbstractProcessor {
|
||||
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
|
||||
private static final TypeName T_EXEC_ENV =
|
||||
ClassName.get("org.tensorflow", "ExecutionEnvironment");
|
||||
private static final TypeName T_EAGER_SESSION = ClassName.get("org.tensorflow", "EagerSession");
|
||||
private static final TypeName T_STRING = ClassName.get(String.class);
|
||||
// Operand<?>
|
||||
private static final TypeName T_OPERAND =
|
||||
@ -359,13 +360,13 @@ public final class OperatorProcessor extends AbstractProcessor {
|
||||
+ " Operand four = ops.constant(4);\n"
|
||||
+ " // Most builders are found within a group, and accept\n"
|
||||
+ " // Operand types as operands\n"
|
||||
+ " Operand nine = ops.math().add(four, ops.constant(5));\n"
|
||||
+ " Operand nine = ops.math.add(four, ops.constant(5));\n"
|
||||
+ " // Multi-result operations however offer methods to\n"
|
||||
+ " // select a particular result for use.\n"
|
||||
+ " Operand result = \n"
|
||||
+ " ops.math().add(ops.array().unique(s, a).y(), b);\n"
|
||||
+ " ops.math.add(ops.unique(s, a).y(), b);\n"
|
||||
+ " // Optional attributes\n"
|
||||
+ " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
|
||||
+ " ops.linalg.matMul(a, b, MatMul.transposeA(true));\n"
|
||||
+ " // Naming operators\n"
|
||||
+ " ops.withName(\"foo\").constant(5); // name \"foo\"\n"
|
||||
+ " // Names can exist in a hierarchy\n"
|
||||
@ -446,7 +447,18 @@ public final class OperatorProcessor extends AbstractProcessor {
|
||||
.addParameter(T_EXEC_ENV, "env")
|
||||
.returns(T_OPS)
|
||||
.addStatement("return new Ops(new $T(env))", T_SCOPE)
|
||||
.addJavadoc("Creates an API for building operations in the provided environment\n")
|
||||
.addJavadoc(
|
||||
"Creates an API for building operations in the provided execution environment\n")
|
||||
.build());
|
||||
|
||||
opsBuilder.addMethod(
|
||||
MethodSpec.methodBuilder("create")
|
||||
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
|
||||
.returns(T_OPS)
|
||||
.addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION)
|
||||
.addJavadoc(
|
||||
"Creates an API for building operations in the default eager execution environment\n\n"
|
||||
+ "<p>Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n")
|
||||
.build());
|
||||
|
||||
return opsBuilder.build();
|
||||
|
@ -34,19 +34,6 @@ import java.util.concurrent.Executors;
|
||||
* standard programming library.
|
||||
*
|
||||
* <p>Instances of a {@code EagerSession} are thread-safe.
|
||||
*
|
||||
* <p><b>WARNING:</b> Resources consumed by an {@code EagerSession} object must be explicitly freed
|
||||
* by invoking the {@link #close()} method when it is no longer needed. This could be achieve using
|
||||
* the `try-with-resources` technique as the example below:
|
||||
*
|
||||
* <pre>{@code
|
||||
* try (EagerSession s = EagerSession.create()) {
|
||||
* // execute operations eagerly
|
||||
* }
|
||||
* }</pre>
|
||||
*
|
||||
* In addition, {@code EagerSession} objects clean up unused resources during the session, working
|
||||
* in pair with the JVM garbage collector. See {@link ResourceCleanupStrategy} for more details.
|
||||
*/
|
||||
public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
|
||||
|
||||
@ -208,27 +195,131 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns an object that configures and builds a {@code EagerSession} with custom options. */
|
||||
public static EagerSession.Options options() {
|
||||
return new Options();
|
||||
/**
|
||||
* Initializes the default eager session, which remains active for the lifetime of the
|
||||
* application.
|
||||
*
|
||||
* <p>This method is implicitly invoked on the first call to {@link #getDefault()}, but can also
|
||||
* be invoked explicitly to override default options.
|
||||
*
|
||||
* <p>Note that calling this method more than once will throw an {@code IllegalArgumentException}
|
||||
* as the default session cannot be modified once it has been created. Therefore, it is important
|
||||
* to explicitly initialize it before {@link #getDefault()} is invoked for the first time from any
|
||||
* thread.
|
||||
*
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* // Initializing default session to override default options is valid but
|
||||
* // is optional
|
||||
* EagerSession.initDefault(EagerSession.options().async(true));
|
||||
*
|
||||
* // Starting to build eager operations using default session, by calling
|
||||
* // EagerSession.getDefault() implicitly
|
||||
* Ops tf = Ops.create();
|
||||
*
|
||||
* // Initializing default session more than once or after using it is not
|
||||
* // permitted and throws an exception
|
||||
* EagerSession.initDefault(EagerSession.options().async(true)); // throws
|
||||
* }</pre>
|
||||
*
|
||||
* @param options options to use to build default session
|
||||
* @return default eager session
|
||||
* @throws IllegalStateException if the default session is already initialized
|
||||
* @see {@link #getDefault()}
|
||||
*/
|
||||
public static EagerSession initDefault(Options options) {
|
||||
synchronized (EagerSession.class) {
|
||||
if (defaultSession != null) {
|
||||
throw new IllegalStateException("Default eager session is already initialized");
|
||||
}
|
||||
defaultSession = options.build();
|
||||
}
|
||||
return defaultSession;
|
||||
}
|
||||
|
||||
/** Returns an {@code EagerSession} configured with default options. */
|
||||
/**
|
||||
* Returns the default eager session
|
||||
*
|
||||
* <p>Once initialized, the default eager session remains active for the whole life of the
|
||||
* application, as opposed to sessions obtained from {@link #create()} or {@link Options#build()}
|
||||
* which should be closed after their usage.
|
||||
*
|
||||
* <p>The default set of {@link Options} is used to initialize the session on the first call. To
|
||||
* override this behavior, it is possible to invoke {@link #initDefault(Options)} with a different
|
||||
* set of options prior to this first call.
|
||||
*
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* // Starting to build eager operations using default session, by calling
|
||||
* // EagerSession.getDefault() implicitly
|
||||
* Ops tf = Ops.create();
|
||||
*
|
||||
* // Starting to build eager operations using default session, by calling
|
||||
* // EagerSession.getDefault() explictly
|
||||
* Ops tf = Ops.create(EagerSession.getDefault());
|
||||
* }</pre>
|
||||
*
|
||||
* @return default eager session
|
||||
* @see {@link #initDefault(Options)}
|
||||
*/
|
||||
public static EagerSession getDefault() {
|
||||
if (defaultSession == null) {
|
||||
synchronized (EagerSession.class) {
|
||||
if (defaultSession == null) {
|
||||
defaultSession = options().build();
|
||||
}
|
||||
}
|
||||
}
|
||||
return defaultSession;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an {@code EagerSession} configured with default options.
|
||||
*
|
||||
* <p><b>WARNING:</b>Instances of {@code EagerSession} returned by this method must be explicitly
|
||||
* freed by invoking {@link #close()} when they are no longer needed. This could be achieve using
|
||||
* the `try-with-resources` technique.
|
||||
*
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* try (EagerSession session = EagerSession.create()) {
|
||||
* Ops tf = Ops.create(session);
|
||||
* // build execute operations eagerly...
|
||||
* }
|
||||
* }</pre>
|
||||
*/
|
||||
public static EagerSession create() {
|
||||
return options().build();
|
||||
}
|
||||
|
||||
private EagerSession(Options options) {
|
||||
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
|
||||
this.resourceCleanupStrategy = options.resourceCleanupStrategy;
|
||||
|
||||
if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) {
|
||||
nativeResources.startCleanupThread();
|
||||
}
|
||||
/**
|
||||
* Returns an object that configures and builds a {@code EagerSession} with custom options.
|
||||
*
|
||||
* <p><b>WARNING:</b>Instances of {@code EagerSession} returned by this method must be explicitly
|
||||
* freed by invoking {@link #close()} when they are no longer needed. This could be achieve using
|
||||
* the `try-with-resources` technique.
|
||||
*
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* try (EagerSession session = EagerSession.options().async(true).build()) {
|
||||
* Ops tf = Ops.create(session);
|
||||
* // build execute operations eagerly and asynchronously...
|
||||
* }
|
||||
* }</pre>
|
||||
*/
|
||||
public static EagerSession.Options options() {
|
||||
return new Options();
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void close() {
|
||||
if (this == defaultSession) {
|
||||
throw new IllegalStateException("Default eager session cannot be closed");
|
||||
}
|
||||
if (nativeHandle != 0L) {
|
||||
if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) {
|
||||
nativeResources.stopCleanupThread();
|
||||
@ -397,16 +488,32 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
|
||||
private volatile boolean cleanupInBackground = false;
|
||||
}
|
||||
|
||||
private static volatile EagerSession defaultSession = null;
|
||||
|
||||
private final NativeResourceCollector nativeResources = new NativeResourceCollector();
|
||||
private final ResourceCleanupStrategy resourceCleanupStrategy;
|
||||
private long nativeHandle;
|
||||
|
||||
private EagerSession(Options options) {
|
||||
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
|
||||
this.resourceCleanupStrategy = options.resourceCleanupStrategy;
|
||||
|
||||
if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) {
|
||||
nativeResources.startCleanupThread();
|
||||
}
|
||||
}
|
||||
|
||||
private void checkSession() {
|
||||
if (nativeHandle == 0L) {
|
||||
throw new IllegalStateException("Eager session has been closed");
|
||||
}
|
||||
}
|
||||
|
||||
// For tests
|
||||
ResourceCleanupStrategy resourceCleanupStrategy() {
|
||||
return resourceCleanupStrategy;
|
||||
}
|
||||
|
||||
private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config);
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
@ -130,6 +132,28 @@ public class EagerSessionTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void defaultSession() throws Exception {
|
||||
EagerSession.Options options =
|
||||
EagerSession.options().resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE);
|
||||
EagerSession.initDefault(options);
|
||||
EagerSession session = EagerSession.getDefault();
|
||||
assertNotNull(session);
|
||||
assertEquals(ResourceCleanupStrategy.ON_SESSION_CLOSE, session.resourceCleanupStrategy());
|
||||
try {
|
||||
EagerSession.initDefault(options);
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
try {
|
||||
session.close();
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestReference extends EagerSession.NativeReference {
|
||||
|
||||
TestReference(EagerSession session, Object referent, AtomicBoolean deleted) {
|
||||
|
Loading…
Reference in New Issue
Block a user