From bf05237063d4a81f6b9ba064f2eb4157ef3c7921 Mon Sep 17 00:00:00 2001 From: Karl Lessard <karl@kubx.ca> Date: Fri, 3 May 2019 09:22:09 -0400 Subject: [PATCH] Eager operations implementation --- tensorflow/java/BUILD | 13 ++ .../org/tensorflow/AbstractOperation.java | 15 ++ .../java/org/tensorflow/EagerOperation.java | 128 ++++++++++++++++++ .../org/tensorflow/EagerOperationBuilder.java | 45 +++--- .../java/org/tensorflow/GraphOperation.java | 15 -- .../native/eager_operation_builder_jni.cc | 5 +- .../src/main/native/eager_operation_jni.cc | 126 +++++++++++++++++ .../src/main/native/eager_operation_jni.h | 84 ++++++++++++ .../tensorflow/EagerOperationBuilderTest.java | 89 +++++++++++- .../org/tensorflow/EagerOperationTest.java | 114 ++++++++++++++++ .../test/java/org/tensorflow/TestUtil.java | 10 +- 11 files changed, 604 insertions(+), 40 deletions(-) create mode 100644 tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java create mode 100644 tensorflow/java/src/main/native/eager_operation_jni.cc create mode 100644 tensorflow/java/src/main/native/eager_operation_jni.h create mode 100644 tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index a23ae9c60e9..6a71cd1e9da 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -159,6 +159,19 @@ tf_java_test( ], ) +tf_java_test( + name = "EagerOperationTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.EagerOperationTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + tf_java_test( name = "GraphTest", size = "small", diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java index a1d95f246b2..0d4745fe0b7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java @@ -23,6 +23,21 @@ package org.tensorflow; */ abstract class AbstractOperation implements Operation { + @Override + public Output<?>[] outputList(int idx, int length) { + Output<?>[] outputs = new Output<?>[length]; + for (int i = 0; i < length; ++i) { + outputs[i] = output(idx + i); + } + return outputs; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public <T> Output<T> output(int idx) { + return new Output(this, idx); + } + @Override public String toString() { return String.format("<%s '%s'>", type(), name()); diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java new file mode 100644 index 00000000000..a0530d7b9da --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java @@ -0,0 +1,128 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.util.Arrays; + +/** + * Implementation of an {@link Operation} executed eagerly. + * + * <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of is + * valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the EagerOperation + * instance may fail with an {@code IllegalStateException}. + * + * <p>EagerOperation instances are thread-safe. + */ +class EagerOperation extends AbstractOperation { + + EagerOperation(EagerSession session, long opNativeHandle, long[] outputNativeHandles, String type, String name) { + this.session = session; + this.type = type; + this.name = name; + this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles); + } + + @Override + public String name() { + return name; + } + + @Override + public String type() { + return type; + } + + @Override + public int numOutputs() { + return nativeRef.outputHandles.length; + } + + @Override + public int outputListLength(final String name) { + return outputListLength(nativeRef.opHandle, name); + } + + @Override + public int inputListLength(final String name) { + return inputListLength(nativeRef.opHandle, name); + } + + @Override + public long getUnsafeNativeHandle(int outputIndex) { + return nativeRef.outputHandles[outputIndex]; + } + + @Override + public long[] shape(int outputIndex) { + long outputNativeHandle = getUnsafeNativeHandle(outputIndex); + long[] shape = new long[numDims(outputNativeHandle)]; + for (int i = 0; i < shape.length; ++i) { + shape[i] = dim(outputNativeHandle, i); + } + return shape; + } + + @Override + public DataType dtype(int outputIndex) { + long outputNativeHandle = getUnsafeNativeHandle(outputIndex); + return DataType.fromC(dataType(outputNativeHandle)); + } + + private static class NativeReference extends EagerSession.NativeReference { + + NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) { + super(session, operation); + this.opHandle = opHandle; + this.outputHandles = outputHandles; + } + + @Override + void delete() { + if (opHandle != 0L) { + for (long tensorHandle : outputHandles) { + if (tensorHandle != 0L) { + EagerOperation.deleteTensorHandle(tensorHandle); + } + } + EagerOperation.delete(opHandle); + opHandle = 0L; + Arrays.fill(outputHandles, 0L); + } + } + + private long opHandle; + private final long[] outputHandles; + } + + private final EagerSession session; + private final NativeReference nativeRef; + private final String type; + private final String name; + + private static native void delete(long handle); + + private static native void deleteTensorHandle(long handle); + + private static native int outputListLength(long handle, String name); + + private static native int inputListLength(long handle, String name); + + private static native int dataType(long handle); + + private static native int numDims(long handle); + + private static native long dim(long handle, int index); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java index a8cb5b5c318..2097f4ad4fa 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -31,10 +31,13 @@ final class EagerOperationBuilder implements OperationBuilder { } @Override - public Operation build() { - // TODO (karllessard) Execute the eager operation and pass output tensor handles to new - // EagerOperation class - throw new UnsupportedOperationException("Eager execution is not supported yet"); + public EagerOperation build() { + long[] tensorHandles = execute(nativeRef.opHandle); + EagerOperation operation = new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name); + // Release our reference to the native op handle now that we transferred its + // ownership to the EagerOperation + nativeRef.clear(); + return operation; } @Override @@ -44,7 +47,7 @@ final class EagerOperationBuilder implements OperationBuilder { } @Override - public OperationBuilder addInputList(Output<?>[] inputs) { + public EagerOperationBuilder addInputList(Output<?>[] inputs) { long[] inputHandles = new long[inputs.length]; for (int i = 0; i < inputs.length; ++i) { inputHandles[i] = inputs[i].getUnsafeNativeHandle(); @@ -60,18 +63,18 @@ final class EagerOperationBuilder implements OperationBuilder { } @Override - public OperationBuilder setDevice(String device) { + public EagerOperationBuilder setDevice(String device) { setDevice(nativeRef.opHandle, device); return this; } @Override - public OperationBuilder setAttr(String name, String value) { + public EagerOperationBuilder setAttr(String name, String value) { return setAttr(name, value.getBytes(StandardCharsets.UTF_8)); } @Override - public OperationBuilder setAttr(String name, String[] values) { + public EagerOperationBuilder setAttr(String name, String[] values) { Charset utf8 = StandardCharsets.UTF_8; Object[] objects = new Object[values.length]; for (int i = 0; i < values.length; ++i) { @@ -82,55 +85,55 @@ final class EagerOperationBuilder implements OperationBuilder { } @Override - public OperationBuilder setAttr(String name, byte[] values) { + public EagerOperationBuilder setAttr(String name, byte[] values) { setAttrString(nativeRef.opHandle, name, values); return this; } @Override - public OperationBuilder setAttr(String name, long value) { + public EagerOperationBuilder setAttr(String name, long value) { setAttrInt(nativeRef.opHandle, name, value); return this; } @Override - public OperationBuilder setAttr(String name, long[] values) { + public EagerOperationBuilder setAttr(String name, long[] values) { setAttrIntList(nativeRef.opHandle, name, values); return this; } @Override - public OperationBuilder setAttr(String name, float value) { + public EagerOperationBuilder setAttr(String name, float value) { setAttrFloat(nativeRef.opHandle, name, value); return this; } @Override - public OperationBuilder setAttr(String name, float[] values) { + public EagerOperationBuilder setAttr(String name, float[] values) { setAttrFloatList(nativeRef.opHandle, name, values); return this; } @Override - public OperationBuilder setAttr(String name, boolean value) { + public EagerOperationBuilder setAttr(String name, boolean value) { setAttrBool(nativeRef.opHandle, name, value); return this; } @Override - public OperationBuilder setAttr(String name, boolean[] values) { + public EagerOperationBuilder setAttr(String name, boolean[] values) { setAttrBoolList(nativeRef.opHandle, name, values); return this; } @Override - public OperationBuilder setAttr(String name, DataType value) { + public EagerOperationBuilder setAttr(String name, DataType value) { setAttrType(nativeRef.opHandle, name, value.c()); return this; } @Override - public OperationBuilder setAttr(String name, DataType[] values) { + public EagerOperationBuilder setAttr(String name, DataType[] values) { int[] c = new int[values.length]; for (int i = 0; i < values.length; ++i) { c[i] = values[i].c(); @@ -140,26 +143,26 @@ final class EagerOperationBuilder implements OperationBuilder { } @Override - public OperationBuilder setAttr(String name, Tensor<?> value) { + public EagerOperationBuilder setAttr(String name, Tensor<?> value) { setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle()); return this; } @Override - public OperationBuilder setAttr(String name, Tensor<?>[] values) { + public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) { // TODO (karllessard) could be supported by adding this attribute type in the eager C API throw new UnsupportedOperationException( "Tensor list attributes are not supported in eager mode"); } @Override - public OperationBuilder setAttr(String name, Shape value) { + public EagerOperationBuilder setAttr(String name, Shape value) { setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions()); return this; } @Override - public OperationBuilder setAttr(String name, Shape[] values) { + public EagerOperationBuilder setAttr(String name, Shape[] values) { int[] numDimensions = new int[values.length]; int totalNumDimensions = 0; for (int idx = 0; idx < values.length; ++idx) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java index 31b80d306a2..0e43bc3eb43 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java @@ -75,21 +75,6 @@ public final class GraphOperation extends AbstractOperation { } } - @Override - public Output<?>[] outputList(int idx, int length) { - Output<?>[] outputs = new Output<?>[length]; - for (int i = 0; i < length; ++i) { - outputs[i] = output(idx + i); - } - return outputs; - } - - @Override - @SuppressWarnings({"rawtypes", "unchecked"}) - public <T> Output<T> output(int idx) { - return new Output(this, idx); - } - @Override public int hashCode() { return Long.valueOf(getUnsafeNativeHandle()).hashCode(); diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc index 654c453176b..f8ed2072ba0 100644 --- a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc +++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc @@ -22,6 +22,9 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/java/src/main/native/exception_jni.h" +// This value should be >= to the maximum number of outputs in any op +#define MAX_OUTPUTS_PER_OP 8 + namespace { TFE_Op* requireOp(JNIEnv* env, jlong handle) { @@ -89,7 +92,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute( JNIEnv* env, jclass clazz, jlong op_handle) { TFE_Op* op = requireOp(env, op_handle); if (op == nullptr) return 0; - int num_retvals = 8; // should be >= than max number of outputs in any op + int num_retvals = MAX_OUTPUTS_PER_OP; std::unique_ptr<TFE_TensorHandle*[]> retvals( new TFE_TensorHandle*[num_retvals]); TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc new file mode 100644 index 00000000000..3a5f6f90ddc --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_jni.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <assert.h> +#include <stdlib.h> +#include <string.h> +#include <algorithm> +#include <memory> + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/java/src/main/native/eager_operation_jni.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +namespace { + +TFE_Op* requireOp(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Eager session has been closed"); + return nullptr; + } + return reinterpret_cast<TFE_Op*>(handle); +} + +TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "EagerSession has been closed"); + return nullptr; + } + return reinterpret_cast<TFE_TensorHandle*>(handle); +} + +} // namespace + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete( + JNIEnv* env, jclass clazz, jlong handle) { + if (handle == 0) return; + TFE_DeleteOp(reinterpret_cast<TFE_Op*>(handle)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle( + JNIEnv* env, jclass clazz, jlong handle) { + if (handle == 0) return; + TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + const char* cname = env->GetStringUTFChars(name, nullptr); + int length = TFE_OpGetOutputLength(op, cname, status); + env->ReleaseStringUTFChars(name, cname); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast<jint>(length); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + const char* cname = env->GetStringUTFChars(name, nullptr); + int length = TFE_OpGetInputLength(op, cname, status); + env->ReleaseStringUTFChars(name, cname); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast<jint>(length); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle); + return static_cast<jint>(data_type); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + int num_dims = TFE_TensorHandleNumDims(tensor_handle, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast<jint>(num_dims); +} + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim( + JNIEnv* env, jclass clazz, jlong handle, jint dim_index) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast<jlong>(dim); +} diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h new file mode 100644 index 00000000000..f9684b0a26e --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_jni.h @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ + +#include <jni.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_EagerOperation + * Method: delete + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete( + JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperation + * Method: deleteTensorHandle + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle( + JNIEnv *, jclass, jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: outputListLength + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( + JNIEnv *, jclass, jlong, jstring); + +/** + * Class: org_tensorflow_EagerOperation + * Method: inputListLength + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength( + JNIEnv *, jclass, jlong, jstring); + +/** + * Class: org_tensorflow_EagerOperation + * Method: dataType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType( + JNIEnv *, jclass, jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: numDims + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims( + JNIEnv *, jclass, jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: dim + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim( + JNIEnv *, jclass, jlong, jint); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index 38534297e8c..696a2098f3e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -48,5 +48,92 @@ public class EagerOperationBuilderTest { } } - // TODO (karllessard) add more tests when EagerOperation is implemented as well + @Test + public void addInputs() { + try (EagerSession session = EagerSession.create()) { + Operation asrt = opBuilder(session, "Assert", "assert") + .addInput(TestUtil.constant(session, "Cond", true)) + .addInputList(new Output<?>[] {TestUtil.constant(session, "Error", -1)}) + .build(); + try { + opBuilder(session, "Const", "var").addControlInput(asrt); + fail(); + } catch (UnsupportedOperationException e) {} + } + } + + @Test + public void setDevice() { + try (EagerSession session = EagerSession.create()) { + opBuilder(session, "Add", "SetDevice") + .setDevice("/job:localhost/replica:0/task:0/device:CPU:0") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + } + } + + @Test + public void setAttrs() { + // The effect of setting an attribute may not easily be visible from the other parts of this + // package's API. Thus, for now, the test simply executes the various setAttr variants to see + // that there are no exceptions. + // + // This is a bit of an awkward test since it has to find operations with attributes of specific + // types that aren't inferred from the input arguments. + try (EagerSession session = EagerSession.create()) { + // dtype, tensor attributes. + try (Tensor<Integer> t = Tensors.create(1)) { + opBuilder(session, "Const", "DataTypeAndTensor") + .setAttr("dtype", DataType.INT32) + .setAttr("value", t) + .build(); + } + // type, int (TF "int" attributes are 64-bit signed, so a Java long). + opBuilder(session, "RandomUniform", "DataTypeAndInt") + .addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1})) + .setAttr("seed", 10) + .setAttr("dtype", DataType.FLOAT) + .build(); + // list(int), string + opBuilder(session, "MaxPool", "IntListAndString") + .addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2])) + .setAttr("ksize", new long[] {1, 1, 1, 1}) + .setAttr("strides", new long[] {1, 1, 1, 1}) + .setAttr("padding", "SAME") + .build(); + // list(float), device + opBuilder(session, "FractionalMaxPool", "FloatList") + .addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2])) + .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) + .build(); + // shape + opBuilder(session, "EnsureShape", "ShapeAttr") + .addInput(TestUtil.constant(session, "Const", new int[2][2])) + .setAttr("shape", Shape.make(2, 2)) + .build(); + // list(shape) + opBuilder(session, "FIFOQueue", "queue") + .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) + .setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}) + .build(); + // bool + opBuilder(session, "All", "Bool") + .addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false})) + .addInput(TestUtil.constant(session, "Axis", 0)) + .setAttr("keep_dims", false) + .build(); + // float + opBuilder(session, "ApproximateEqual", "Float") + .addInput(TestUtil.constant(session, "Const1", 10.00001f)) + .addInput(TestUtil.constant(session, "Const2", 10.00000f)) + .setAttr("tolerance", 0.1f) + .build(); + // Missing tests: list(string), list(byte), list(bool), list(type) + } + } + + private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { + return new EagerOperationBuilder(session, type, name); + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java new file mode 100644 index 00000000000..5a5f5508487 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class EagerOperationTest { + + @Test + public void failToCreateIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + session.close(); + try { + new EagerOperation(session, 1L, new long[] {1L}, "Add", "add"); + fail(); + } catch (IllegalStateException e) {} + } + + @Test + public void outputDataTypeAndShape() { + try (EagerSession session = EagerSession.create(); + Tensor<Integer> t = Tensors.create(new int[2][3])) { + EagerOperation op = opBuilder(session, "Const", "OutputAttrs") + .setAttr("dtype", DataType.INT32) + .setAttr("value", t) + .build(); + assertEquals(DataType.INT32, op.dtype(0)); + assertEquals(2, op.shape(0)[0]); + assertEquals(3, op.shape(0)[1]); + } + } + + @Test + public void inputAndOutputListLengths() { + try (EagerSession session = EagerSession.create()) { + Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f}); + Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f}); + + EagerOperation acc = opBuilder(session, "AddN", "InputListLength") + .addInputList(new Output<?>[] {c1, c2}) + .build(); + assertEquals(2, acc.inputListLength("inputs")); + assertEquals(1, acc.outputListLength("sum")); + + EagerOperation split = opBuilder(session, "Split", "OutputListLength") + .addInput(TestUtil.constant(session, "Axis", 0)) + .addInput(c1) + .setAttr("num_split", 2) + .build(); + assertEquals(1, split.inputListLength("split_dim")); + assertEquals(2, split.outputListLength("output")); + + try { + split.inputListLength("no_such_input"); + fail(); + } catch (IllegalArgumentException e) {} + + try { + split.outputListLength("no_such_output"); + fail(); + } catch (IllegalArgumentException e) {} + } + } + + @Test + public void numOutputs() { + try (EagerSession session = EagerSession.create()) { + EagerOperation op = opBuilder(session, "UniqueWithCountsV2", "unq") + .addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1})) + .addInput(TestUtil.constant(session, "Axis", new int[] {0})) + .setAttr("out_idx", DataType.INT32) + .build(); + assertEquals(3, op.numOutputs()); + } + } + + @Test + public void opNotAccessibleIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + EagerOperation add = opBuilder(session, "Add", "SetDevice") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + assertEquals(1, add.outputListLength("z")); + session.close(); + try { + add.outputListLength("z"); + fail(); + } catch (IllegalStateException e) {} + } + + private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { + return new EagerOperationBuilder(session, type, name); + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 9b48f6a3dc9..c97bcaa3386 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -50,8 +50,14 @@ public class TestUtil { } } - public static <T> Output<T> constant(Graph g, String name, Object value) { - return constantOp(g, name, value).<T>output(0); + public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) { + try (Tensor<?> t = Tensor.create(value)) { + return env.opBuilder("Const", name) + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build() + .<T>output(0); + } } public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {