Merge pull request #28475 from karllessard:java-eager-operation
PiperOrigin-RevId: 247638740
This commit is contained in:
commit
b422174124
@ -159,6 +159,19 @@ tf_java_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_java_test(
|
||||||
|
name = "EagerOperationTest",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"],
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
test_class = "org.tensorflow.EagerOperationTest",
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
":testutil",
|
||||||
|
"@junit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_java_test(
|
tf_java_test(
|
||||||
name = "GraphTest",
|
name = "GraphTest",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -23,6 +23,21 @@ package org.tensorflow;
|
|||||||
*/
|
*/
|
||||||
abstract class AbstractOperation implements Operation {
|
abstract class AbstractOperation implements Operation {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Output<?>[] outputList(int idx, int length) {
|
||||||
|
Output<?>[] outputs = new Output<?>[length];
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
outputs[i] = output(idx + i);
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||||
|
public <T> Output<T> output(int idx) {
|
||||||
|
return new Output(this, idx);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return String.format("<%s '%s'>", type(), name());
|
return String.format("<%s '%s'>", type(), name());
|
||||||
|
134
tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
Normal file
134
tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of an {@link Operation} executed eagerly.
|
||||||
|
*
|
||||||
|
* <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of
|
||||||
|
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the
|
||||||
|
* EagerOperation instance may fail with an {@code IllegalStateException}.
|
||||||
|
*
|
||||||
|
* <p>EagerOperation instances are thread-safe.
|
||||||
|
*/
|
||||||
|
class EagerOperation extends AbstractOperation {
|
||||||
|
|
||||||
|
EagerOperation(
|
||||||
|
EagerSession session,
|
||||||
|
long opNativeHandle,
|
||||||
|
long[] outputNativeHandles,
|
||||||
|
String type,
|
||||||
|
String name) {
|
||||||
|
this.session = session;
|
||||||
|
this.type = type;
|
||||||
|
this.name = name;
|
||||||
|
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String name() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String type() {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numOutputs() {
|
||||||
|
return nativeRef.outputHandles.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int outputListLength(final String name) {
|
||||||
|
return outputListLength(nativeRef.opHandle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int inputListLength(final String name) {
|
||||||
|
return inputListLength(nativeRef.opHandle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getUnsafeNativeHandle(int outputIndex) {
|
||||||
|
return nativeRef.outputHandles[outputIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long[] shape(int outputIndex) {
|
||||||
|
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||||
|
long[] shape = new long[numDims(outputNativeHandle)];
|
||||||
|
for (int i = 0; i < shape.length; ++i) {
|
||||||
|
shape[i] = dim(outputNativeHandle, i);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType dtype(int outputIndex) {
|
||||||
|
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||||
|
return DataType.fromC(dataType(outputNativeHandle));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class NativeReference extends EagerSession.NativeReference {
|
||||||
|
|
||||||
|
NativeReference(
|
||||||
|
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
|
||||||
|
super(session, operation);
|
||||||
|
this.opHandle = opHandle;
|
||||||
|
this.outputHandles = outputHandles;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void delete() {
|
||||||
|
if (opHandle != 0L) {
|
||||||
|
for (long tensorHandle : outputHandles) {
|
||||||
|
if (tensorHandle != 0L) {
|
||||||
|
EagerOperation.deleteTensorHandle(tensorHandle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EagerOperation.delete(opHandle);
|
||||||
|
opHandle = 0L;
|
||||||
|
Arrays.fill(outputHandles, 0L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long opHandle;
|
||||||
|
private final long[] outputHandles;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final EagerSession session;
|
||||||
|
private final NativeReference nativeRef;
|
||||||
|
private final String type;
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
private static native void delete(long handle);
|
||||||
|
|
||||||
|
private static native void deleteTensorHandle(long handle);
|
||||||
|
|
||||||
|
private static native int outputListLength(long handle, String name);
|
||||||
|
|
||||||
|
private static native int inputListLength(long handle, String name);
|
||||||
|
|
||||||
|
private static native int dataType(long handle);
|
||||||
|
|
||||||
|
private static native int numDims(long handle);
|
||||||
|
|
||||||
|
private static native long dim(long handle, int index);
|
||||||
|
}
|
@ -31,10 +31,14 @@ final class EagerOperationBuilder implements OperationBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Operation build() {
|
public EagerOperation build() {
|
||||||
// TODO (karllessard) Execute the eager operation and pass output tensor handles to new
|
long[] tensorHandles = execute(nativeRef.opHandle);
|
||||||
// EagerOperation class
|
EagerOperation operation =
|
||||||
throw new UnsupportedOperationException("Eager execution is not supported yet");
|
new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name);
|
||||||
|
// Release our reference to the native op handle now that we transferred its
|
||||||
|
// ownership to the EagerOperation
|
||||||
|
nativeRef.clear();
|
||||||
|
return operation;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -44,7 +48,7 @@ final class EagerOperationBuilder implements OperationBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder addInputList(Output<?>[] inputs) {
|
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
|
||||||
long[] inputHandles = new long[inputs.length];
|
long[] inputHandles = new long[inputs.length];
|
||||||
for (int i = 0; i < inputs.length; ++i) {
|
for (int i = 0; i < inputs.length; ++i) {
|
||||||
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
|
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
|
||||||
@ -60,18 +64,18 @@ final class EagerOperationBuilder implements OperationBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setDevice(String device) {
|
public EagerOperationBuilder setDevice(String device) {
|
||||||
setDevice(nativeRef.opHandle, device);
|
setDevice(nativeRef.opHandle, device);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, String value) {
|
public EagerOperationBuilder setAttr(String name, String value) {
|
||||||
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
|
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, String[] values) {
|
public EagerOperationBuilder setAttr(String name, String[] values) {
|
||||||
Charset utf8 = StandardCharsets.UTF_8;
|
Charset utf8 = StandardCharsets.UTF_8;
|
||||||
Object[] objects = new Object[values.length];
|
Object[] objects = new Object[values.length];
|
||||||
for (int i = 0; i < values.length; ++i) {
|
for (int i = 0; i < values.length; ++i) {
|
||||||
@ -82,55 +86,55 @@ final class EagerOperationBuilder implements OperationBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, byte[] values) {
|
public EagerOperationBuilder setAttr(String name, byte[] values) {
|
||||||
setAttrString(nativeRef.opHandle, name, values);
|
setAttrString(nativeRef.opHandle, name, values);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, long value) {
|
public EagerOperationBuilder setAttr(String name, long value) {
|
||||||
setAttrInt(nativeRef.opHandle, name, value);
|
setAttrInt(nativeRef.opHandle, name, value);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, long[] values) {
|
public EagerOperationBuilder setAttr(String name, long[] values) {
|
||||||
setAttrIntList(nativeRef.opHandle, name, values);
|
setAttrIntList(nativeRef.opHandle, name, values);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, float value) {
|
public EagerOperationBuilder setAttr(String name, float value) {
|
||||||
setAttrFloat(nativeRef.opHandle, name, value);
|
setAttrFloat(nativeRef.opHandle, name, value);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, float[] values) {
|
public EagerOperationBuilder setAttr(String name, float[] values) {
|
||||||
setAttrFloatList(nativeRef.opHandle, name, values);
|
setAttrFloatList(nativeRef.opHandle, name, values);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, boolean value) {
|
public EagerOperationBuilder setAttr(String name, boolean value) {
|
||||||
setAttrBool(nativeRef.opHandle, name, value);
|
setAttrBool(nativeRef.opHandle, name, value);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, boolean[] values) {
|
public EagerOperationBuilder setAttr(String name, boolean[] values) {
|
||||||
setAttrBoolList(nativeRef.opHandle, name, values);
|
setAttrBoolList(nativeRef.opHandle, name, values);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, DataType value) {
|
public EagerOperationBuilder setAttr(String name, DataType value) {
|
||||||
setAttrType(nativeRef.opHandle, name, value.c());
|
setAttrType(nativeRef.opHandle, name, value.c());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, DataType[] values) {
|
public EagerOperationBuilder setAttr(String name, DataType[] values) {
|
||||||
int[] c = new int[values.length];
|
int[] c = new int[values.length];
|
||||||
for (int i = 0; i < values.length; ++i) {
|
for (int i = 0; i < values.length; ++i) {
|
||||||
c[i] = values[i].c();
|
c[i] = values[i].c();
|
||||||
@ -140,26 +144,26 @@ final class EagerOperationBuilder implements OperationBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, Tensor<?> value) {
|
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
|
||||||
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
|
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, Tensor<?>[] values) {
|
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
|
||||||
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
|
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Tensor list attributes are not supported in eager mode");
|
"Tensor list attributes are not supported in eager mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, Shape value) {
|
public EagerOperationBuilder setAttr(String name, Shape value) {
|
||||||
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
|
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OperationBuilder setAttr(String name, Shape[] values) {
|
public EagerOperationBuilder setAttr(String name, Shape[] values) {
|
||||||
int[] numDimensions = new int[values.length];
|
int[] numDimensions = new int[values.length];
|
||||||
int totalNumDimensions = 0;
|
int totalNumDimensions = 0;
|
||||||
for (int idx = 0; idx < values.length; ++idx) {
|
for (int idx = 0; idx < values.length; ++idx) {
|
||||||
|
@ -75,21 +75,6 @@ public final class GraphOperation extends AbstractOperation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public Output<?>[] outputList(int idx, int length) {
|
|
||||||
Output<?>[] outputs = new Output<?>[length];
|
|
||||||
for (int i = 0; i < length; ++i) {
|
|
||||||
outputs[i] = output(idx + i);
|
|
||||||
}
|
|
||||||
return outputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
|
||||||
public <T> Output<T> output(int idx) {
|
|
||||||
return new Output(this, idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Long.valueOf(getUnsafeNativeHandle()).hashCode();
|
return Long.valueOf(getUnsafeNativeHandle()).hashCode();
|
||||||
|
@ -22,6 +22,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||||
|
|
||||||
|
// This value should be >= to the maximum number of outputs in any op
|
||||||
|
#define MAX_OUTPUTS_PER_OP 8
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
||||||
@ -89,7 +92,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
|
|||||||
JNIEnv* env, jclass clazz, jlong op_handle) {
|
JNIEnv* env, jclass clazz, jlong op_handle) {
|
||||||
TFE_Op* op = requireOp(env, op_handle);
|
TFE_Op* op = requireOp(env, op_handle);
|
||||||
if (op == nullptr) return 0;
|
if (op == nullptr) return 0;
|
||||||
int num_retvals = 8; // should be >= than max number of outputs in any op
|
int num_retvals = MAX_OUTPUTS_PER_OP;
|
||||||
std::unique_ptr<TFE_TensorHandle*[]> retvals(
|
std::unique_ptr<TFE_TensorHandle*[]> retvals(
|
||||||
new TFE_TensorHandle*[num_retvals]);
|
new TFE_TensorHandle*[num_retvals]);
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
130
tensorflow/java/src/main/native/eager_operation_jni.cc
Normal file
130
tensorflow/java/src/main/native/eager_operation_jni.cc
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/java/src/main/native/eager_operation_jni.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
||||||
|
if (handle == 0) {
|
||||||
|
throwException(env, kIllegalStateException,
|
||||||
|
"Eager session has been closed");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<TFE_Op*>(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
|
||||||
|
if (handle == 0) {
|
||||||
|
throwException(env, kIllegalStateException, "EagerSession has been closed");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<TFE_TensorHandle*>(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv* env,
|
||||||
|
jclass clazz,
|
||||||
|
jlong handle) {
|
||||||
|
if (handle == 0) return;
|
||||||
|
TFE_DeleteOp(reinterpret_cast<TFE_Op*>(handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle) {
|
||||||
|
if (handle == 0) return;
|
||||||
|
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||||
|
TFE_Op* op = requireOp(env, handle);
|
||||||
|
if (op == nullptr) return 0;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||||
|
int length = TFE_OpGetOutputLength(op, cname, status);
|
||||||
|
env->ReleaseStringUTFChars(name, cname);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<jint>(length);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||||
|
TFE_Op* op = requireOp(env, handle);
|
||||||
|
if (op == nullptr) return 0;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||||
|
int length = TFE_OpGetInputLength(op, cname, status);
|
||||||
|
env->ReleaseStringUTFChars(name, cname);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<jint>(length);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle) {
|
||||||
|
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||||
|
if (tensor_handle == nullptr) return 0;
|
||||||
|
TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle);
|
||||||
|
return static_cast<jint>(data_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle) {
|
||||||
|
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||||
|
if (tensor_handle == nullptr) return 0;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
int num_dims = TFE_TensorHandleNumDims(tensor_handle, status);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<jint>(num_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv* env,
|
||||||
|
jclass clazz,
|
||||||
|
jlong handle,
|
||||||
|
jint dim_index) {
|
||||||
|
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||||
|
if (tensor_handle == nullptr) return 0;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<jlong>(dim);
|
||||||
|
}
|
86
tensorflow/java/src/main/native/eager_operation_jni.h
Normal file
86
tensorflow/java/src/main/native/eager_operation_jni.h
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
||||||
|
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
||||||
|
|
||||||
|
#include <jni.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: delete
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *,
|
||||||
|
jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: deleteTensorHandle
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: outputListLength
|
||||||
|
* Signature: (JLjava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||||
|
JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: inputListLength
|
||||||
|
* Signature: (JLjava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
|
||||||
|
JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: dataType
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(JNIEnv *,
|
||||||
|
jclass,
|
||||||
|
jlong);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: numDims
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(JNIEnv *,
|
||||||
|
jclass,
|
||||||
|
jlong);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: dim
|
||||||
|
* Signature: (JI)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv *, jclass,
|
||||||
|
jlong, jint);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
|
@ -21,6 +21,7 @@ import org.junit.Test;
|
|||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
|
|
||||||
|
/** Unit tests for {@link EagerOperationBuilder} class. */
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
public class EagerOperationBuilderTest {
|
public class EagerOperationBuilderTest {
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ public class EagerOperationBuilderTest {
|
|||||||
new EagerOperationBuilder(session, "Add", "add");
|
new EagerOperationBuilder(session, "Add", "add");
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
|
// expected
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,8 +47,99 @@ public class EagerOperationBuilderTest {
|
|||||||
opBuilder.setAttr("dtype", DataType.FLOAT);
|
opBuilder.setAttr("dtype", DataType.FLOAT);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
|
// expected
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (karllessard) add more tests when EagerOperation is implemented as well
|
@Test
|
||||||
|
public void addInputs() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
Operation asrt =
|
||||||
|
opBuilder(session, "Assert", "assert")
|
||||||
|
.addInput(TestUtil.constant(session, "Cond", true))
|
||||||
|
.addInputList(new Output<?>[] {TestUtil.constant(session, "Error", -1)})
|
||||||
|
.build();
|
||||||
|
try {
|
||||||
|
opBuilder(session, "Const", "var").addControlInput(asrt);
|
||||||
|
fail();
|
||||||
|
} catch (UnsupportedOperationException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setDevice() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
opBuilder(session, "Add", "SetDevice")
|
||||||
|
.setDevice("/job:localhost/replica:0/task:0/device:CPU:0")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||||
|
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAttrs() {
|
||||||
|
// The effect of setting an attribute may not easily be visible from the other parts of this
|
||||||
|
// package's API. Thus, for now, the test simply executes the various setAttr variants to see
|
||||||
|
// that there are no exceptions.
|
||||||
|
//
|
||||||
|
// This is a bit of an awkward test since it has to find operations with attributes of specific
|
||||||
|
// types that aren't inferred from the input arguments.
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
// dtype, tensor attributes.
|
||||||
|
try (Tensor<Integer> t = Tensors.create(1)) {
|
||||||
|
opBuilder(session, "Const", "DataTypeAndTensor")
|
||||||
|
.setAttr("dtype", DataType.INT32)
|
||||||
|
.setAttr("value", t)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
// type, int (TF "int" attributes are 64-bit signed, so a Java long).
|
||||||
|
opBuilder(session, "RandomUniform", "DataTypeAndInt")
|
||||||
|
.addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1}))
|
||||||
|
.setAttr("seed", 10)
|
||||||
|
.setAttr("dtype", DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
// list(int), string
|
||||||
|
opBuilder(session, "MaxPool", "IntListAndString")
|
||||||
|
.addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2]))
|
||||||
|
.setAttr("ksize", new long[] {1, 1, 1, 1})
|
||||||
|
.setAttr("strides", new long[] {1, 1, 1, 1})
|
||||||
|
.setAttr("padding", "SAME")
|
||||||
|
.build();
|
||||||
|
// list(float), device
|
||||||
|
opBuilder(session, "FractionalMaxPool", "FloatList")
|
||||||
|
.addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2]))
|
||||||
|
.setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
|
||||||
|
.build();
|
||||||
|
// shape
|
||||||
|
opBuilder(session, "EnsureShape", "ShapeAttr")
|
||||||
|
.addInput(TestUtil.constant(session, "Const", new int[2][2]))
|
||||||
|
.setAttr("shape", Shape.make(2, 2))
|
||||||
|
.build();
|
||||||
|
// list(shape)
|
||||||
|
opBuilder(session, "FIFOQueue", "queue")
|
||||||
|
.setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
|
||||||
|
.setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)})
|
||||||
|
.build();
|
||||||
|
// bool
|
||||||
|
opBuilder(session, "All", "Bool")
|
||||||
|
.addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false}))
|
||||||
|
.addInput(TestUtil.constant(session, "Axis", 0))
|
||||||
|
.setAttr("keep_dims", false)
|
||||||
|
.build();
|
||||||
|
// float
|
||||||
|
opBuilder(session, "ApproximateEqual", "Float")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", 10.00001f))
|
||||||
|
.addInput(TestUtil.constant(session, "Const2", 10.00000f))
|
||||||
|
.setAttr("tolerance", 0.1f)
|
||||||
|
.build();
|
||||||
|
// Missing tests: list(string), list(byte), list(bool), list(type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||||
|
return new EagerOperationBuilder(session, type, name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,128 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
|
||||||
|
/** Unit tests for {@link EagerOperation} class. */
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public class EagerOperationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failToCreateIfSessionIsClosed() {
|
||||||
|
EagerSession session = EagerSession.create();
|
||||||
|
session.close();
|
||||||
|
try {
|
||||||
|
new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
|
||||||
|
fail();
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void outputDataTypeAndShape() {
|
||||||
|
try (EagerSession session = EagerSession.create();
|
||||||
|
Tensor<Integer> t = Tensors.create(new int[2][3])) {
|
||||||
|
EagerOperation op =
|
||||||
|
opBuilder(session, "Const", "OutputAttrs")
|
||||||
|
.setAttr("dtype", DataType.INT32)
|
||||||
|
.setAttr("value", t)
|
||||||
|
.build();
|
||||||
|
assertEquals(DataType.INT32, op.dtype(0));
|
||||||
|
assertEquals(2, op.shape(0)[0]);
|
||||||
|
assertEquals(3, op.shape(0)[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void inputAndOutputListLengths() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f});
|
||||||
|
Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f});
|
||||||
|
|
||||||
|
EagerOperation acc =
|
||||||
|
opBuilder(session, "AddN", "InputListLength")
|
||||||
|
.addInputList(new Output<?>[] {c1, c2})
|
||||||
|
.build();
|
||||||
|
assertEquals(2, acc.inputListLength("inputs"));
|
||||||
|
assertEquals(1, acc.outputListLength("sum"));
|
||||||
|
|
||||||
|
EagerOperation split =
|
||||||
|
opBuilder(session, "Split", "OutputListLength")
|
||||||
|
.addInput(TestUtil.constant(session, "Axis", 0))
|
||||||
|
.addInput(c1)
|
||||||
|
.setAttr("num_split", 2)
|
||||||
|
.build();
|
||||||
|
assertEquals(1, split.inputListLength("split_dim"));
|
||||||
|
assertEquals(2, split.outputListLength("output"));
|
||||||
|
|
||||||
|
try {
|
||||||
|
split.inputListLength("no_such_input");
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
split.outputListLength("no_such_output");
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void numOutputs() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
EagerOperation op =
|
||||||
|
opBuilder(session, "UniqueWithCountsV2", "unq")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1}))
|
||||||
|
.addInput(TestUtil.constant(session, "Axis", new int[] {0}))
|
||||||
|
.setAttr("out_idx", DataType.INT32)
|
||||||
|
.build();
|
||||||
|
assertEquals(3, op.numOutputs());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void opNotAccessibleIfSessionIsClosed() {
|
||||||
|
EagerSession session = EagerSession.create();
|
||||||
|
EagerOperation add =
|
||||||
|
opBuilder(session, "Add", "SetDevice")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||||
|
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||||
|
.build();
|
||||||
|
assertEquals(1, add.outputListLength("z"));
|
||||||
|
session.close();
|
||||||
|
try {
|
||||||
|
add.outputListLength("z");
|
||||||
|
fail();
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||||
|
return new EagerOperationBuilder(session, type, name);
|
||||||
|
}
|
||||||
|
}
|
@ -50,8 +50,14 @@ public class TestUtil {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> Output<T> constant(Graph g, String name, Object value) {
|
public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) {
|
||||||
return constantOp(g, name, value).<T>output(0);
|
try (Tensor<?> t = Tensor.create(value)) {
|
||||||
|
return env.opBuilder("Const", name)
|
||||||
|
.setAttr("dtype", t.dataType())
|
||||||
|
.setAttr("value", t)
|
||||||
|
.build()
|
||||||
|
.<T>output(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {
|
public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {
|
||||||
|
Loading…
Reference in New Issue
Block a user