From fceea7d0902ea7dde7625d618410e9ddb99eaf50 Mon Sep 17 00:00:00 2001
From: Karl Lessard <karl@kubx.ca>
Date: Thu, 9 May 2019 00:32:51 -0400
Subject: [PATCH] Add eager tensor support

---
 .../org/tensorflow/AbstractOperation.java     |  14 ++-
 .../java/org/tensorflow/EagerOperation.java   |  57 +++++++--
 .../java/org/tensorflow/GraphOperation.java   |   5 +
 .../src/main/java/org/tensorflow/Output.java  |  16 +++
 .../src/main/java/org/tensorflow/Tensor.java  | 116 ++++++++++++++----
 .../src/main/native/eager_operation_jni.cc    |  16 +++
 .../src/main/native/eager_operation_jni.h     |   8 ++
 .../org/tensorflow/EagerOperationTest.java    |  52 +++++++-
 .../org/tensorflow/GraphOperationTest.java    |  11 ++
 .../test/java/org/tensorflow/TensorTest.java  |  21 ++++
 .../test/java/org/tensorflow/TestUtil.java    |   4 +-
 11 files changed, 283 insertions(+), 37 deletions(-)

diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
index 0d4745fe0b7..f586dae73e0 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
@@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
   abstract long getUnsafeNativeHandle(int outputIdx);
 
   /**
-   * Returns the shape of the tensor of the {code outputIdx}th output of this operation.
+   * Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
    *
    * @param outputIdx index of the output of this operation
    * @return output tensor shape
@@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
   abstract long[] shape(int outputIdx);
 
   /**
-   * Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
+   * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
    *
    * @param outputIdx index of the output of this operation
    * @return output tensor datatype
    */
   abstract DataType dtype(int outputIdx);
+
+  /**
+   * Returns the tensor of the {@code outputIdx}th output of this operation.
+   * 
+   * <p>This is only supported in an eager execution environment.
+   * 
+   * @param outputIdx index of the output of this operation
+   * @return output tensor
+   */
+  abstract Tensor<?> tensor(int outputIdx);
 }
diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
index a0530d7b9da..2c1df4cdc40 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
@@ -15,7 +15,7 @@ limitations under the License.
 
 package org.tensorflow;
 
