Add eager tensor support

This commit is contained in:
Karl Lessard 2019-05-09 00:32:51 -04:00
parent 4678439926
commit fceea7d090
11 changed files with 283 additions and 37 deletions

View File

@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
abstract long getUnsafeNativeHandle(int outputIdx);
/**
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor shape
@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
abstract long[] shape(int outputIdx);
/**
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
* Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor datatype
*/
abstract DataType dtype(int outputIdx);
/**
* Returns the tensor of the {@code outputIdx}th output of this operation.
*
* <p>This is only supported in an eager execution environment.
*
* @param outputIdx index of the output of this operation
* @return output tensor
*/
abstract Tensor<?> tensor(int outputIdx);
}

View File

@ -15,7 +15,7 @@ limitations under the License.
package org.tensorflow;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReferenceArray;
/**
* Implementation of an {@link Operation} executed eagerly.
@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
this.type = type;
this.name = name;
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
}
@Override
@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
@Override
public long[] shape(int outputIndex) {
// If the tensor of this output has already been resolved, return its shape.
// Otherwise, retrieve the tensor shape from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.shape();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) {
@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
@Override
public DataType dtype(int outputIndex) {
// If the tensor of this output has already been resolved, return its datatype.
// Otherwise, retrieve the tensor datatype from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.dataType();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataType.fromC(dataType(outputNativeHandle));
}
@Override
public Tensor<?> tensor(int outputIndex) {
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
}
return tensor;
}
private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private final AtomicReferenceArray<Tensor<?>> outputTensors;
private Tensor<?> resolveTensor(int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
tensor.close();
tensor = outputTensors.get(outputIndex);
}
return tensor;
}
private static class NativeReference extends EagerSession.NativeReference {
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
@ -92,30 +132,27 @@ class EagerOperation extends AbstractOperation {
@Override
void delete() {
if (opHandle != 0L) {
for (long tensorHandle : outputHandles) {
if (tensorHandle != 0L) {
EagerOperation.deleteTensorHandle(tensorHandle);
for (int i = 0; i < outputHandles.length; ++i) {
if (outputHandles[i] != 0L) {
EagerOperation.deleteTensorHandle(outputHandles[i]);
outputHandles[i] = 0L;
}
}
EagerOperation.delete(opHandle);
opHandle = 0L;
Arrays.fill(outputHandles, 0L);
}
}
private long opHandle;
private final long[] outputHandles;
}
private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private static native void delete(long handle);
private static native void deleteTensorHandle(long handle);
private static native long resolveTensorHandle(long handle);
private static native int outputListLength(long handle, String name);
private static native int inputListLength(long handle, String name);

View File

@ -138,6 +138,11 @@ public final class GraphOperation extends AbstractOperation {
r.close();
}
}
@Override
Tensor<?> tensor(int outputIdx) {
throw new IllegalStateException("Graph tensors must be fetched by running a session");
}
long getUnsafeNativeHandle() {
return unsafeNativeHandle;

View File

@ -47,6 +47,22 @@ public final class Output<T> implements Operand<T> {
public DataType dataType() {
return operation.dtype(index);
}
/**
* Returns the tensor at this output.
*
* <p>This operation is only supported on the outputs of an operation executed eagerly.
* For graph environments, output tensors must be fetched by running a session, using
* {@link Session.Runner#fetch(Output)}.
*
* @return tensor
* @throws IllegalStateException if this output results from a graph
* @see EagerSession
*/
@SuppressWarnings("unchecked")
public Tensor<T> tensor() {
return (Tensor<T>)operation.tensor(index);
}
@Override
public Output<T> asOutput() {

View File

@ -140,15 +140,17 @@ public final class Tensor<T> implements AutoCloseable {
Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, t.shapeCopy);
long nativeHandle;
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(t.nativeHandle, obj);
nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
} else {
t.nativeHandle = allocateScalarBytes((byte[]) obj);
nativeHandle = allocateScalarBytes((byte[]) obj);
}
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
@ -314,23 +316,22 @@ public final class Tensor<T> implements AutoCloseable {
}
Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
/**
* Release resources associated with the Tensor.
*
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
* <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
* operation or memory will be leaked.
*
* <p>The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
if (nativeHandle != 0) {
delete(nativeHandle);
nativeHandle = 0;
}
nativeRef.release();
}
/** Returns the {@link DataType} of elements stored in the Tensor. */
@ -374,7 +375,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
public float floatValue() {
return scalarFloat(nativeHandle);
return scalarFloat(getNativeHandle());
}
/**
@ -383,7 +384,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
public double doubleValue() {
return scalarDouble(nativeHandle);
return scalarDouble(getNativeHandle());
}
/**
@ -392,7 +393,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
public int intValue() {
return scalarInt(nativeHandle);
return scalarInt(getNativeHandle());
}
/**
@ -401,7 +402,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
public long longValue() {
return scalarLong(nativeHandle);
return scalarLong(getNativeHandle());
}
/**
@ -410,7 +411,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public boolean booleanValue() {
return scalarBoolean(nativeHandle);
return scalarBoolean(getNativeHandle());
}
/**
@ -419,7 +420,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public byte[] bytesValue() {
return scalarBytes(nativeHandle);
return scalarBytes(getNativeHandle());
}
/**
@ -448,7 +449,7 @@ public final class Tensor<T> implements AutoCloseable {
*/
public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
readNDArray(nativeHandle, dst);
readNDArray(getNativeHandle(), dst);
return dst;
}
@ -553,16 +554,27 @@ public final class Tensor<T> implements AutoCloseable {
@SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
t.nativeHandle = handle;
t.nativeRef = new NativeReference(handle);
return t;
}
/**
* Create an eager Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(long handle, EagerSession session) {
Tensor<?> t = fromHandle(handle);
t.nativeRef.eager(session, t);
return t;
}
long getNativeHandle() {
return nativeHandle;
return nativeRef.tensorHandle;
}
private long nativeHandle;
private DataType dtype;
private NativeReference nativeRef = null;
private final DataType dtype;
private long[] shapeCopy = null;
private Tensor(DataType t) {
@ -570,7 +582,7 @@ public final class Tensor<T> implements AutoCloseable {
}
private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
}
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
@ -609,6 +621,66 @@ public final class Tensor<T> implements AutoCloseable {
}
}
/**
* Reference to the underlying native tensor
*
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
* released after executing the last line of the `try` block they were declared in.
*
* <p>They can also be attached to an eager session, where in this case their lifetime ends either when
* this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
*
* <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to
* the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
* being explicetly closed before this happens, it will take cake of clearing its association with any eager
* session before cleaning up the resources.
*/
private static class NativeReference {
/**
* Attaches this reference to an eager session
*/
private class EagerReference extends EagerSession.NativeReference {
EagerReference(EagerSession session, Tensor<?> tensor) {
super(session, tensor);
}
@Override
void delete() {
// Mark this eager reference as cleared since it has been deleted by the session
NativeReference.this.eagerRef = null;
NativeReference.this.release();
}
}
NativeReference(long tensorHandle) {
this.tensorHandle = tensorHandle;
}
void eager(EagerSession session, Tensor<?> tensor) {
if (eagerRef != null) {
throw new IllegalStateException("The tensor is already attached to an eager session");
}
eagerRef = new EagerReference(session, tensor);
}
synchronized void release() {
if (tensorHandle != 0L) {
// Clear any remaining eager reference to this tensor
if (eagerRef != null) {
eagerRef.clear();
eagerRef = null;
}
Tensor.delete(tensorHandle);
tensorHandle = 0L;
}
}
private long tensorHandle;
private EagerReference eagerRef;
}
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
static {

View File

@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
JNIEnv* env, jclass clazz, jlong handle) {
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
if (tensor_handle == nullptr) return 0;
TF_Status* status = TF_NewStatus();
TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
"Cannot represent a C TF_Tensor as a Java long");
return reinterpret_cast<jlong>(tensor);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Op* op = requireOp(env, handle);

