Add eager tensor support
This commit is contained in:
parent
4678439926
commit
fceea7d090
tensorflow/java/src
main
java/org/tensorflow
native
test/java/org/tensorflow
@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
|
|||||||
abstract long getUnsafeNativeHandle(int outputIdx);
|
abstract long getUnsafeNativeHandle(int outputIdx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
|
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
|
||||||
*
|
*
|
||||||
* @param outputIdx index of the output of this operation
|
* @param outputIdx index of the output of this operation
|
||||||
* @return output tensor shape
|
* @return output tensor shape
|
||||||
@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
|
|||||||
abstract long[] shape(int outputIdx);
|
abstract long[] shape(int outputIdx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
|
* Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
|
||||||
*
|
*
|
||||||
* @param outputIdx index of the output of this operation
|
* @param outputIdx index of the output of this operation
|
||||||
* @return output tensor datatype
|
* @return output tensor datatype
|
||||||
*/
|
*/
|
||||||
abstract DataType dtype(int outputIdx);
|
abstract DataType dtype(int outputIdx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the tensor of the {@code outputIdx}th output of this operation.
|
||||||
|
*
|
||||||
|
* <p>This is only supported in an eager execution environment.
|
||||||
|
*
|
||||||
|
* @param outputIdx index of the output of this operation
|
||||||
|
* @return output tensor
|
||||||
|
*/
|
||||||
|
abstract Tensor<?> tensor(int outputIdx);
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
package org.tensorflow;
|
package org.tensorflow;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.concurrent.atomic.AtomicReferenceArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of an {@link Operation} executed eagerly.
|
* Implementation of an {@link Operation} executed eagerly.
|
||||||
@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
|
|||||||
this.type = type;
|
this.type = type;
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
|
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
|
||||||
|
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] shape(int outputIndex) {
|
public long[] shape(int outputIndex) {
|
||||||
|
// If the tensor of this output has already been resolved, return its shape.
|
||||||
|
// Otherwise, retrieve the tensor shape from the native library.
|
||||||
|
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||||
|
if (tensor != null) {
|
||||||
|
return tensor.shape();
|
||||||
|
}
|
||||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||||
long[] shape = new long[numDims(outputNativeHandle)];
|
long[] shape = new long[numDims(outputNativeHandle)];
|
||||||
for (int i = 0; i < shape.length; ++i) {
|
for (int i = 0; i < shape.length; ++i) {
|
||||||
@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType dtype(int outputIndex) {
|
public DataType dtype(int outputIndex) {
|
||||||
|
// If the tensor of this output has already been resolved, return its datatype.
|
||||||
|
// Otherwise, retrieve the tensor datatype from the native library.
|
||||||
|
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||||
|
if (tensor != null) {
|
||||||
|
return tensor.dataType();
|
||||||
|
}
|
||||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||||
return DataType.fromC(dataType(outputNativeHandle));
|
return DataType.fromC(dataType(outputNativeHandle));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tensor<?> tensor(int outputIndex) {
|
||||||
|
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||||
|
if (tensor == null) {
|
||||||
|
tensor = resolveTensor(outputIndex);
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final EagerSession session;
|
||||||
|
private final NativeReference nativeRef;
|
||||||
|
private final String type;
|
||||||
|
private final String name;
|
||||||
|
private final AtomicReferenceArray<Tensor<?>> outputTensors;
|
||||||
|
|
||||||
|
private Tensor<?> resolveTensor(int outputIndex) {
|
||||||
|
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
|
||||||
|
// If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
|
||||||
|
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
|
||||||
|
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
|
||||||
|
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
|
||||||
|
tensor.close();
|
||||||
|
tensor = outputTensors.get(outputIndex);
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
private static class NativeReference extends EagerSession.NativeReference {
|
private static class NativeReference extends EagerSession.NativeReference {
|
||||||
|
|
||||||
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
|
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
|
||||||
@ -92,14 +132,14 @@ class EagerOperation extends AbstractOperation {
|
|||||||
@Override
|
@Override
|
||||||
void delete() {
|
void delete() {
|
||||||
if (opHandle != 0L) {
|
if (opHandle != 0L) {
|
||||||
for (long tensorHandle : outputHandles) {
|
for (int i = 0; i < outputHandles.length; ++i) {
|
||||||
if (tensorHandle != 0L) {
|
if (outputHandles[i] != 0L) {
|
||||||
EagerOperation.deleteTensorHandle(tensorHandle);
|
EagerOperation.deleteTensorHandle(outputHandles[i]);
|
||||||
|
outputHandles[i] = 0L;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EagerOperation.delete(opHandle);
|
EagerOperation.delete(opHandle);
|
||||||
opHandle = 0L;
|
opHandle = 0L;
|
||||||
Arrays.fill(outputHandles, 0L);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,15 +147,12 @@ class EagerOperation extends AbstractOperation {
|
|||||||
private final long[] outputHandles;
|
private final long[] outputHandles;
|
||||||
}
|
}
|
||||||
|
|
||||||
private final EagerSession session;
|
|
||||||
private final NativeReference nativeRef;
|
|
||||||
private final String type;
|
|
||||||
private final String name;
|
|
||||||
|
|
||||||
private static native void delete(long handle);
|
private static native void delete(long handle);
|
||||||
|
|
||||||
private static native void deleteTensorHandle(long handle);
|
private static native void deleteTensorHandle(long handle);
|
||||||
|
|
||||||
|
private static native long resolveTensorHandle(long handle);
|
||||||
|
|
||||||
private static native int outputListLength(long handle, String name);
|
private static native int outputListLength(long handle, String name);
|
||||||
|
|
||||||
private static native int inputListLength(long handle, String name);
|
private static native int inputListLength(long handle, String name);
|
||||||
|
@ -139,6 +139,11 @@ public final class GraphOperation extends AbstractOperation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Tensor<?> tensor(int outputIdx) {
|
||||||
|
throw new IllegalStateException("Graph tensors must be fetched by running a session");
|
||||||
|
}
|
||||||
|
|
||||||
long getUnsafeNativeHandle() {
|
long getUnsafeNativeHandle() {
|
||||||
return unsafeNativeHandle;
|
return unsafeNativeHandle;
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,22 @@ public final class Output<T> implements Operand<T> {
|
|||||||
return operation.dtype(index);
|
return operation.dtype(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the tensor at this output.
|
||||||
|
*
|
||||||
|
* <p>This operation is only supported on the outputs of an operation executed eagerly.
|
||||||
|
* For graph environments, output tensors must be fetched by running a session, using
|
||||||
|
* {@link Session.Runner#fetch(Output)}.
|
||||||
|
*
|
||||||
|
* @return tensor
|
||||||
|
* @throws IllegalStateException if this output results from a graph
|
||||||
|
* @see EagerSession
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public Tensor<T> tensor() {
|
||||||
|
return (Tensor<T>)operation.tensor(index);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Output<T> asOutput() {
|
public Output<T> asOutput() {
|
||||||
return this;
|
return this;
|
||||||
|
@ -140,15 +140,17 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
Tensor<?> t = new Tensor(dtype);
|
Tensor<?> t = new Tensor(dtype);
|
||||||
t.shapeCopy = new long[numDimensions(obj, dtype)];
|
t.shapeCopy = new long[numDimensions(obj, dtype)];
|
||||||
fillShape(obj, 0, t.shapeCopy);
|
fillShape(obj, 0, t.shapeCopy);
|
||||||
|
long nativeHandle;
|
||||||
if (t.dtype != DataType.STRING) {
|
if (t.dtype != DataType.STRING) {
|
||||||
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
|
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
|
||||||
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
|
nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
|
||||||
setValue(t.nativeHandle, obj);
|
setValue(nativeHandle, obj);
|
||||||
} else if (t.shapeCopy.length != 0) {
|
} else if (t.shapeCopy.length != 0) {
|
||||||
t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
|
nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
|
||||||
} else {
|
} else {
|
||||||
t.nativeHandle = allocateScalarBytes((byte[]) obj);
|
nativeHandle = allocateScalarBytes((byte[]) obj);
|
||||||
}
|
}
|
||||||
|
t.nativeRef = new NativeReference(nativeHandle);
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,23 +316,22 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
Tensor<T> t = new Tensor<T>(dataType);
|
Tensor<T> t = new Tensor<T>(dataType);
|
||||||
t.shapeCopy = Arrays.copyOf(shape, shape.length);
|
t.shapeCopy = Arrays.copyOf(shape, shape.length);
|
||||||
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
|
long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
|
||||||
|
t.nativeRef = new NativeReference(nativeHandle);
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Release resources associated with the Tensor.
|
* Release resources associated with the Tensor.
|
||||||
*
|
*
|
||||||
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
|
* <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
|
||||||
|
* operation or memory will be leaked.
|
||||||
*
|
*
|
||||||
* <p>The Tensor object is no longer usable after {@code close} returns.
|
* <p>The Tensor object is no longer usable after {@code close} returns.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
if (nativeHandle != 0) {
|
nativeRef.release();
|
||||||
delete(nativeHandle);
|
|
||||||
nativeHandle = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the {@link DataType} of elements stored in the Tensor. */
|
/** Returns the {@link DataType} of elements stored in the Tensor. */
|
||||||
@ -374,7 +375,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
|
||||||
*/
|
*/
|
||||||
public float floatValue() {
|
public float floatValue() {
|
||||||
return scalarFloat(nativeHandle);
|
return scalarFloat(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -383,7 +384,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
|
||||||
*/
|
*/
|
||||||
public double doubleValue() {
|
public double doubleValue() {
|
||||||
return scalarDouble(nativeHandle);
|
return scalarDouble(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -392,7 +393,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
|
||||||
*/
|
*/
|
||||||
public int intValue() {
|
public int intValue() {
|
||||||
return scalarInt(nativeHandle);
|
return scalarInt(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -401,7 +402,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
|
||||||
*/
|
*/
|
||||||
public long longValue() {
|
public long longValue() {
|
||||||
return scalarLong(nativeHandle);
|
return scalarLong(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -410,7 +411,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
||||||
*/
|
*/
|
||||||
public boolean booleanValue() {
|
public boolean booleanValue() {
|
||||||
return scalarBoolean(nativeHandle);
|
return scalarBoolean(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -419,7 +420,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
||||||
*/
|
*/
|
||||||
public byte[] bytesValue() {
|
public byte[] bytesValue() {
|
||||||
return scalarBytes(nativeHandle);
|
return scalarBytes(getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -448,7 +449,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
*/
|
*/
|
||||||
public <U> U copyTo(U dst) {
|
public <U> U copyTo(U dst) {
|
||||||
throwExceptionIfTypeIsIncompatible(dst);
|
throwExceptionIfTypeIsIncompatible(dst);
|
||||||
readNDArray(nativeHandle, dst);
|
readNDArray(getNativeHandle(), dst);
|
||||||
return dst;
|
return dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -553,16 +554,27 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
@SuppressWarnings("rawtypes")
|
@SuppressWarnings("rawtypes")
|
||||||
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
|
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
|
||||||
t.shapeCopy = shape(handle);
|
t.shapeCopy = shape(handle);
|
||||||
t.nativeHandle = handle;
|
t.nativeRef = new NativeReference(handle);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an eager Tensor object from a handle to the C TF_Tensor object.
|
||||||
|
*
|
||||||
|
* <p>Takes ownership of the handle.
|
||||||
|
*/
|
||||||
|
static Tensor<?> fromHandle(long handle, EagerSession session) {
|
||||||
|
Tensor<?> t = fromHandle(handle);
|
||||||
|
t.nativeRef.eager(session, t);
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
long getNativeHandle() {
|
long getNativeHandle() {
|
||||||
return nativeHandle;
|
return nativeRef.tensorHandle;
|
||||||
}
|
}
|
||||||
|
|
||||||
private long nativeHandle;
|
private NativeReference nativeRef = null;
|
||||||
private DataType dtype;
|
private final DataType dtype;
|
||||||
private long[] shapeCopy = null;
|
private long[] shapeCopy = null;
|
||||||
|
|
||||||
private Tensor(DataType t) {
|
private Tensor(DataType t) {
|
||||||
@ -570,7 +582,7 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private ByteBuffer buffer() {
|
private ByteBuffer buffer() {
|
||||||
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
|
return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
|
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
|
||||||
@ -609,6 +621,66 @@ public final class Tensor<T> implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reference to the underlying native tensor
|
||||||
|
*
|
||||||
|
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
|
||||||
|
* released after executing the last line of the `try` block they were declared in.
|
||||||
|
*
|
||||||
|
* <p>They can also be attached to an eager session, where in this case their lifetime ends either when
|
||||||
|
* this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
|
||||||
|
*
|
||||||
|
* <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to
|
||||||
|
* the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
|
||||||
|
* being explicetly closed before this happens, it will take cake of clearing its association with any eager
|
||||||
|
* session before cleaning up the resources.
|
||||||
|
*/
|
||||||
|
private static class NativeReference {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attaches this reference to an eager session
|
||||||
|
*/
|
||||||
|
private class EagerReference extends EagerSession.NativeReference {
|
||||||
|
|
||||||
|
EagerReference(EagerSession session, Tensor<?> tensor) {
|
||||||
|
super(session, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void delete() {
|
||||||
|
// Mark this eager reference as cleared since it has been deleted by the session
|
||||||
|
NativeReference.this.eagerRef = null;
|
||||||
|
NativeReference.this.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NativeReference(long tensorHandle) {
|
||||||
|
this.tensorHandle = tensorHandle;
|
||||||
|
}
|
||||||
|
|
||||||
|
void eager(EagerSession session, Tensor<?> tensor) {
|
||||||
|
if (eagerRef != null) {
|
||||||
|
throw new IllegalStateException("The tensor is already attached to an eager session");
|
||||||
|
}
|
||||||
|
eagerRef = new EagerReference(session, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
synchronized void release() {
|
||||||
|
if (tensorHandle != 0L) {
|
||||||
|
// Clear any remaining eager reference to this tensor
|
||||||
|
if (eagerRef != null) {
|
||||||
|
eagerRef.clear();
|
||||||
|
eagerRef = null;
|
||||||
|
}
|
||||||
|
Tensor.delete(tensorHandle);
|
||||||
|
tensorHandle = 0L;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long tensorHandle;
|
||||||
|
private EagerReference eagerRef;
|
||||||
|
}
|
||||||
|
|
||||||
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
|
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
|||||||
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
|
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle) {
|
||||||
|
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||||
|
if (tensor_handle == nullptr) return 0;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
|
||||||
|
"Cannot represent a C TF_Tensor as a Java long");
|
||||||
|
return reinterpret_cast<jlong>(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||||
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||||
TFE_Op* op = requireOp(env, handle);
|
TFE_Op* op = requireOp(env, handle);
|
||||||
|
@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
|
|||||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
||||||
JNIEnv *, jclass, jlong);
|
JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class: org_tensorflow_EagerOperation
|
||||||
|
* Method: resolveTensorHandle
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
|
||||||
|
JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Class: org_tensorflow_EagerOperation
|
* Class: org_tensorflow_EagerOperation
|
||||||
* Method: outputListLength
|
* Method: outputListLength
|
||||||
|
@ -54,6 +54,22 @@ public class EagerOperationTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void outputTensor() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
EagerOperation add = opBuilder(session, "Add", "CompareResult")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||||
|
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||||
|
.build();
|
||||||
|
assertEquals(6, add.tensor(0).intValue());
|
||||||
|
|
||||||
|
// Validate that we retrieve the right shape and datatype from the tensor
|
||||||
|
// that has been resolved
|
||||||
|
assertEquals(0, add.shape(0).length);
|
||||||
|
assertEquals(DataType.INT32, add.dtype(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void inputAndOutputListLengths() {
|
public void inputAndOutputListLengths() {
|
||||||
try (EagerSession session = EagerSession.create()) {
|
try (EagerSession session = EagerSession.create()) {
|
||||||
@ -105,7 +121,7 @@ public class EagerOperationTest {
|
|||||||
@Test
|
@Test
|
||||||
public void opNotAccessibleIfSessionIsClosed() {
|
public void opNotAccessibleIfSessionIsClosed() {
|
||||||
EagerSession session = EagerSession.create();
|
EagerSession session = EagerSession.create();
|
||||||
EagerOperation add = opBuilder(session, "Add", "SetDevice")
|
EagerOperation add = opBuilder(session, "Add", "SessionClosed")
|
||||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||||
.build();
|
.build();
|
||||||
@ -119,6 +135,40 @@ public class EagerOperationTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void outputIndexOutOfBounds() {
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
EagerOperation add = opBuilder(session, "Add", "OutOfRange")
|
||||||
|
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||||
|
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||||
|
.build();
|
||||||
|
try {
|
||||||
|
add.getUnsafeNativeHandle(1);
|
||||||
|
fail();
|
||||||
|
} catch (IndexOutOfBoundsException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
add.shape(1);
|
||||||
|
fail();
|
||||||
|
} catch (IndexOutOfBoundsException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
add.dtype(1);
|
||||||
|
fail();
|
||||||
|
} catch (IndexOutOfBoundsException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
add.tensor(1);
|
||||||
|
fail();
|
||||||
|
} catch (IndexOutOfBoundsException e) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||||
return new EagerOperationBuilder(session, type, name);
|
return new EagerOperationBuilder(session, type, name);
|
||||||
}
|
}
|
||||||
|
@ -167,6 +167,17 @@ public class GraphOperationTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void outputTensorNotSupported() {
|
||||||
|
try (Graph g = new Graph()) {
|
||||||
|
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||||
|
try {
|
||||||
|
split.output(0).tensor();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalStateException e) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static int split(int[] values, int num_split) {
|
private static int split(int[] values, int num_split) {
|
||||||
try (Graph g = new Graph()) {
|
try (Graph g = new Graph()) {
|
||||||
return g.opBuilder("Split", "Split")
|
return g.opBuilder("Split", "Split")
|
||||||
|
@ -18,6 +18,7 @@ package org.tensorflow;
|
|||||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNotEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
|
|||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
import java.nio.IntBuffer;
|
import java.nio.IntBuffer;
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
@ -520,6 +522,25 @@ public class TensorTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void eagerTensorIsReleasedAfterSessionIsClosed() {
|
||||||
|
Tensor<Integer> sum;
|
||||||
|
try (EagerSession session = EagerSession.create()) {
|
||||||
|
Output<?> x = TestUtil.constant(session, "Const1", 10);
|
||||||
|
Output<?> y = TestUtil.constant(session, "Const2", 20);
|
||||||
|
sum = TestUtil.<Integer>addN(session, x, y).tensor();
|
||||||
|
assertNotEquals(0L, sum.getNativeHandle());
|
||||||
|
assertEquals(30, sum.intValue());
|
||||||
|
}
|
||||||
|
assertEquals(0L, sum.getNativeHandle());
|
||||||
|
try {
|
||||||
|
sum.intValue();
|
||||||
|
fail();
|
||||||
|
} catch (NullPointerException e) {
|
||||||
|
// expected.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void fromHandle() {
|
public void fromHandle() {
|
||||||
// fromHandle is a package-visible method intended for use when the C TF_Tensor object has been
|
// fromHandle is a package-visible method intended for use when the C TF_Tensor object has been
|
||||||
|
@ -67,8 +67,8 @@ public class TestUtil {
|
|||||||
.<T>output(0);
|
.<T>output(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
|
public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) {
|
||||||
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
|
return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> Output<T> matmul(
|
public static <T> Output<T> matmul(
|
||||||
|
Loading…
Reference in New Issue
Block a user