-import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 
 /**
  * Implementation of an {@link Operation} executed eagerly.
@@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
     this.type = type;
     this.name = name;
     this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
+    this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
   }
 
   @Override
@@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
 
   @Override
   public long[] shape(int outputIndex) {
+    // If the tensor of this output has already been resolved, return its shape.
+    // Otherwise, retrieve the tensor shape from the native library.
+    Tensor<?> tensor = outputTensors.get(outputIndex);
+    if (tensor != null) {
+      return tensor.shape();
+    }
     long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
     long[] shape = new long[numDims(outputNativeHandle)];
     for (int i = 0; i < shape.length; ++i) {
@@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
 
   @Override
   public DataType dtype(int outputIndex) {
+    // If the tensor of this output has already been resolved, return its datatype.
+    // Otherwise, retrieve the tensor datatype from the native library.
+    Tensor<?> tensor = outputTensors.get(outputIndex);
+    if (tensor != null) {
+      return tensor.dataType();
+    }
     long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
     return DataType.fromC(dataType(outputNativeHandle));
   }
 
+  @Override
+  public Tensor<?> tensor(int outputIndex) {
+    Tensor<?> tensor = outputTensors.get(outputIndex);
+    if (tensor == null) {
+      tensor = resolveTensor(outputIndex);
+    }
+    return tensor;
+  }
+
+  private final EagerSession session;
+  private final NativeReference nativeRef;
+  private final String type;
+  private final String name;
+  private final AtomicReferenceArray<Tensor<?>> outputTensors;
+  
+  private Tensor<?> resolveTensor(int outputIndex) {
+    // Take an optimistic approach, where we attempt to resolve the output tensor without locking.
+    // If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
+    long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
+    Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
+    if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
+      tensor.close();
+      tensor = outputTensors.get(outputIndex);
+    }
+    return tensor;
+  }
+  
   private static class NativeReference extends EagerSession.NativeReference {
 
     NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
@@ -92,30 +132,27 @@ class EagerOperation extends AbstractOperation {
     @Override
     void delete() {
       if (opHandle != 0L) {
-        for (long tensorHandle : outputHandles) {
-          if (tensorHandle != 0L) {
-            EagerOperation.deleteTensorHandle(tensorHandle);
+        for (int i = 0; i < outputHandles.length; ++i) {
+          if (outputHandles[i] != 0L) {
+            EagerOperation.deleteTensorHandle(outputHandles[i]);
+            outputHandles[i] = 0L;
           }
         }
         EagerOperation.delete(opHandle);
         opHandle = 0L;
-        Arrays.fill(outputHandles, 0L);
       }
     }
     
     private long opHandle;
     private final long[] outputHandles;
   }
-
-  private final EagerSession session;
-  private final NativeReference nativeRef;
-  private final String type;
-  private final String name;
   
   private static native void delete(long handle);
 
   private static native void deleteTensorHandle(long handle);
   
+  private static native long resolveTensorHandle(long handle);
+  
   private static native int outputListLength(long handle, String name);
 
   private static native int inputListLength(long handle, String name);
diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
index 0e43bc3eb43..590eff8a83e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
@@ -138,6 +138,11 @@ public final class GraphOperation extends AbstractOperation {
       r.close();
     }
   }
+  
+  @Override
+  Tensor<?> tensor(int outputIdx) {
+    throw new IllegalStateException("Graph tensors must be fetched by running a session");
+  }
 
   long getUnsafeNativeHandle() {
     return unsafeNativeHandle;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java
index 15bb2e89e8d..90668bb7ad3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Output.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java
@@ -47,6 +47,22 @@ public final class Output<T> implements Operand<T> {
   public DataType dataType() {
     return operation.dtype(index);
   }
+  
+  /** 
+   * Returns the tensor at this output.
+   * 
+   * <p>This operation is only supported on the outputs of an operation executed eagerly. 
+   * For graph environments, output tensors must be fetched by running a session, using 
+   * {@link Session.Runner#fetch(Output)}.
+   * 
+   * @return tensor
+   * @throws IllegalStateException if this output results from a graph
+   * @see EagerSession
+   */
+  @SuppressWarnings("unchecked")
+  public Tensor<T> tensor() {
+    return (Tensor<T>)operation.tensor(index);
+  }
 
   @Override
   public Output<T> asOutput() {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 89872537689..253ceb65781 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -140,15 +140,17 @@ public final class Tensor<T> implements AutoCloseable {
     Tensor<?> t = new Tensor(dtype);
     t.shapeCopy = new long[numDimensions(obj, dtype)];
     fillShape(obj, 0, t.shapeCopy);
+    long nativeHandle;
     if (t.dtype != DataType.STRING) {
       int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
-      t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
-      setValue(t.nativeHandle, obj);
+      nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
+      setValue(nativeHandle, obj);
     } else if (t.shapeCopy.length != 0) {
-      t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
+      nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
     } else {
-      t.nativeHandle = allocateScalarBytes((byte[]) obj);
+      nativeHandle = allocateScalarBytes((byte[]) obj);
     }
+    t.nativeRef = new NativeReference(nativeHandle);
     return t;
   }
 
@@ -314,23 +316,22 @@ public final class Tensor<T> implements AutoCloseable {
     }
     Tensor<T> t = new Tensor<T>(dataType);
     t.shapeCopy = Arrays.copyOf(shape, shape.length);
-    t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
+    long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
+    t.nativeRef = new NativeReference(nativeHandle);
     return t;
   }
 
   /**
    * Release resources associated with the Tensor.
    *
-   * <p><b>WARNING:</b>If not invoked, memory will be leaked.
+   * <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager 
+   * operation or memory will be leaked.
    *
    * <p>The Tensor object is no longer usable after {@code close} returns.
    */
   @Override
   public void close() {
-    if (nativeHandle != 0) {
-      delete(nativeHandle);
-      nativeHandle = 0;
-    }
+    nativeRef.release();
   }
 
   /** Returns the {@link DataType} of elements stored in the Tensor. */
@@ -374,7 +375,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a float scalar.
    */
   public float floatValue() {
-    return scalarFloat(nativeHandle);
+    return scalarFloat(getNativeHandle());
   }
 
   /**
@@ -383,7 +384,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a double scalar.
    */
   public double doubleValue() {
-    return scalarDouble(nativeHandle);
+    return scalarDouble(getNativeHandle());
   }
 
   /**
@@ -392,7 +393,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a int scalar.
    */
   public int intValue() {
-    return scalarInt(nativeHandle);
+    return scalarInt(getNativeHandle());
   }
 
   /**
@@ -401,7 +402,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a long scalar.
    */
   public long longValue() {
-    return scalarLong(nativeHandle);
+    return scalarLong(getNativeHandle());
   }
 
   /**
@@ -410,7 +411,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
    */
   public boolean booleanValue() {
-    return scalarBoolean(nativeHandle);
+    return scalarBoolean(getNativeHandle());
   }
 
   /**
@@ -419,7 +420,7 @@ public final class Tensor<T> implements AutoCloseable {
    * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
    */
   public byte[] bytesValue() {
-    return scalarBytes(nativeHandle);
+    return scalarBytes(getNativeHandle());
   }
 
   /**
@@ -448,7 +449,7 @@ public final class Tensor<T> implements AutoCloseable {
    */
   public <U> U copyTo(U dst) {
     throwExceptionIfTypeIsIncompatible(dst);
-    readNDArray(nativeHandle, dst);
+    readNDArray(getNativeHandle(), dst);
     return dst;
   }
 