View File

@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
JNIEnv *, jclass, jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: resolveTensorHandle
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
JNIEnv *, jclass, jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: outputListLength

View File

@ -54,6 +54,22 @@ public class EagerOperationTest {
}
}
@Test
public void outputTensor() {
try (EagerSession session = EagerSession.create()) {
EagerOperation add = opBuilder(session, "Add", "CompareResult")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
assertEquals(6, add.tensor(0).intValue());
// Validate that we retrieve the right shape and datatype from the tensor
// that has been resolved
assertEquals(0, add.shape(0).length);
assertEquals(DataType.INT32, add.dtype(0));
}
}
@Test
public void inputAndOutputListLengths() {
try (EagerSession session = EagerSession.create()) {
@ -105,7 +121,7 @@ public class EagerOperationTest {
@Test
public void opNotAccessibleIfSessionIsClosed() {
EagerSession session = EagerSession.create();
EagerOperation add = opBuilder(session, "Add", "SetDevice")
EagerOperation add = opBuilder(session, "Add", "SessionClosed")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
@ -119,6 +135,40 @@ public class EagerOperationTest {
}
}
@Test
public void outputIndexOutOfBounds() {
try (EagerSession session = EagerSession.create()) {
EagerOperation add = opBuilder(session, "Add", "OutOfRange")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
try {
add.getUnsafeNativeHandle(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.shape(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.dtype(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.tensor(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
}
}
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
return new EagerOperationBuilder(session, type, name);
}

View File

@ -166,6 +166,17 @@ public class GraphOperationTest {
}
}
}
@Test
public void outputTensorNotSupported() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
try {
split.output(0).tensor();
fail();
} catch (IllegalStateException e) {}
}
}
private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) {

View File

@ -18,6 +18,7 @@ package org.tensorflow;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -519,6 +521,25 @@ public class TensorTest {
// The expected exception.
}
}
@Test
public void eagerTensorIsReleasedAfterSessionIsClosed() {
Tensor<Integer> sum;
try (EagerSession session = EagerSession.create()) {
Output<?> x = TestUtil.constant(session, "Const1", 10);
Output<?> y = TestUtil.constant(session, "Const2", 20);
sum = TestUtil.<Integer>addN(session, x, y).tensor();
assertNotEquals(0L, sum.getNativeHandle());
assertEquals(30, sum.intValue());
}
assertEquals(0L, sum.getNativeHandle());
try {
sum.intValue();
fail();
} catch (NullPointerException e) {
// expected.
}
}
@Test
public void fromHandle() {

View File

@ -67,8 +67,8 @@ public class TestUtil {
.<T>output(0);
}
public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) {
return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
public static <T> Output<T> matmul(