From bf05237063d4a81f6b9ba064f2eb4157ef3c7921 Mon Sep 17 00:00:00 2001
From: Karl Lessard <karl@kubx.ca>
Date: Fri, 3 May 2019 09:22:09 -0400
Subject: [PATCH 1/2] Eager operations implementation

---
 tensorflow/java/BUILD                         |  13 ++
 .../org/tensorflow/AbstractOperation.java     |  15 ++
 .../java/org/tensorflow/EagerOperation.java   | 128 ++++++++++++++++++
 .../org/tensorflow/EagerOperationBuilder.java |  45 +++---
 .../java/org/tensorflow/GraphOperation.java   |  15 --
 .../native/eager_operation_builder_jni.cc     |   5 +-
 .../src/main/native/eager_operation_jni.cc    | 126 +++++++++++++++++
 .../src/main/native/eager_operation_jni.h     |  84 ++++++++++++
 .../tensorflow/EagerOperationBuilderTest.java |  89 +++++++++++-
 .../org/tensorflow/EagerOperationTest.java    | 114 ++++++++++++++++
 .../test/java/org/tensorflow/TestUtil.java    |  10 +-
 11 files changed, 604 insertions(+), 40 deletions(-)
 create mode 100644 tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
 create mode 100644 tensorflow/java/src/main/native/eager_operation_jni.cc
 create mode 100644 tensorflow/java/src/main/native/eager_operation_jni.h
 create mode 100644 tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java

diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index a23ae9c60e9..6a71cd1e9da 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -159,6 +159,19 @@ tf_java_test(
     ],
 )
 