@@ -553,16 +554,27 @@ public final class Tensor<T> implements AutoCloseable {
     @SuppressWarnings("rawtypes")
     Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
     t.shapeCopy = shape(handle);
-    t.nativeHandle = handle;
+    t.nativeRef = new NativeReference(handle);
+    return t;
+  }
+
+  /**
+   * Create an eager Tensor object from a handle to the C TF_Tensor object.
+   *
+   * <p>Takes ownership of the handle.
+   */
+  static Tensor<?> fromHandle(long handle, EagerSession session) {
+    Tensor<?> t = fromHandle(handle);
+    t.nativeRef.eager(session, t);
     return t;
   }
 
   long getNativeHandle() {
-    return nativeHandle;
+    return nativeRef.tensorHandle;
   }
 
-  private long nativeHandle;
-  private DataType dtype;
+  private NativeReference nativeRef = null;
+  private final DataType dtype;
   private long[] shapeCopy = null;
 
   private Tensor(DataType t) {
@@ -570,7 +582,7 @@ public final class Tensor<T> implements AutoCloseable {
   }
 
   private ByteBuffer buffer() {
-    return buffer(nativeHandle).order(ByteOrder.nativeOrder());
+    return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
   }
 
   private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
@@ -609,6 +621,66 @@ public final class Tensor<T> implements AutoCloseable {
     }
   }
 
