diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
index 0d4745fe0b7..f586dae73e0 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
@@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
abstract long getUnsafeNativeHandle(int outputIdx);
/**
- * Returns the shape of the tensor of the {code outputIdx}th output of this operation.
+ * Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor shape
@@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
abstract long[] shape(int outputIdx);
/**
- * Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
+ * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor datatype
*/
abstract DataType dtype(int outputIdx);
+
+ /**
+ * Returns the tensor of the {@code outputIdx}th output of this operation.
+ *
+ *
This is only supported in an eager execution environment.
+ *
+ * @param outputIdx index of the output of this operation
+ * @return output tensor
+ */
+ abstract Tensor> tensor(int outputIdx);
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
index a0530d7b9da..2c1df4cdc40 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
@@ -15,7 +15,7 @@ limitations under the License.
package org.tensorflow;
-import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReferenceArray;
/**
* Implementation of an {@link Operation} executed eagerly.
@@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
this.type = type;
this.name = name;
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
+ this.outputTensors = new AtomicReferenceArray>(outputNativeHandles.length);
}
@Override
@@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
@Override
public long[] shape(int outputIndex) {
+ // If the tensor of this output has already been resolved, return its shape.
+ // Otherwise, retrieve the tensor shape from the native library.
+ Tensor> tensor = outputTensors.get(outputIndex);
+ if (tensor != null) {
+ return tensor.shape();
+ }
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) {
@@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
@Override
public DataType dtype(int outputIndex) {
+ // If the tensor of this output has already been resolved, return its datatype.
+ // Otherwise, retrieve the tensor datatype from the native library.
+ Tensor> tensor = outputTensors.get(outputIndex);
+ if (tensor != null) {
+ return tensor.dataType();
+ }
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataType.fromC(dataType(outputNativeHandle));
}
+ @Override
+ public Tensor> tensor(int outputIndex) {
+ Tensor> tensor = outputTensors.get(outputIndex);
+ if (tensor == null) {
+ tensor = resolveTensor(outputIndex);
+ }
+ return tensor;
+ }
+
+ private final EagerSession session;
+ private final NativeReference nativeRef;
+ private final String type;
+ private final String name;
+ private final AtomicReferenceArray> outputTensors;
+
+ private Tensor> resolveTensor(int outputIndex) {
+ // Take an optimistic approach, where we attempt to resolve the output tensor without locking.
+ // If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
+ long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
+ Tensor> tensor = Tensor.fromHandle(tensorNativeHandle, session);
+ if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
+ tensor.close();
+ tensor = outputTensors.get(outputIndex);
+ }
+ return tensor;
+ }
+
private static class NativeReference extends EagerSession.NativeReference {
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
@@ -92,30 +132,27 @@ class EagerOperation extends AbstractOperation {
@Override
void delete() {
if (opHandle != 0L) {
- for (long tensorHandle : outputHandles) {
- if (tensorHandle != 0L) {
- EagerOperation.deleteTensorHandle(tensorHandle);
+ for (int i = 0; i < outputHandles.length; ++i) {
+ if (outputHandles[i] != 0L) {
+ EagerOperation.deleteTensorHandle(outputHandles[i]);
+ outputHandles[i] = 0L;
}
}
EagerOperation.delete(opHandle);
opHandle = 0L;
- Arrays.fill(outputHandles, 0L);
}
}
private long opHandle;
private final long[] outputHandles;
}
-
- private final EagerSession session;
- private final NativeReference nativeRef;
- private final String type;
- private final String name;
private static native void delete(long handle);
private static native void deleteTensorHandle(long handle);
+ private static native long resolveTensorHandle(long handle);
+
private static native int outputListLength(long handle, String name);
private static native int inputListLength(long handle, String name);
diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
index 0e43bc3eb43..590eff8a83e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
@@ -138,6 +138,11 @@ public final class GraphOperation extends AbstractOperation {
r.close();
}
}
+
+ @Override
+ Tensor> tensor(int outputIdx) {
+ throw new IllegalStateException("Graph tensors must be fetched by running a session");
+ }
long getUnsafeNativeHandle() {
return unsafeNativeHandle;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java
index 15bb2e89e8d..90668bb7ad3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Output.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java
@@ -47,6 +47,22 @@ public final class Output implements Operand {
public DataType dataType() {
return operation.dtype(index);
}
+
+ /**
+ * Returns the tensor at this output.
+ *
+ * This operation is only supported on the outputs of an operation executed eagerly.
+ * For graph environments, output tensors must be fetched by running a session, using
+ * {@link Session.Runner#fetch(Output)}.
+ *
+ * @return tensor
+ * @throws IllegalStateException if this output results from a graph
+ * @see EagerSession
+ */
+ @SuppressWarnings("unchecked")
+ public Tensor tensor() {
+ return (Tensor)operation.tensor(index);
+ }
@Override
public Output asOutput() {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 89872537689..253ceb65781 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -140,15 +140,17 @@ public final class Tensor implements AutoCloseable {
Tensor> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, t.shapeCopy);
+ long nativeHandle;
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
- t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
- setValue(t.nativeHandle, obj);
+ nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
+ setValue(nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
- t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
+ nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
} else {
- t.nativeHandle = allocateScalarBytes((byte[]) obj);
+ nativeHandle = allocateScalarBytes((byte[]) obj);
}
+ t.nativeRef = new NativeReference(nativeHandle);
return t;
}
@@ -314,23 +316,22 @@ public final class Tensor implements AutoCloseable {
}
Tensor t = new Tensor(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
- t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
+ long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
+ t.nativeRef = new NativeReference(nativeHandle);
return t;
}
/**
* Release resources associated with the Tensor.
*
- * WARNING:If not invoked, memory will be leaked.
+ *
WARNING:This must be invoked for all tensors that were not been produced by an eager
+ * operation or memory will be leaked.
*
*
The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
- if (nativeHandle != 0) {
- delete(nativeHandle);
- nativeHandle = 0;
- }
+ nativeRef.release();
}
/** Returns the {@link DataType} of elements stored in the Tensor. */
@@ -374,7 +375,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
public float floatValue() {
- return scalarFloat(nativeHandle);
+ return scalarFloat(getNativeHandle());
}
/**
@@ -383,7 +384,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
public double doubleValue() {
- return scalarDouble(nativeHandle);
+ return scalarDouble(getNativeHandle());
}
/**
@@ -392,7 +393,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
public int intValue() {
- return scalarInt(nativeHandle);
+ return scalarInt(getNativeHandle());
}
/**
@@ -401,7 +402,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
public long longValue() {
- return scalarLong(nativeHandle);
+ return scalarLong(getNativeHandle());
}
/**
@@ -410,7 +411,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public boolean booleanValue() {
- return scalarBoolean(nativeHandle);
+ return scalarBoolean(getNativeHandle());
}
/**
@@ -419,7 +420,7 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public byte[] bytesValue() {
- return scalarBytes(nativeHandle);
+ return scalarBytes(getNativeHandle());
}
/**
@@ -448,7 +449,7 @@ public final class Tensor implements AutoCloseable {
*/
public U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
- readNDArray(nativeHandle, dst);
+ readNDArray(getNativeHandle(), dst);
return dst;
}
@@ -553,16 +554,27 @@ public final class Tensor implements AutoCloseable {
@SuppressWarnings("rawtypes")
Tensor> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
- t.nativeHandle = handle;
+ t.nativeRef = new NativeReference(handle);
+ return t;
+ }
+
+ /**
+ * Create an eager Tensor object from a handle to the C TF_Tensor object.
+ *
+ * Takes ownership of the handle.
+ */
+ static Tensor> fromHandle(long handle, EagerSession session) {
+ Tensor> t = fromHandle(handle);
+ t.nativeRef.eager(session, t);
return t;
}
long getNativeHandle() {
- return nativeHandle;
+ return nativeRef.tensorHandle;
}
- private long nativeHandle;
- private DataType dtype;
+ private NativeReference nativeRef = null;
+ private final DataType dtype;
private long[] shapeCopy = null;
private Tensor(DataType t) {
@@ -570,7 +582,7 @@ public final class Tensor implements AutoCloseable {
}
private ByteBuffer buffer() {
- return buffer(nativeHandle).order(ByteOrder.nativeOrder());
+ return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
}
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
@@ -609,6 +621,66 @@ public final class Tensor implements AutoCloseable {
}
}
+ /**
+ * Reference to the underlying native tensor
+ *
+ * Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
+ * released after executing the last line of the `try` block they were declared in.
+ *
+ *
They can also be attached to an eager session, where in this case their lifetime ends either when
+ * this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
+ *
+ *
This helper class wraps the tensor native handle and support both situations; If an eager reference to
+ * the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
+ * being explicetly closed before this happens, it will take cake of clearing its association with any eager
+ * session before cleaning up the resources.
+ */
+ private static class NativeReference {
+
+ /**
+ * Attaches this reference to an eager session
+ */
+ private class EagerReference extends EagerSession.NativeReference {
+
+ EagerReference(EagerSession session, Tensor> tensor) {
+ super(session, tensor);
+ }
+
+ @Override
+ void delete() {
+ // Mark this eager reference as cleared since it has been deleted by the session
+ NativeReference.this.eagerRef = null;
+ NativeReference.this.release();
+ }
+ }
+
+ NativeReference(long tensorHandle) {
+ this.tensorHandle = tensorHandle;
+ }
+
+ void eager(EagerSession session, Tensor> tensor) {
+ if (eagerRef != null) {
+ throw new IllegalStateException("The tensor is already attached to an eager session");
+ }
+ eagerRef = new EagerReference(session, tensor);
+ }
+
+ synchronized void release() {
+ if (tensorHandle != 0L) {
+ // Clear any remaining eager reference to this tensor
+ if (eagerRef != null) {
+ eagerRef.clear();
+ eagerRef = null;
+ }
+ Tensor.delete(tensorHandle);
+ tensorHandle = 0L;
+ }
+ }
+
+ private long tensorHandle;
+ private EagerReference eagerRef;
+ }
+
private static HashMap, DataType> classDataTypes = new HashMap<>();
static {
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc
index 3a5f6f90ddc..15f98905796 100644
--- a/tensorflow/java/src/main/native/eager_operation_jni.cc
+++ b/tensorflow/java/src/main/native/eager_operation_jni.cc
@@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
TFE_DeleteTensorHandle(reinterpret_cast(handle));
}
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
+ JNIEnv* env, jclass clazz, jlong handle) {
+ TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
+ if (tensor_handle == nullptr) return 0;
+ TF_Status* status = TF_NewStatus();
+ TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteStatus(status);
+ return 0;
+ }
+ TF_DeleteStatus(status);
+ static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
+ "Cannot represent a C TF_Tensor as a Java long");
+ return reinterpret_cast(tensor);
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Op* op = requireOp(env, handle);
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h
index f9684b0a26e..c1d52bf9393 100644
--- a/tensorflow/java/src/main/native/eager_operation_jni.h
+++ b/tensorflow/java/src/main/native/eager_operation_jni.h
@@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
JNIEnv *, jclass, jlong);
+/**
+ * Class: org_tensorflow_EagerOperation
+ * Method: resolveTensorHandle
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
+ JNIEnv *, jclass, jlong);
+
/**
* Class: org_tensorflow_EagerOperation
* Method: outputListLength
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
index d0256435f48..4b7fdc8ccf8 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
@@ -54,6 +54,22 @@ public class EagerOperationTest {
}
}
+ @Test
+ public void outputTensor() {
+ try (EagerSession session = EagerSession.create()) {
+ EagerOperation add = opBuilder(session, "Add", "CompareResult")
+ .addInput(TestUtil.constant(session, "Const1", 2))
+ .addInput(TestUtil.constant(session, "Const2", 4))
+ .build();
+ assertEquals(6, add.tensor(0).intValue());
+
+ // Validate that we retrieve the right shape and datatype from the tensor
+ // that has been resolved
+ assertEquals(0, add.shape(0).length);
+ assertEquals(DataType.INT32, add.dtype(0));
+ }
+ }
+
@Test
public void inputAndOutputListLengths() {
try (EagerSession session = EagerSession.create()) {
@@ -105,7 +121,7 @@ public class EagerOperationTest {
@Test
public void opNotAccessibleIfSessionIsClosed() {
EagerSession session = EagerSession.create();
- EagerOperation add = opBuilder(session, "Add", "SetDevice")
+ EagerOperation add = opBuilder(session, "Add", "SessionClosed")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
@@ -119,6 +135,40 @@ public class EagerOperationTest {
}
}
+ @Test
+ public void outputIndexOutOfBounds() {
+ try (EagerSession session = EagerSession.create()) {
+ EagerOperation add = opBuilder(session, "Add", "OutOfRange")
+ .addInput(TestUtil.constant(session, "Const1", 2))
+ .addInput(TestUtil.constant(session, "Const2", 4))
+ .build();
+ try {
+ add.getUnsafeNativeHandle(1);
+ fail();
+ } catch (IndexOutOfBoundsException e) {
+ // expected
+ }
+ try {
+ add.shape(1);
+ fail();
+ } catch (IndexOutOfBoundsException e) {
+ // expected
+ }
+ try {
+ add.dtype(1);
+ fail();
+ } catch (IndexOutOfBoundsException e) {
+ // expected
+ }
+ try {
+ add.tensor(1);
+ fail();
+ } catch (IndexOutOfBoundsException e) {
+ // expected
+ }
+ }
+ }
+
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
return new EagerOperationBuilder(session, type, name);
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
index 7331ad50e51..bfbf5385b48 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
@@ -166,6 +166,17 @@ public class GraphOperationTest {
}
}
}
+
+ @Test
+ public void outputTensorNotSupported() {
+ try (Graph g = new Graph()) {
+ Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
+ try {
+ split.output(0).tensor();
+ fail();
+ } catch (IllegalStateException e) {}
+ }
+ }
private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 3229cce2776..21f4e25f5ab 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -18,6 +18,7 @@ package org.tensorflow;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -519,6 +521,25 @@ public class TensorTest {
// The expected exception.
}
}
+
+ @Test
+ public void eagerTensorIsReleasedAfterSessionIsClosed() {
+ Tensor sum;
+ try (EagerSession session = EagerSession.create()) {
+ Output> x = TestUtil.constant(session, "Const1", 10);
+ Output> y = TestUtil.constant(session, "Const2", 20);
+ sum = TestUtil.addN(session, x, y).tensor();
+ assertNotEquals(0L, sum.getNativeHandle());
+ assertEquals(30, sum.intValue());
+ }
+ assertEquals(0L, sum.getNativeHandle());
+ try {
+ sum.intValue();
+ fail();
+ } catch (NullPointerException e) {
+ // expected.
+ }
+ }
@Test
public void fromHandle() {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c97bcaa3386..6e24d88a310 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -67,8 +67,8 @@ public class TestUtil {
.output(0);
}
- public static Output addN(Graph g, Output>... inputs) {
- return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
+ public static Output addN(ExecutionEnvironment env, Output>... inputs) {
+ return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
public static Output matmul(