+tf_java_test(
+    name = "EagerOperationTest",
+    size = "small",
+    srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"],
+    javacopts = JAVACOPTS,
+    test_class = "org.tensorflow.EagerOperationTest",
+    deps = [
+        ":tensorflow",
+        ":testutil",
+        "@junit",
+    ],
+)
+
 tf_java_test(
     name = "GraphTest",
     size = "small",
diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
index a1d95f246b2..0d4745fe0b7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java
@@ -23,6 +23,21 @@ package org.tensorflow;
  */
 abstract class AbstractOperation implements Operation {
 
+  @Override
+  public Output<?>[] outputList(int idx, int length) {
+    Output<?>[] outputs = new Output<?>[length];
+    for (int i = 0; i < length; ++i) {
+      outputs[i] = output(idx + i);
+    }
+    return outputs;
+  }
+
+  @Override
+  @SuppressWarnings({"rawtypes", "unchecked"})
+  public <T> Output<T> output(int idx) {
+    return new Output(this, idx);
+  }
+
   @Override
   public String toString() {
     return String.format("<%s '%s'>", type(), name());
diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
new file mode 100644
index 00000000000..a0530d7b9da
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
@@ -0,0 +1,128 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+import java.util.Arrays;
+
+/**
+ * Implementation of an {@link Operation} executed eagerly.
+ * 
+ * <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of is
+ * valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the EagerOperation
+ * instance may fail with an {@code IllegalStateException}.
+ *
+ * <p>EagerOperation instances are thread-safe.
+ */
+class EagerOperation extends AbstractOperation {
+  
+  EagerOperation(EagerSession session, long opNativeHandle, long[] outputNativeHandles, String type, String name) {
+    this.session = session;
+    this.type = type;
+    this.name = name;
+    this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
+  }
+
+  @Override
+  public String name() {
+    return name;
+  }
+
+  @Override
+  public String type() {
+    return type;
+  }
+
+  @Override
+  public int numOutputs() {
+    return nativeRef.outputHandles.length;
+  }
+
+  @Override
+  public int outputListLength(final String name) {
+    return outputListLength(nativeRef.opHandle, name);
+  }
+
+  @Override
+  public int inputListLength(final String name) {
+    return inputListLength(nativeRef.opHandle, name);
+  }
+
+  @Override
+  public long getUnsafeNativeHandle(int outputIndex) {
+    return nativeRef.outputHandles[outputIndex];
+  }
+
+  @Override
+  public long[] shape(int outputIndex) {
+    long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
+    long[] shape = new long[numDims(outputNativeHandle)];
+    for (int i = 0; i < shape.length; ++i) {
+      shape[i] = dim(outputNativeHandle, i);
+    }
+    return shape;
+  }
+
+  @Override
+  public DataType dtype(int outputIndex) {
+    long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
+    return DataType.fromC(dataType(outputNativeHandle));
+  }
+
+  private static class NativeReference extends EagerSession.NativeReference {
+
+    NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
+      super(session, operation);
+      this.opHandle = opHandle;
+      this.outputHandles = outputHandles;
+    }
+
+    @Override
+    void delete() {
+      if (opHandle != 0L) {
+        for (long tensorHandle : outputHandles) {
+          if (tensorHandle != 0L) {
+            EagerOperation.deleteTensorHandle(tensorHandle);
+          }
+        }
+        EagerOperation.delete(opHandle);
+        opHandle = 0L;
+        Arrays.fill(outputHandles, 0L);
+      }
+    }
+    
+    private long opHandle;
+    private final long[] outputHandles;
+  }
+
+  private final EagerSession session;
+  private final NativeReference nativeRef;
+  private final String type;
+  private final String name;
+  
+  private static native void delete(long handle);
+
+  private static native void deleteTensorHandle(long handle);
+  
+  private static native int outputListLength(long handle, String name);
+
+  private static native int inputListLength(long handle, String name);
+
+  private static native int dataType(long handle);
+
+  private static native int numDims(long handle);
+
+  private static native long dim(long handle, int index);
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java
index a8cb5b5c318..2097f4ad4fa 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java
@@ -31,10 +31,13 @@ final class EagerOperationBuilder implements OperationBuilder {
   }
 
   @Override
-  public Operation build() {
-    // TODO (karllessard) Execute the eager operation and pass output tensor handles to new
-    // EagerOperation class
-    throw new UnsupportedOperationException("Eager execution is not supported yet");
+  public EagerOperation build() {
+    long[] tensorHandles = execute(nativeRef.opHandle);
+    EagerOperation operation = new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name);
+    // Release our reference to the native op handle now that we transferred its 
+    // ownership to the EagerOperation
+    nativeRef.clear();
+    return operation;
   }
 
   @Override
@@ -44,7 +47,7 @@ final class EagerOperationBuilder implements OperationBuilder {
   }
 
   @Override
-  public OperationBuilder addInputList(Output<?>[] inputs) {
+  public EagerOperationBuilder addInputList(Output<?>[] inputs) {
     long[] inputHandles = new long[inputs.length];
     for (int i = 0; i < inputs.length; ++i) {
       inputHandles[i] = inputs[i].getUnsafeNativeHandle();
@@ -60,18 +63,18 @@ final class EagerOperationBuilder implements OperationBuilder {
   }
 
   @Override
-  public OperationBuilder setDevice(String device) {
+  public EagerOperationBuilder setDevice(String device) {
     setDevice(nativeRef.opHandle, device);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, String value) {
+  public EagerOperationBuilder setAttr(String name, String value) {
     return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
   }
 
   @Override
-  public OperationBuilder setAttr(String name, String[] values) {
+  public EagerOperationBuilder setAttr(String name, String[] values) {
     Charset utf8 = StandardCharsets.UTF_8;
     Object[] objects = new Object[values.length];
     for (int i = 0; i < values.length; ++i) {
@@ -82,55 +85,55 @@ final class EagerOperationBuilder implements OperationBuilder {
   }
 
   @Override
-  public OperationBuilder setAttr(String name, byte[] values) {
+  public EagerOperationBuilder setAttr(String name, byte[] values) {
     setAttrString(nativeRef.opHandle, name, values);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, long value) {
+  public EagerOperationBuilder setAttr(String name, long value) {
     setAttrInt(nativeRef.opHandle, name, value);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, long[] values) {
+  public EagerOperationBuilder setAttr(String name, long[] values) {
     setAttrIntList(nativeRef.opHandle, name, values);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, float value) {
+  public EagerOperationBuilder setAttr(String name, float value) {
     setAttrFloat(nativeRef.opHandle, name, value);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, float[] values) {
+  public EagerOperationBuilder setAttr(String name, float[] values) {
     setAttrFloatList(nativeRef.opHandle, name, values);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, boolean value) {
+  public EagerOperationBuilder setAttr(String name, boolean value) {
     setAttrBool(nativeRef.opHandle, name, value);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, boolean[] values) {
+  public EagerOperationBuilder setAttr(String name, boolean[] values) {
     setAttrBoolList(nativeRef.opHandle, name, values);
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, DataType value) {
+  public EagerOperationBuilder setAttr(String name, DataType value) {
     setAttrType(nativeRef.opHandle, name, value.c());
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, DataType[] values) {
+  public EagerOperationBuilder setAttr(String name, DataType[] values) {
     int[] c = new int[values.length];
     for (int i = 0; i < values.length; ++i) {
       c[i] = values[i].c();
@@ -140,26 +143,26 @@ final class EagerOperationBuilder implements OperationBuilder {
   }
 
   @Override
-  public OperationBuilder setAttr(String name, Tensor<?> value) {
+  public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
     setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, Tensor<?>[] values) {
+  public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
     // TODO (karllessard) could be supported by adding this attribute type in the eager C API
     throw new UnsupportedOperationException(
         "Tensor list attributes are not supported in eager mode");
   }
 
   @Override
-  public OperationBuilder setAttr(String name, Shape value) {
+  public EagerOperationBuilder setAttr(String name, Shape value) {
     setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
     return this;
   }
 
   @Override
-  public OperationBuilder setAttr(String name, Shape[] values) {
+  public EagerOperationBuilder setAttr(String name, Shape[] values) {
     int[] numDimensions = new int[values.length];
     int totalNumDimensions = 0;
     for (int idx = 0; idx < values.length; ++idx) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
index 31b80d306a2..0e43bc3eb43 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java
@@ -75,21 +75,6 @@ public final class GraphOperation extends AbstractOperation {
     }
   }
 
-  @Override
-  public Output<?>[] outputList(int idx, int length) {
-    Output<?>[] outputs = new Output<?>[length];
-    for (int i = 0; i < length; ++i) {
-      outputs[i] = output(idx + i);
-    }
-    return outputs;
-  }
-
-  @Override
-  @SuppressWarnings({"rawtypes", "unchecked"})
-  public <T> Output<T> output(int idx) {
-    return new Output(this, idx);
-  }
-
   @Override
   public int hashCode() {
     return Long.valueOf(getUnsafeNativeHandle()).hashCode();
diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc
index 654c453176b..f8ed2072ba0 100644
--- a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc
+++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc
@@ -22,6 +22,9 @@ limitations under the License.
 #include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/java/src/main/native/exception_jni.h"
 
+// This value should be >= to the maximum number of outputs in any op
+#define MAX_OUTPUTS_PER_OP 8
+
 namespace {
 
 TFE_Op* requireOp(JNIEnv* env, jlong handle) {
@@ -89,7 +92,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
     JNIEnv* env, jclass clazz, jlong op_handle) {
   TFE_Op* op = requireOp(env, op_handle);
   if (op == nullptr) return 0;
-  int num_retvals = 8;  // should be >= than max number of outputs in any op
+  int num_retvals = MAX_OUTPUTS_PER_OP;
   std::unique_ptr<TFE_TensorHandle*[]> retvals(
       new TFE_TensorHandle*[num_retvals]);
   TF_Status* status = TF_NewStatus();
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc
new file mode 100644
index 00000000000..3a5f6f90ddc
--- /dev/null
+++ b/tensorflow/java/src/main/native/eager_operation_jni.cc
@@ -0,0 +1,126 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <assert.h>
+#include <stdlib.h>
+#include <string.h>
+#include <algorithm>
+#include <memory>
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/java/src/main/native/eager_operation_jni.h"
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+namespace {
+
+TFE_Op* requireOp(JNIEnv* env, jlong handle) {
+  if (handle == 0) {
+    throwException(env, kIllegalStateException,
+                   "Eager session has been closed");
+    return nullptr;
+  }
+  return reinterpret_cast<TFE_Op*>(handle);
+}
+
+TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
+  if (handle == 0) {
+    throwException(env, kIllegalStateException,
+                   "EagerSession has been closed");
+    return nullptr;
+  }
+  return reinterpret_cast<TFE_TensorHandle*>(handle);
+}
+
+}  // namespace
+
+JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  if (handle == 0) return;
+  TFE_DeleteOp(reinterpret_cast<TFE_Op*>(handle));
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  if (handle == 0) return;
+  TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
+    JNIEnv* env, jclass clazz, jlong handle, jstring name) {
+  TFE_Op* op = requireOp(env, handle);
+  if (op == nullptr) return 0;
+  TF_Status* status = TF_NewStatus();
+  const char* cname = env->GetStringUTFChars(name, nullptr);
+  int length = TFE_OpGetOutputLength(op, cname, status);
+  env->ReleaseStringUTFChars(name, cname);
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
+    return 0;
+  }
+  TF_DeleteStatus(status);
+  return static_cast<jint>(length);
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
+    JNIEnv* env, jclass clazz, jlong handle, jstring name) {
+  TFE_Op* op = requireOp(env, handle);
+  if (op == nullptr) return 0;
+  TF_Status* status = TF_NewStatus();
+  const char* cname = env->GetStringUTFChars(name, nullptr);
+  int length = TFE_OpGetInputLength(op, cname, status);
+  env->ReleaseStringUTFChars(name, cname);
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
+    return 0;
+  }
+  TF_DeleteStatus(status);
+  return static_cast<jint>(length);
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
+  if (tensor_handle == nullptr) return 0;
+  TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle);
+  return static_cast<jint>(data_type);
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
+    JNIEnv* env, jclass clazz, jlong handle) {
+  TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
+  if (tensor_handle == nullptr) return 0;
+  TF_Status* status = TF_NewStatus();
+  int num_dims = TFE_TensorHandleNumDims(tensor_handle, status);
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
+    return 0;
+  }
+  TF_DeleteStatus(status);
+  return static_cast<jint>(num_dims);
+}
+
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(
+    JNIEnv* env, jclass clazz, jlong handle, jint dim_index) {
+  TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
+  if (tensor_handle == nullptr) return 0;
+  TF_Status* status = TF_NewStatus();
+  int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status);
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
+    return 0;
+  }
+  TF_DeleteStatus(status);
+  return static_cast<jlong>(dim);
+}
diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h
new file mode 100644
index 00000000000..f9684b0a26e
--- /dev/null
+++ b/tensorflow/java/src/main/native/eager_operation_jni.h
@@ -0,0 +1,84 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
+
+#include <jni.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    delete
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
+    JNIEnv *, jclass, jlong);
+
+/*
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    deleteTensorHandle
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
+    JNIEnv *, jclass, jlong);
+
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    outputListLength
+ * Signature: (JLjava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
+    JNIEnv *, jclass, jlong, jstring);
+
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    inputListLength
+ * Signature: (JLjava/lang/String;)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
+    JNIEnv *, jclass, jlong, jstring);
+
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    dataType
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
+    JNIEnv *, jclass, jlong);
+
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    numDims
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
+    JNIEnv *, jclass, jlong);
+
+/**
+ * Class:     org_tensorflow_EagerOperation
+ * Method:    dim
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(
+    JNIEnv *, jclass, jlong, jint);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+#endif  // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
index 38534297e8c..696a2098f3e 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
@@ -48,5 +48,92 @@ public class EagerOperationBuilderTest {
     }
   }
 
-  // TODO (karllessard) add more tests when EagerOperation is implemented as well
+  @Test
+  public void addInputs() {
+    try (EagerSession session = EagerSession.create()) {
+      Operation asrt = opBuilder(session, "Assert", "assert")
+          .addInput(TestUtil.constant(session, "Cond", true))
+          .addInputList(new Output<?>[] {TestUtil.constant(session, "Error", -1)})
+          .build();
+      try {
+        opBuilder(session, "Const", "var").addControlInput(asrt);
+        fail();
+      } catch (UnsupportedOperationException e) {}
+    }
+  }
+  
+  @Test
+  public void setDevice() {
+    try (EagerSession session = EagerSession.create()) {
+      opBuilder(session, "Add", "SetDevice")
+          .setDevice("/job:localhost/replica:0/task:0/device:CPU:0")
+          .addInput(TestUtil.constant(session, "Const1", 2))
+          .addInput(TestUtil.constant(session, "Const2", 4))
+          .build();
+    }
+  }
+
+  @Test
+  public void setAttrs() {
+    // The effect of setting an attribute may not easily be visible from the other parts of this
+    // package's API. Thus, for now, the test simply executes the various setAttr variants to see
+    // that there are no exceptions.
+    //
+    // This is a bit of an awkward test since it has to find operations with attributes of specific
+    // types that aren't inferred from the input arguments.
+    try (EagerSession session = EagerSession.create()) {
+      // dtype, tensor attributes.
+      try (Tensor<Integer> t = Tensors.create(1)) {
+        opBuilder(session, "Const", "DataTypeAndTensor")
+            .setAttr("dtype", DataType.INT32)
+            .setAttr("value", t)
+            .build();
+      }
+      // type, int (TF "int" attributes are 64-bit signed, so a Java long).
+      opBuilder(session, "RandomUniform", "DataTypeAndInt")
+          .addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1}))
+          .setAttr("seed", 10)
+          .setAttr("dtype", DataType.FLOAT)
+          .build();
+      // list(int), string
+      opBuilder(session, "MaxPool", "IntListAndString")
+          .addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2]))
+          .setAttr("ksize", new long[] {1, 1, 1, 1})
+          .setAttr("strides", new long[] {1, 1, 1, 1})
+          .setAttr("padding", "SAME")
+          .build();
+      // list(float), device
+      opBuilder(session, "FractionalMaxPool", "FloatList")
+          .addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2]))
+          .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
+          .build();
+      // shape
+      opBuilder(session, "EnsureShape", "ShapeAttr")
+          .addInput(TestUtil.constant(session, "Const", new int[2][2]))
+          .setAttr("shape", Shape.make(2, 2))
+          .build();
+      // list(shape)
+      opBuilder(session, "FIFOQueue", "queue")
+          .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
+          .setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)})
+          .build();
+      // bool
+      opBuilder(session, "All", "Bool")
+          .addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false}))
+          .addInput(TestUtil.constant(session, "Axis", 0))
+          .setAttr("keep_dims", false)
+          .build();
+      // float
+      opBuilder(session, "ApproximateEqual", "Float")
+          .addInput(TestUtil.constant(session, "Const1", 10.00001f))
+          .addInput(TestUtil.constant(session, "Const2", 10.00000f))
+          .setAttr("tolerance", 0.1f)    
+          .build();
+      // Missing tests: list(string), list(byte), list(bool), list(type)
+    }
+  }
+  
+  private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
+    return new EagerOperationBuilder(session, type, name);
+  }
 }
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
new file mode 100644
index 00000000000..5a5f5508487
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
@@ -0,0 +1,114 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class EagerOperationTest {
+  
+  @Test
+  public void failToCreateIfSessionIsClosed() {
+    EagerSession session = EagerSession.create();
+    session.close();
+    try {
+      new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
+      fail();
+    } catch (IllegalStateException e) {}
+  }
+  
+  @Test
+  public void outputDataTypeAndShape() {
+    try (EagerSession session = EagerSession.create();
+        Tensor<Integer> t = Tensors.create(new int[2][3])) {
+      EagerOperation op = opBuilder(session, "Const", "OutputAttrs")
+          .setAttr("dtype", DataType.INT32)
+          .setAttr("value", t)
+          .build();
+      assertEquals(DataType.INT32, op.dtype(0));
+      assertEquals(2, op.shape(0)[0]);
+      assertEquals(3, op.shape(0)[1]);
+    }
+  }
+  
+  @Test
+  public void inputAndOutputListLengths() {
+    try (EagerSession session = EagerSession.create()) {
+      Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f});
+      Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f});
+
+      EagerOperation acc = opBuilder(session, "AddN", "InputListLength")
+          .addInputList(new Output<?>[] {c1, c2})
+          .build();
+      assertEquals(2, acc.inputListLength("inputs"));
+      assertEquals(1, acc.outputListLength("sum"));
+
+      EagerOperation split = opBuilder(session, "Split", "OutputListLength")
+          .addInput(TestUtil.constant(session, "Axis", 0))
+          .addInput(c1)
+          .setAttr("num_split", 2)
+          .build();
+      assertEquals(1, split.inputListLength("split_dim"));
+      assertEquals(2, split.outputListLength("output"));
+
+      try {
+        split.inputListLength("no_such_input");
+        fail();
+      } catch (IllegalArgumentException e) {}
+
+      try {
+        split.outputListLength("no_such_output");
+        fail();
+      } catch (IllegalArgumentException e) {}
+    }
+  }
+  
+  @Test
+  public void numOutputs() {
+    try (EagerSession session = EagerSession.create()) {
+      EagerOperation op = opBuilder(session, "UniqueWithCountsV2", "unq")
+          .addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1}))
+          .addInput(TestUtil.constant(session, "Axis", new int[] {0}))
+          .setAttr("out_idx", DataType.INT32)
+          .build();
+      assertEquals(3, op.numOutputs());
+    }    
+  }
+  
+  @Test
+  public void opNotAccessibleIfSessionIsClosed() {
+    EagerSession session = EagerSession.create();
+    EagerOperation add = opBuilder(session, "Add", "SetDevice")
+        .addInput(TestUtil.constant(session, "Const1", 2))
+        .addInput(TestUtil.constant(session, "Const2", 4))
+        .build();
+    assertEquals(1, add.outputListLength("z"));
+    session.close();
+    try {
+      add.outputListLength("z");
+      fail();
+    } catch (IllegalStateException e) {}
+  }
+  
+  private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
+    return new EagerOperationBuilder(session, type, name);
+  }
+}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 9b48f6a3dc9..c97bcaa3386 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -50,8 +50,14 @@ public class TestUtil {
     }
   }
 
-  public static <T> Output<T> constant(Graph g, String name, Object value) {
-    return constantOp(g, name, value).<T>output(0);
+  public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) {
+    try (Tensor<?> t = Tensor.create(value)) {
+      return env.opBuilder("Const", name)
+          .setAttr("dtype", t.dataType())
+          .setAttr("value", t)
+          .build()
+          .<T>output(0);
+    }
   }
 
   public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {

From 4678439926b69c8435213f0827e62f3c103012fb Mon Sep 17 00:00:00 2001
From: Karl Lessard <karl@kubx.ca>
Date: Thu, 9 May 2019 22:23:54 -0400
Subject: [PATCH 2/2] Fix internal lint errors

---
 .../tensorflow/EagerOperationBuilderTest.java |  9 ++++++++-
 .../org/tensorflow/EagerOperationTest.java    | 19 +++++++++++++++----
 2 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
index 696a2098f3e..83b683dde6e 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
@@ -21,6 +21,9 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+/**
+ * Unit tests for {@link EagerOperationBuilder} class.
+ */
 @RunWith(JUnit4.class)
 public class EagerOperationBuilderTest {
 
@@ -32,6 +35,7 @@ public class EagerOperationBuilderTest {
       new EagerOperationBuilder(session, "Add", "add");
       fail();
     } catch (IllegalStateException e) {
+      // expected
     }
   }
 
@@ -45,6 +49,7 @@ public class EagerOperationBuilderTest {
       opBuilder.setAttr("dtype", DataType.FLOAT);
       fail();
     } catch (IllegalStateException e) {
+      // expected
     }
   }
 
@@ -58,7 +63,9 @@ public class EagerOperationBuilderTest {
       try {
         opBuilder(session, "Const", "var").addControlInput(asrt);
         fail();
-      } catch (UnsupportedOperationException e) {}
+      } catch (UnsupportedOperationException e) {
+        // expected
+      }
     }
   }
   
diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
index 5a5f5508487..d0256435f48 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java
@@ -22,6 +22,9 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+/**
+ * Unit tests for {@link EagerOperation} class.
+ */
 @RunWith(JUnit4.class)
 public class EagerOperationTest {
   
@@ -32,7 +35,9 @@ public class EagerOperationTest {
     try {
       new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
       fail();
-    } catch (IllegalStateException e) {}
+    } catch (IllegalStateException e) {
+      // expected
+    }
   }
   
   @Test
@@ -72,12 +77,16 @@ public class EagerOperationTest {
       try {
         split.inputListLength("no_such_input");
         fail();
-      } catch (IllegalArgumentException e) {}
+      } catch (IllegalArgumentException e) {
+        // expected
+      }
 
       try {
         split.outputListLength("no_such_output");
         fail();
-      } catch (IllegalArgumentException e) {}
+      } catch (IllegalArgumentException e) {
+        // expected
+      }
     }
   }
   
@@ -105,7 +114,9 @@ public class EagerOperationTest {
     try {
       add.outputListLength("z");
       fail();
-    } catch (IllegalStateException e) {}
+    } catch (IllegalStateException e) {
+      // expected
+    }
   }
   
   private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {