Merge pull request #28475 from karllessard:java-eager-operation
PiperOrigin-RevId: 247638740
This commit is contained in:
commit
b422174124
@ -159,6 +159,19 @@ tf_java_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_java_test(
|
||||
name = "EagerOperationTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.EagerOperationTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
tf_java_test(
|
||||
name = "GraphTest",
|
||||
size = "small",
|
||||
|
@ -23,6 +23,21 @@ package org.tensorflow;
|
||||
*/
|
||||
abstract class AbstractOperation implements Operation {
|
||||
|
||||
@Override
|
||||
public Output<?>[] outputList(int idx, int length) {
|
||||
Output<?>[] outputs = new Output<?>[length];
|
||||
for (int i = 0; i < length; ++i) {
|
||||
outputs[i] = output(idx + i);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
public <T> Output<T> output(int idx) {
|
||||
return new Output(this, idx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("<%s '%s'>", type(), name());
|
||||
|
134
tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
Normal file
134
tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Implementation of an {@link Operation} executed eagerly.
|
||||
*
|
||||
* <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of
|
||||
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the
|
||||
* EagerOperation instance may fail with an {@code IllegalStateException}.
|
||||
*
|
||||
* <p>EagerOperation instances are thread-safe.
|
||||
*/
|
||||
class EagerOperation extends AbstractOperation {
|
||||
|
||||
EagerOperation(
|
||||
EagerSession session,
|
||||
long opNativeHandle,
|
||||
long[] outputNativeHandles,
|
||||
String type,
|
||||
String name) {
|
||||
this.session = session;
|
||||
this.type = type;
|
||||
this.name = name;
|
||||
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String type() {
|
||||
return type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numOutputs() {
|
||||
return nativeRef.outputHandles.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int outputListLength(final String name) {
|
||||
return outputListLength(nativeRef.opHandle, name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inputListLength(final String name) {
|
||||
return inputListLength(nativeRef.opHandle, name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getUnsafeNativeHandle(int outputIndex) {
|
||||
return nativeRef.outputHandles[outputIndex];
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] shape(int outputIndex) {
|
||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||
long[] shape = new long[numDims(outputNativeHandle)];
|
||||
for (int i = 0; i < shape.length; ++i) {
|
||||
shape[i] = dim(outputNativeHandle, i);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType dtype(int outputIndex) {
|
||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||
return DataType.fromC(dataType(outputNativeHandle));
|
||||
}
|
||||
|
||||
private static class NativeReference extends EagerSession.NativeReference {
|
||||
|
||||
NativeReference(
|
||||
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
|
||||
super(session, operation);
|
||||
this.opHandle = opHandle;
|
||||
this.outputHandles = outputHandles;
|
||||
}
|
||||
|
||||
@Override
|
||||
void delete() {
|
||||
if (opHandle != 0L) {
|
||||
for (long tensorHandle : outputHandles) {
|
||||
if (tensorHandle != 0L) {
|
||||
EagerOperation.deleteTensorHandle(tensorHandle);
|
||||
}
|
||||
}
|
||||
EagerOperation.delete(opHandle);
|
||||
opHandle = 0L;
|
||||
Arrays.fill(outputHandles, 0L);
|
||||
}
|
||||
}
|
||||
|
||||
private long opHandle;
|
||||
private final long[] outputHandles;
|
||||
}
|
||||
|
||||
private final EagerSession session;
|
||||
private final NativeReference nativeRef;
|
||||
private final String type;
|
||||
private final String name;
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
||||
private static native void deleteTensorHandle(long handle);
|
||||
|
||||
private static native int outputListLength(long handle, String name);
|
||||
|
||||
private static native int inputListLength(long handle, String name);
|
||||
|
||||
private static native int dataType(long handle);
|
||||
|
||||
private static native int numDims(long handle);
|
||||
|
||||
private static native long dim(long handle, int index);
|
||||
}
|
@ -31,10 +31,14 @@ final class EagerOperationBuilder implements OperationBuilder {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Operation build() {
|
||||
// TODO (karllessard) Execute the eager operation and pass output tensor handles to new
|
||||
// EagerOperation class
|
||||
throw new UnsupportedOperationException("Eager execution is not supported yet");
|
||||
public EagerOperation build() {
|
||||
long[] tensorHandles = execute(nativeRef.opHandle);
|
||||
EagerOperation operation =
|
||||
new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name);
|
||||
// Release our reference to the native op handle now that we transferred its
|
||||
// ownership to the EagerOperation
|
||||
nativeRef.clear();
|
||||
return operation;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -44,7 +48,7 @@ final class EagerOperationBuilder implements OperationBuilder {
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder addInputList(Output<?>[] inputs) {
|
||||
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
|
||||
long[] inputHandles = new long[inputs.length];
|
||||
for (int i = 0; i < inputs.length; ++i) {
|
||||
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
|
||||
@ -60,18 +64,18 @@ final class EagerOperationBuilder implements OperationBuilder {
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setDevice(String device) {
|
||||
public EagerOperationBuilder setDevice(String device) {
|
||||
setDevice(nativeRef.opHandle, device);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, String value) {
|
||||
public EagerOperationBuilder setAttr(String name, String value) {
|
||||
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, String[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, String[] values) {
|
||||
Charset utf8 = StandardCharsets.UTF_8;
|
||||
Object[] objects = new Object[values.length];
|
||||
for (int i = 0; i < values.length; ++i) {
|
||||
@ -82,55 +86,55 @@ final class EagerOperationBuilder implements OperationBuilder {
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, byte[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, byte[] values) {
|
||||
setAttrString(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, long value) {
|
||||
public EagerOperationBuilder setAttr(String name, long value) {
|
||||
setAttrInt(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, long[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, long[] values) {
|
||||
setAttrIntList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, float value) {
|
||||
public EagerOperationBuilder setAttr(String name, float value) {
|
||||
setAttrFloat(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, float[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, float[] values) {
|
||||
setAttrFloatList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, boolean value) {
|
||||
public EagerOperationBuilder setAttr(String name, boolean value) {
|
||||
setAttrBool(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, boolean[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, boolean[] values) {
|
||||
setAttrBoolList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, DataType value) {
|
||||
public EagerOperationBuilder setAttr(String name, DataType value) {
|
||||
setAttrType(nativeRef.opHandle, name, value.c());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, DataType[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, DataType[] values) {
|
||||
int[] c = new int[values.length];
|
||||
for (int i = 0; i < values.length; ++i) {
|
||||
c[i] = values[i].c();
|
||||
@ -140,26 +144,26 @@ final class EagerOperationBuilder implements OperationBuilder {
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Tensor<?> value) {
|
||||
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
|
||||
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Tensor<?>[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
|
||||
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
|
||||
throw new UnsupportedOperationException(
|
||||
"Tensor list attributes are not supported in eager mode");
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Shape value) {
|
||||
public EagerOperationBuilder setAttr(String name, Shape value) {
|
||||
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Shape[] values) {
|
||||
public EagerOperationBuilder setAttr(String name, Shape[] values) {
|
||||
int[] numDimensions = new int[values.length];
|
||||
int totalNumDimensions = 0;
|
||||
for (int idx = 0; idx < values.length; ++idx) {
|
||||
|
@ -75,21 +75,6 @@ public final class GraphOperation extends AbstractOperation {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output<?>[] outputList(int idx, int length) {
|
||||
Output<?>[] outputs = new Output<?>[length];
|
||||
for (int i = 0; i < length; ++i) {
|
||||
outputs[i] = output(idx + i);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
public <T> Output<T> output(int idx) {
|
||||
return new Output(this, idx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Long.valueOf(getUnsafeNativeHandle()).hashCode();
|
||||
|
@ -22,6 +22,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
|
||||
// This value should be >= to the maximum number of outputs in any op
|
||||
#define MAX_OUTPUTS_PER_OP 8
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
||||
@ -89,7 +92,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return 0;
|
||||
int num_retvals = 8; // should be >= than max number of outputs in any op
|
||||
int num_retvals = MAX_OUTPUTS_PER_OP;
|
||||
std::unique_ptr<TFE_TensorHandle*[]> retvals(
|
||||
new TFE_TensorHandle*[num_retvals]);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
130
tensorflow/java/src/main/native/eager_operation_jni.cc
Normal file
130
tensorflow/java/src/main/native/eager_operation_jni.cc
Normal file
@ -0,0 +1,130 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/java/src/main/native/eager_operation_jni.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException,
|
||||
"Eager session has been closed");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_Op*>(handle);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException, "EagerSession has been closed");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_TensorHandle*>(handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
if (handle == 0) return;
|
||||
TFE_DeleteOp(reinterpret_cast<TFE_Op*>(handle));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
if (handle == 0) return;
|
||||
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||
TFE_Op* op = requireOp(env, handle);
|
||||
if (op == nullptr) return 0;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||
int length = TFE_OpGetOutputLength(op, cname, status);
|
||||
env->ReleaseStringUTFChars(name, cname);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<jint>(length);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||
TFE_Op* op = requireOp(env, handle);
|
||||
if (op == nullptr) return 0;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||
int length = TFE_OpGetInputLength(op, cname, status);
|
||||
env->ReleaseStringUTFChars(name, cname);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<jint>(length);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||
if (tensor_handle == nullptr) return 0;
|
||||
TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle);
|
||||
return static_cast<jint>(data_type);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||
if (tensor_handle == nullptr) return 0;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
int num_dims = TFE_TensorHandleNumDims(tensor_handle, status);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<jint>(num_dims);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle,
|
||||
jint dim_index) {
|
||||
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||
if (tensor_handle == nullptr) return 0;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<jlong>(dim);
|
||||
}
|
86
tensorflow/java/src/main/native/eager_operation_jni.h
Normal file
86
tensorflow/java/src/main/native/eager_operation_jni.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
||||
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: delete
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *,
|
||||
jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: deleteTensorHandle
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: outputListLength
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||
JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: inputListLength
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
|
||||
JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: dataType
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(JNIEnv *,
|
||||
jclass,
|
||||
jlong);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: numDims
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(JNIEnv *,
|
||||
jclass,
|
||||
jlong);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: dim
|
||||
* Signature: (JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv *, jclass,
|
||||
jlong, jint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
@ -21,6 +21,7 @@ import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link EagerOperationBuilder} class. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class EagerOperationBuilderTest {
|
||||
|
||||
@ -32,6 +33,7 @@ public class EagerOperationBuilderTest {
|
||||
new EagerOperationBuilder(session, "Add", "add");
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,8 +47,99 @@ public class EagerOperationBuilderTest {
|
||||
opBuilder.setAttr("dtype", DataType.FLOAT);
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (karllessard) add more tests when EagerOperation is implemented as well
|
||||
@Test
|
||||
public void addInputs() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
Operation asrt =
|
||||
opBuilder(session, "Assert", "assert")
|
||||
.addInput(TestUtil.constant(session, "Cond", true))
|
||||
.addInputList(new Output<?>[] {TestUtil.constant(session, "Error", -1)})
|
||||
.build();
|
||||
try {
|
||||
opBuilder(session, "Const", "var").addControlInput(asrt);
|
||||
fail();
|
||||
} catch (UnsupportedOperationException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setDevice() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
opBuilder(session, "Add", "SetDevice")
|
||||
.setDevice("/job:localhost/replica:0/task:0/device:CPU:0")
|
||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAttrs() {
|
||||
// The effect of setting an attribute may not easily be visible from the other parts of this
|
||||
// package's API. Thus, for now, the test simply executes the various setAttr variants to see
|
||||
// that there are no exceptions.
|
||||
//
|
||||
// This is a bit of an awkward test since it has to find operations with attributes of specific
|
||||
// types that aren't inferred from the input arguments.
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
// dtype, tensor attributes.
|
||||
try (Tensor<Integer> t = Tensors.create(1)) {
|
||||
opBuilder(session, "Const", "DataTypeAndTensor")
|
||||
.setAttr("dtype", DataType.INT32)
|
||||
.setAttr("value", t)
|
||||
.build();
|
||||
}
|
||||
// type, int (TF "int" attributes are 64-bit signed, so a Java long).
|
||||
opBuilder(session, "RandomUniform", "DataTypeAndInt")
|
||||
.addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1}))
|
||||
.setAttr("seed", 10)
|
||||
.setAttr("dtype", DataType.FLOAT)
|
||||
.build();
|
||||
// list(int), string
|
||||
opBuilder(session, "MaxPool", "IntListAndString")
|
||||
.addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2]))
|
||||
.setAttr("ksize", new long[] {1, 1, 1, 1})
|
||||
.setAttr("strides", new long[] {1, 1, 1, 1})
|
||||
.setAttr("padding", "SAME")
|
||||
.build();
|
||||
// list(float), device
|
||||
opBuilder(session, "FractionalMaxPool", "FloatList")
|
||||
.addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2]))
|
||||
.setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
|
||||
.build();
|
||||
// shape
|
||||
opBuilder(session, "EnsureShape", "ShapeAttr")
|
||||
.addInput(TestUtil.constant(session, "Const", new int[2][2]))
|
||||
.setAttr("shape", Shape.make(2, 2))
|
||||
.build();
|
||||
// list(shape)
|
||||
opBuilder(session, "FIFOQueue", "queue")
|
||||
.setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
|
||||
.setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)})
|
||||
.build();
|
||||
// bool
|
||||
opBuilder(session, "All", "Bool")
|
||||
.addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false}))
|
||||
.addInput(TestUtil.constant(session, "Axis", 0))
|
||||
.setAttr("keep_dims", false)
|
||||
.build();
|
||||
// float
|
||||
opBuilder(session, "ApproximateEqual", "Float")
|
||||
.addInput(TestUtil.constant(session, "Const1", 10.00001f))
|
||||
.addInput(TestUtil.constant(session, "Const2", 10.00000f))
|
||||
.setAttr("tolerance", 0.1f)
|
||||
.build();
|
||||
// Missing tests: list(string), list(byte), list(bool), list(type)
|
||||
}
|
||||
}
|
||||
|
||||
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||
return new EagerOperationBuilder(session, type, name);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,128 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link EagerOperation} class. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class EagerOperationTest {
|
||||
|
||||
@Test
|
||||
public void failToCreateIfSessionIsClosed() {
|
||||
EagerSession session = EagerSession.create();
|
||||
session.close();
|
||||
try {
|
||||
new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputDataTypeAndShape() {
|
||||
try (EagerSession session = EagerSession.create();
|
||||
Tensor<Integer> t = Tensors.create(new int[2][3])) {
|
||||
EagerOperation op =
|
||||
opBuilder(session, "Const", "OutputAttrs")
|
||||
.setAttr("dtype", DataType.INT32)
|
||||
.setAttr("value", t)
|
||||
.build();
|
||||
assertEquals(DataType.INT32, op.dtype(0));
|
||||
assertEquals(2, op.shape(0)[0]);
|
||||
assertEquals(3, op.shape(0)[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inputAndOutputListLengths() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f});
|
||||
Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f});
|
||||
|
||||
EagerOperation acc =
|
||||
opBuilder(session, "AddN", "InputListLength")
|
||||
.addInputList(new Output<?>[] {c1, c2})
|
||||
.build();
|
||||
assertEquals(2, acc.inputListLength("inputs"));
|
||||
assertEquals(1, acc.outputListLength("sum"));
|
||||
|
||||
EagerOperation split =
|
||||
opBuilder(session, "Split", "OutputListLength")
|
||||
.addInput(TestUtil.constant(session, "Axis", 0))
|
||||
.addInput(c1)
|
||||
.setAttr("num_split", 2)
|
||||
.build();
|
||||
assertEquals(1, split.inputListLength("split_dim"));
|
||||
assertEquals(2, split.outputListLength("output"));
|
||||
|
||||
try {
|
||||
split.inputListLength("no_such_input");
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// expected
|
||||
}
|
||||
|
||||
try {
|
||||
split.outputListLength("no_such_output");
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void numOutputs() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
EagerOperation op =
|
||||
opBuilder(session, "UniqueWithCountsV2", "unq")
|
||||
.addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1}))
|
||||
.addInput(TestUtil.constant(session, "Axis", new int[] {0}))
|
||||
.setAttr("out_idx", DataType.INT32)
|
||||
.build();
|
||||
assertEquals(3, op.numOutputs());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void opNotAccessibleIfSessionIsClosed() {
|
||||
EagerSession session = EagerSession.create();
|
||||
EagerOperation add =
|
||||
opBuilder(session, "Add", "SetDevice")
|
||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||
.build();
|
||||
assertEquals(1, add.outputListLength("z"));
|
||||
session.close();
|
||||
try {
|
||||
add.outputListLength("z");
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||
return new EagerOperationBuilder(session, type, name);
|
||||
}
|
||||
}
|
@ -50,8 +50,14 @@ public class TestUtil {
|
||||
}
|
||||
}
|
||||
|
||||
public static <T> Output<T> constant(Graph g, String name, Object value) {
|
||||
return constantOp(g, name, value).<T>output(0);
|
||||
public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) {
|
||||
try (Tensor<?> t = Tensor.create(value)) {
|
||||
return env.opBuilder("Const", name)
|
||||
.setAttr("dtype", t.dataType())
|
||||
.setAttr("value", t)
|
||||
.build()
|
||||
.<T>output(0);
|
||||
}
|
||||
}
|
||||
|
||||
public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {
|
||||
|
Loading…
Reference in New Issue
Block a user