+  /**
+   * Reference to the underlying native tensor
+   *
+   * <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically 
+   * released after executing the last line of the `try` block they were declared in.
+   *  
+   * <p>They can also be attached to an eager session, where in this case their lifetime ends either when
+   * this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
+   * 
+   * <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to 
+   * the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
+   * being explicetly closed before this happens, it will take cake of clearing its association with any eager
+   * session before cleaning up the resources.
+   */
+  private static class NativeReference {
+
+    /**
+     * Attaches this reference to an eager session
+     */
+    private class EagerReference extends EagerSession.NativeReference {
+
+      EagerReference(EagerSession session, Tensor<?> tensor) {
+        super(session, tensor);
+      }
+
+      @Override
+      void delete() {
+        // Mark this eager reference as cleared since it has been deleted by the session
+        NativeReference.this.eagerRef = null;
+        NativeReference.this.release();
+      }
+    }
+    
+    NativeReference(long tensorHandle) {
+      this.tensorHandle = tensorHandle;
+    }
+    
+    void eager(EagerSession session, Tensor<?> tensor) {
+      if (eagerRef != null) {
+        throw new IllegalStateException("The tensor is already attached to an eager session");
+      }
+      eagerRef = new EagerReference(session, tensor);
+    }
+
+    synchronized void release() {
+      if (tensorHandle != 0L) {
+        // Clear any remaining eager reference to this tensor
+        if (eagerRef != null) {
+          eagerRef.clear();
+          eagerRef = null;
+        }
+        Tensor.delete(tensorHandle);
+        tensorHandle = 0L;
+      }
+    }
+    
+    private long tensorHandle;
+    private EagerReference eagerRef;
+  }
+
   private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
 
   static {
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc
index 3a5f6f90ddc..15f98905796 100644
--- a/tensorflow/java/src/main/native/eager_operation_jni.cc
+++ b/tensorflow/java/src/main/native/eager_operation_jni.cc
@@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
   TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
 }
 
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
+  if (tensor_handle == nullptr) return 0;
+  TF_Status* status = TF_NewStatus();
+  TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
+    return 0;
+  }
+  TF_DeleteStatus(status);
+  static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
+                "Cannot represent a C TF_Tensor as a Java long");
+  return reinterpret_cast<jlong>(tensor);
+}
+
 JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
     JNIEnv* env, jclass clazz, jlong handle, jstring name) {
   TFE_Op* op = requireOp(env, handle);
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h
index f9684b0a26e..c1d52bf9393 100644
--- a/tensorflow/java/src/main/native/eager_operation_jni.h
+++ b/tensorflow/java/src/main/native/eager_operation_jni.h
@@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
     JNIEnv *, jclass, jlong);
 
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    resolveTensorHandle
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
+    JNIEnv *, jclass, jlong);
+
 /**
  * Class:     org_tensorflow_EagerOperation
  * Method:    outputListLength
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
index d0256435f48..4b7fdc8ccf8 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
@@ -54,6 +54,22 @@ public class EagerOperationTest {
     }
   }
   
+  @Test
+  public void outputTensor() {
+    try (EagerSession session = EagerSession.create()) {
+      EagerOperation add = opBuilder(session, "Add", "CompareResult")
+          .addInput(TestUtil.constant(session, "Const1", 2))
+          .addInput(TestUtil.constant(session, "Const2", 4))
+          .build();
+      assertEquals(6, add.tensor(0).intValue());
+      
+      // Validate that we retrieve the right shape and datatype from the tensor
+      // that has been resolved
+      assertEquals(0, add.shape(0).length);
+      assertEquals(DataType.INT32, add.dtype(0));
+    }    
+  }
+  
   @Test
   public void inputAndOutputListLengths() {
     try (EagerSession session = EagerSession.create()) {
@@ -105,7 +121,7 @@ public class EagerOperationTest {
   @Test
   public void opNotAccessibleIfSessionIsClosed() {
     EagerSession session = EagerSession.create();
-    EagerOperation add = opBuilder(session, "Add", "SetDevice")
+    EagerOperation add = opBuilder(session, "Add", "SessionClosed")
         .addInput(TestUtil.constant(session, "Const1", 2))
         .addInput(TestUtil.constant(session, "Const2", 4))
         .build();
@@ -119,6 +135,40 @@ public class EagerOperationTest {
     }
   }
   
+  @Test
+  public void outputIndexOutOfBounds() {
+    try (EagerSession session = EagerSession.create()) {
+      EagerOperation add = opBuilder(session, "Add", "OutOfRange")
+          .addInput(TestUtil.constant(session, "Const1", 2))
+          .addInput(TestUtil.constant(session, "Const2", 4))
+          .build();
+      try {
+          add.getUnsafeNativeHandle(1);
+          fail();
+      } catch (IndexOutOfBoundsException e) {
+        // expected
+      }    
+      try {
+          add.shape(1);
+          fail();
+      } catch (IndexOutOfBoundsException e) {
+        // expected
+      }    
+      try {
+          add.dtype(1);
+          fail();
+      } catch (IndexOutOfBoundsException e) {
+        // expected
+      }    
+      try {
+          add.tensor(1);
+          fail();
+      } catch (IndexOutOfBoundsException e) {
+        // expected
+      }    
+    }
+  }
+  
   private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
     return new EagerOperationBuilder(session, type, name);
   }
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
index 7331ad50e51..bfbf5385b48 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java
@@ -166,6 +166,17 @@ public class GraphOperationTest {
       }
     }
   }
+  
+  @Test
+  public void outputTensorNotSupported() {
+    try (Graph g = new Graph()) {
+      Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
+      try {
+        split.output(0).tensor();
+        fail();
+      } catch (IllegalStateException e) {}
+    }
+  }
 
   private static int split(int[] values, int num_split) {
     try (Graph g = new Graph()) {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 3229cce2776..21f4e25f5ab 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -18,6 +18,7 @@ package org.tensorflow;
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
 import java.nio.FloatBuffer;
 import java.nio.IntBuffer;
 import java.nio.LongBuffer;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -519,6 +521,25 @@ public class TensorTest {
       // The expected exception.
     }
   }
+  
+  @Test
+  public void eagerTensorIsReleasedAfterSessionIsClosed() {
+    Tensor<Integer> sum;
+    try (EagerSession session = EagerSession.create()) {
+       Output<?> x = TestUtil.constant(session, "Const1", 10);
+       Output<?> y = TestUtil.constant(session, "Const2", 20);
+       sum = TestUtil.<Integer>addN(session, x, y).tensor();
+       assertNotEquals(0L, sum.getNativeHandle());
+       assertEquals(30, sum.intValue());
+    }
+    assertEquals(0L, sum.getNativeHandle());
+    try {
+      sum.intValue();
+      fail();
+    } catch (NullPointerException e) {
+      // expected.
+    }
+  }
 
   @Test
   public void fromHandle() {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c97bcaa3386..6e24d88a310 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -67,8 +67,8 @@ public class TestUtil {
         .<T>output(0);
   }
 
-  public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
-    return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
+  public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) {
+    return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
   }
 
   public static <T> Output<T> matmul(