Add eager tensor support

This commit is contained in:
Karl Lessard 2019-05-09 00:32:51 -04:00
parent 4678439926
commit fceea7d090
11 changed files with 283 additions and 37 deletions

View File

@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
abstract long getUnsafeNativeHandle(int outputIdx); abstract long getUnsafeNativeHandle(int outputIdx);
/** /**
* Returns the shape of the tensor of the {code outputIdx}th output of this operation. * Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
* *
* @param outputIdx index of the output of this operation * @param outputIdx index of the output of this operation
* @return output tensor shape * @return output tensor shape
@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
abstract long[] shape(int outputIdx); abstract long[] shape(int outputIdx);
/** /**
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation. * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
* *
* @param outputIdx index of the output of this operation * @param outputIdx index of the output of this operation
* @return output tensor datatype * @return output tensor datatype
*/ */
abstract DataType dtype(int outputIdx); abstract DataType dtype(int outputIdx);
/**
* Returns the tensor of the {@code outputIdx}th output of this operation.
*
* <p>This is only supported in an eager execution environment.
*
* @param outputIdx index of the output of this operation
* @return output tensor
*/
abstract Tensor<?> tensor(int outputIdx);
} }

View File

@ -15,7 +15,7 @@ limitations under the License.
package org.tensorflow; package org.tensorflow;
import java.util.Arrays; import java.util.concurrent.atomic.AtomicReferenceArray;
/** /**
* Implementation of an {@link Operation} executed eagerly. * Implementation of an {@link Operation} executed eagerly.
@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
this.type = type; this.type = type;
this.name = name; this.name = name;
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles); this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
} }
@Override @Override
@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
@Override @Override
public long[] shape(int outputIndex) { public long[] shape(int outputIndex) {
// If the tensor of this output has already been resolved, return its shape.
// Otherwise, retrieve the tensor shape from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.shape();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex); long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)]; long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) { for (int i = 0; i < shape.length; ++i) {
@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
@Override @Override
public DataType dtype(int outputIndex) { public DataType dtype(int outputIndex) {
// If the tensor of this output has already been resolved, return its datatype.
// Otherwise, retrieve the tensor datatype from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.dataType();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex); long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataType.fromC(dataType(outputNativeHandle)); return DataType.fromC(dataType(outputNativeHandle));
} }
@Override
public Tensor<?> tensor(int outputIndex) {
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
}
return tensor;
}
private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private final AtomicReferenceArray<Tensor<?>> outputTensors;
private Tensor<?> resolveTensor(int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
tensor.close();
tensor = outputTensors.get(outputIndex);
}
return tensor;
}
private static class NativeReference extends EagerSession.NativeReference { private static class NativeReference extends EagerSession.NativeReference {
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) { NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
@ -92,14 +132,14 @@ class EagerOperation extends AbstractOperation {
@Override @Override
void delete() { void delete() {
if (opHandle != 0L) { if (opHandle != 0L) {
for (long tensorHandle : outputHandles) { for (int i = 0; i < outputHandles.length; ++i) {
if (tensorHandle != 0L) { if (outputHandles[i] != 0L) {
EagerOperation.deleteTensorHandle(tensorHandle); EagerOperation.deleteTensorHandle(outputHandles[i]);
outputHandles[i] = 0L;
} }
} }
EagerOperation.delete(opHandle); EagerOperation.delete(opHandle);
opHandle = 0L; opHandle = 0L;
Arrays.fill(outputHandles, 0L);
} }
} }
@ -107,15 +147,12 @@ class EagerOperation extends AbstractOperation {
private final long[] outputHandles; private final long[] outputHandles;
} }
private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private static native void delete(long handle); private static native void delete(long handle);
private static native void deleteTensorHandle(long handle); private static native void deleteTensorHandle(long handle);
private static native long resolveTensorHandle(long handle);
private static native int outputListLength(long handle, String name); private static native int outputListLength(long handle, String name);
private static native int inputListLength(long handle, String name); private static native int inputListLength(long handle, String name);

View File

@ -139,6 +139,11 @@ public final class GraphOperation extends AbstractOperation {
} }
} }
@Override
Tensor<?> tensor(int outputIdx) {
throw new IllegalStateException("Graph tensors must be fetched by running a session");
}
long getUnsafeNativeHandle() { long getUnsafeNativeHandle() {
return unsafeNativeHandle; return unsafeNativeHandle;
} }

View File

@ -48,6 +48,22 @@ public final class Output<T> implements Operand<T> {
return operation.dtype(index); return operation.dtype(index);
} }
/**
* Returns the tensor at this output.
*
* <p>This operation is only supported on the outputs of an operation executed eagerly.
* For graph environments, output tensors must be fetched by running a session, using
* {@link Session.Runner#fetch(Output)}.
*
* @return tensor
* @throws IllegalStateException if this output results from a graph
* @see EagerSession
*/
@SuppressWarnings("unchecked")
public Tensor<T> tensor() {
return (Tensor<T>)operation.tensor(index);
}
@Override @Override
public Output<T> asOutput() { public Output<T> asOutput() {
return this; return this;

View File

@ -140,15 +140,17 @@ public final class Tensor<T> implements AutoCloseable {
Tensor<?> t = new Tensor(dtype); Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)]; t.shapeCopy = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, t.shapeCopy); fillShape(obj, 0, t.shapeCopy);
long nativeHandle;
if (t.dtype != DataType.STRING) { if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(t.nativeHandle, obj); setValue(nativeHandle, obj);
} else if (t.shapeCopy.length != 0) { } else if (t.shapeCopy.length != 0) {
t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
} else { } else {
t.nativeHandle = allocateScalarBytes((byte[]) obj); nativeHandle = allocateScalarBytes((byte[]) obj);
} }
t.nativeRef = new NativeReference(nativeHandle);
return t; return t;
} }
@ -314,23 +316,22 @@ public final class Tensor<T> implements AutoCloseable {
} }
Tensor<T> t = new Tensor<T>(dataType); Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length); t.shapeCopy = Arrays.copyOf(shape, shape.length);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t; return t;
} }
/** /**
* Release resources associated with the Tensor. * Release resources associated with the Tensor.
* *
* <p><b>WARNING:</b>If not invoked, memory will be leaked. * <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
* operation or memory will be leaked.
* *
* <p>The Tensor object is no longer usable after {@code close} returns. * <p>The Tensor object is no longer usable after {@code close} returns.
*/ */
@Override @Override
public void close() { public void close() {
if (nativeHandle != 0) { nativeRef.release();
delete(nativeHandle);
nativeHandle = 0;
}
} }
/** Returns the {@link DataType} of elements stored in the Tensor. */ /** Returns the {@link DataType} of elements stored in the Tensor. */
@ -374,7 +375,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a float scalar. * @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/ */
public float floatValue() { public float floatValue() {
return scalarFloat(nativeHandle); return scalarFloat(getNativeHandle());
} }
/** /**
@ -383,7 +384,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a double scalar. * @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/ */
public double doubleValue() { public double doubleValue() {
return scalarDouble(nativeHandle); return scalarDouble(getNativeHandle());
} }
/** /**
@ -392,7 +393,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a int scalar. * @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/ */
public int intValue() { public int intValue() {
return scalarInt(nativeHandle); return scalarInt(getNativeHandle());
} }
/** /**
@ -401,7 +402,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a long scalar. * @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/ */
public long longValue() { public long longValue() {
return scalarLong(nativeHandle); return scalarLong(getNativeHandle());
} }
/** /**
@ -410,7 +411,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/ */
public boolean booleanValue() { public boolean booleanValue() {
return scalarBoolean(nativeHandle); return scalarBoolean(getNativeHandle());
} }
/** /**
@ -419,7 +420,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/ */
public byte[] bytesValue() { public byte[] bytesValue() {
return scalarBytes(nativeHandle); return scalarBytes(getNativeHandle());
} }
/** /**
@ -448,7 +449,7 @@ public final class Tensor<T> implements AutoCloseable {
*/ */
public <U> U copyTo(U dst) { public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst); throwExceptionIfTypeIsIncompatible(dst);
readNDArray(nativeHandle, dst); readNDArray(getNativeHandle(), dst);
return dst; return dst;
} }
@ -553,16 +554,27 @@ public final class Tensor<T> implements AutoCloseable {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle))); Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle); t.shapeCopy = shape(handle);
t.nativeHandle = handle; t.nativeRef = new NativeReference(handle);
return t;
}
/**
* Create an eager Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(long handle, EagerSession session) {
Tensor<?> t = fromHandle(handle);
t.nativeRef.eager(session, t);
return t; return t;
} }
long getNativeHandle() { long getNativeHandle() {
return nativeHandle; return nativeRef.tensorHandle;
} }
private long nativeHandle; private NativeReference nativeRef = null;
private DataType dtype; private final DataType dtype;
private long[] shapeCopy = null; private long[] shapeCopy = null;
private Tensor(DataType t) { private Tensor(DataType t) {
@ -570,7 +582,7 @@ public final class Tensor<T> implements AutoCloseable {
} }
private ByteBuffer buffer() { private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder()); return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
} }
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) { private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
@ -609,6 +621,66 @@ public final class Tensor<T> implements AutoCloseable {
} }
} }
/**
* Reference to the underlying native tensor
*
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
* released after executing the last line of the `try` block they were declared in.
*
* <p>They can also be attached to an eager session, where in this case their lifetime ends either when
* this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
*
* <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to
* the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
* being explicetly closed before this happens, it will take cake of clearing its association with any eager
* session before cleaning up the resources.
*/
private static class NativeReference {
/**
* Attaches this reference to an eager session
*/
private class EagerReference extends EagerSession.NativeReference {
EagerReference(EagerSession session, Tensor<?> tensor) {
super(session, tensor);
}
@Override
void delete() {
// Mark this eager reference as cleared since it has been deleted by the session
NativeReference.this.eagerRef = null;
NativeReference.this.release();
}
}
NativeReference(long tensorHandle) {
this.tensorHandle = tensorHandle;
}
void eager(EagerSession session, Tensor<?> tensor) {
if (eagerRef != null) {
throw new IllegalStateException("The tensor is already attached to an eager session");
}
eagerRef = new EagerReference(session, tensor);
}
synchronized void release() {
if (tensorHandle != 0L) {
// Clear any remaining eager reference to this tensor
if (eagerRef != null) {
eagerRef.clear();
eagerRef = null;
}
Tensor.delete(tensorHandle);
tensorHandle = 0L;
}
}
private long tensorHandle;
private EagerReference eagerRef;
}
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>(); private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
static { static {

View File

@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle)); TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
} }
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
JNIEnv* env, jclass clazz, jlong handle) {
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
if (tensor_handle == nullptr) return 0;
TF_Status* status = TF_NewStatus();
TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
"Cannot represent a C TF_Tensor as a Java long");
return reinterpret_cast<jlong>(tensor);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
JNIEnv* env, jclass clazz, jlong handle, jstring name) { JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Op* op = requireOp(env, handle); TFE_Op* op = requireOp(env, handle);

View File

@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle( JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
JNIEnv *, jclass, jlong); JNIEnv *, jclass, jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: resolveTensorHandle
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
JNIEnv *, jclass, jlong);
/** /**
* Class: org_tensorflow_EagerOperation * Class: org_tensorflow_EagerOperation
* Method: outputListLength * Method: outputListLength

View File

@ -54,6 +54,22 @@ public class EagerOperationTest {
} }
} }
@Test
public void outputTensor() {
try (EagerSession session = EagerSession.create()) {
EagerOperation add = opBuilder(session, "Add", "CompareResult")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
assertEquals(6, add.tensor(0).intValue());
// Validate that we retrieve the right shape and datatype from the tensor
// that has been resolved
assertEquals(0, add.shape(0).length);
assertEquals(DataType.INT32, add.dtype(0));
}
}
@Test @Test
public void inputAndOutputListLengths() { public void inputAndOutputListLengths() {
try (EagerSession session = EagerSession.create()) { try (EagerSession session = EagerSession.create()) {
@ -105,7 +121,7 @@ public class EagerOperationTest {
@Test @Test
public void opNotAccessibleIfSessionIsClosed() { public void opNotAccessibleIfSessionIsClosed() {
EagerSession session = EagerSession.create(); EagerSession session = EagerSession.create();
EagerOperation add = opBuilder(session, "Add", "SetDevice") EagerOperation add = opBuilder(session, "Add", "SessionClosed")
.addInput(TestUtil.constant(session, "Const1", 2)) .addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4)) .addInput(TestUtil.constant(session, "Const2", 4))
.build(); .build();
@ -119,6 +135,40 @@ public class EagerOperationTest {
} }
} }
@Test
public void outputIndexOutOfBounds() {
try (EagerSession session = EagerSession.create()) {
EagerOperation add = opBuilder(session, "Add", "OutOfRange")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
try {
add.getUnsafeNativeHandle(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.shape(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.dtype(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
add.tensor(1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
}
}
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
return new EagerOperationBuilder(session, type, name); return new EagerOperationBuilder(session, type, name);
} }

View File

@ -167,6 +167,17 @@ public class GraphOperationTest {
} }
} }
@Test
public void outputTensorNotSupported() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
try {
split.output(0).tensor();
fail();
} catch (IllegalStateException e) {}
}
}
private static int split(int[] values, int num_split) { private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) { try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split") return g.opBuilder("Split", "Split")

View File

@ -18,6 +18,7 @@ package org.tensorflow;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.nio.IntBuffer; import java.nio.IntBuffer;
import java.nio.LongBuffer; import java.nio.LongBuffer;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -520,6 +522,25 @@ public class TensorTest {
} }
} }
@Test
public void eagerTensorIsReleasedAfterSessionIsClosed() {
Tensor<Integer> sum;
try (EagerSession session = EagerSession.create()) {
Output<?> x = TestUtil.constant(session, "Const1", 10);
Output<?> y = TestUtil.constant(session, "Const2", 20);
sum = TestUtil.<Integer>addN(session, x, y).tensor();
assertNotEquals(0L, sum.getNativeHandle());
assertEquals(30, sum.intValue());
}
assertEquals(0L, sum.getNativeHandle());
try {
sum.intValue();
fail();
} catch (NullPointerException e) {
// expected.
}
}
@Test @Test
public void fromHandle() { public void fromHandle() {
// fromHandle is a package-visible method intended for use when the C TF_Tensor object has been // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been

View File

@ -67,8 +67,8 @@ public class TestUtil {
.<T>output(0); .<T>output(0);
} }
public static <T> Output<T> addN(Graph g, Output<?>... inputs) { public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
} }
public static <T> Output<T> matmul( public static <T> Output<T> matmul(