Add eager tensor support
This commit is contained in:
parent
4678439926
commit
fceea7d090
@ -60,7 +60,7 @@ abstract class AbstractOperation implements Operation {
|
||||
abstract long getUnsafeNativeHandle(int outputIdx);
|
||||
|
||||
/**
|
||||
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
|
||||
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
|
||||
*
|
||||
* @param outputIdx index of the output of this operation
|
||||
* @return output tensor shape
|
||||
@ -68,10 +68,20 @@ abstract class AbstractOperation implements Operation {
|
||||
abstract long[] shape(int outputIdx);
|
||||
|
||||
/**
|
||||
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
|
||||
* Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
|
||||
*
|
||||
* @param outputIdx index of the output of this operation
|
||||
* @return output tensor datatype
|
||||
*/
|
||||
abstract DataType dtype(int outputIdx);
|
||||
|
||||
/**
|
||||
* Returns the tensor of the {@code outputIdx}th output of this operation.
|
||||
*
|
||||
* <p>This is only supported in an eager execution environment.
|
||||
*
|
||||
* @param outputIdx index of the output of this operation
|
||||
* @return output tensor
|
||||
*/
|
||||
abstract Tensor<?> tensor(int outputIdx);
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.atomic.AtomicReferenceArray;
|
||||
|
||||
/**
|
||||
* Implementation of an {@link Operation} executed eagerly.
|
||||
@ -33,6 +33,7 @@ class EagerOperation extends AbstractOperation {
|
||||
this.type = type;
|
||||
this.name = name;
|
||||
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
|
||||
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -67,6 +68,12 @@ class EagerOperation extends AbstractOperation {
|
||||
|
||||
@Override
|
||||
public long[] shape(int outputIndex) {
|
||||
// If the tensor of this output has already been resolved, return its shape.
|
||||
// Otherwise, retrieve the tensor shape from the native library.
|
||||
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||
if (tensor != null) {
|
||||
return tensor.shape();
|
||||
}
|
||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||
long[] shape = new long[numDims(outputNativeHandle)];
|
||||
for (int i = 0; i < shape.length; ++i) {
|
||||
@ -77,10 +84,43 @@ class EagerOperation extends AbstractOperation {
|
||||
|
||||
@Override
|
||||
public DataType dtype(int outputIndex) {
|
||||
// If the tensor of this output has already been resolved, return its datatype.
|
||||
// Otherwise, retrieve the tensor datatype from the native library.
|
||||
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||
if (tensor != null) {
|
||||
return tensor.dataType();
|
||||
}
|
||||
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
|
||||
return DataType.fromC(dataType(outputNativeHandle));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tensor<?> tensor(int outputIndex) {
|
||||
Tensor<?> tensor = outputTensors.get(outputIndex);
|
||||
if (tensor == null) {
|
||||
tensor = resolveTensor(outputIndex);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
private final EagerSession session;
|
||||
private final NativeReference nativeRef;
|
||||
private final String type;
|
||||
private final String name;
|
||||
private final AtomicReferenceArray<Tensor<?>> outputTensors;
|
||||
|
||||
private Tensor<?> resolveTensor(int outputIndex) {
|
||||
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
|
||||
// If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
|
||||
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
|
||||
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
|
||||
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
|
||||
tensor.close();
|
||||
tensor = outputTensors.get(outputIndex);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
private static class NativeReference extends EagerSession.NativeReference {
|
||||
|
||||
NativeReference(EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
|
||||
@ -92,30 +132,27 @@ class EagerOperation extends AbstractOperation {
|
||||
@Override
|
||||
void delete() {
|
||||
if (opHandle != 0L) {
|
||||
for (long tensorHandle : outputHandles) {
|
||||
if (tensorHandle != 0L) {
|
||||
EagerOperation.deleteTensorHandle(tensorHandle);
|
||||
for (int i = 0; i < outputHandles.length; ++i) {
|
||||
if (outputHandles[i] != 0L) {
|
||||
EagerOperation.deleteTensorHandle(outputHandles[i]);
|
||||
outputHandles[i] = 0L;
|
||||
}
|
||||
}
|
||||
EagerOperation.delete(opHandle);
|
||||
opHandle = 0L;
|
||||
Arrays.fill(outputHandles, 0L);
|
||||
}
|
||||
}
|
||||
|
||||
private long opHandle;
|
||||
private final long[] outputHandles;
|
||||
}
|
||||
|
||||
private final EagerSession session;
|
||||
private final NativeReference nativeRef;
|
||||
private final String type;
|
||||
private final String name;
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
||||
private static native void deleteTensorHandle(long handle);
|
||||
|
||||
private static native long resolveTensorHandle(long handle);
|
||||
|
||||
private static native int outputListLength(long handle, String name);
|
||||
|
||||
private static native int inputListLength(long handle, String name);
|
||||
|
@ -138,6 +138,11 @@ public final class GraphOperation extends AbstractOperation {
|
||||
r.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
Tensor<?> tensor(int outputIdx) {
|
||||
throw new IllegalStateException("Graph tensors must be fetched by running a session");
|
||||
}
|
||||
|
||||
long getUnsafeNativeHandle() {
|
||||
return unsafeNativeHandle;
|
||||
|
@ -47,6 +47,22 @@ public final class Output<T> implements Operand<T> {
|
||||
public DataType dataType() {
|
||||
return operation.dtype(index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the tensor at this output.
|
||||
*
|
||||
* <p>This operation is only supported on the outputs of an operation executed eagerly.
|
||||
* For graph environments, output tensors must be fetched by running a session, using
|
||||
* {@link Session.Runner#fetch(Output)}.
|
||||
*
|
||||
* @return tensor
|
||||
* @throws IllegalStateException if this output results from a graph
|
||||
* @see EagerSession
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public Tensor<T> tensor() {
|
||||
return (Tensor<T>)operation.tensor(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output<T> asOutput() {
|
||||
|
@ -140,15 +140,17 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
Tensor<?> t = new Tensor(dtype);
|
||||
t.shapeCopy = new long[numDimensions(obj, dtype)];
|
||||
fillShape(obj, 0, t.shapeCopy);
|
||||
long nativeHandle;
|
||||
if (t.dtype != DataType.STRING) {
|
||||
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
|
||||
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
|
||||
setValue(t.nativeHandle, obj);
|
||||
nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
|
||||
setValue(nativeHandle, obj);
|
||||
} else if (t.shapeCopy.length != 0) {
|
||||
t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
|
||||
nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
|
||||
} else {
|
||||
t.nativeHandle = allocateScalarBytes((byte[]) obj);
|
||||
nativeHandle = allocateScalarBytes((byte[]) obj);
|
||||
}
|
||||
t.nativeRef = new NativeReference(nativeHandle);
|
||||
return t;
|
||||
}
|
||||
|
||||
@ -314,23 +316,22 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
}
|
||||
Tensor<T> t = new Tensor<T>(dataType);
|
||||
t.shapeCopy = Arrays.copyOf(shape, shape.length);
|
||||
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
|
||||
long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
|
||||
t.nativeRef = new NativeReference(nativeHandle);
|
||||
return t;
|
||||
}
|
||||
|
||||
/**
|
||||
* Release resources associated with the Tensor.
|
||||
*
|
||||
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
|
||||
* <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
|
||||
* operation or memory will be leaked.
|
||||
*
|
||||
* <p>The Tensor object is no longer usable after {@code close} returns.
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
if (nativeHandle != 0) {
|
||||
delete(nativeHandle);
|
||||
nativeHandle = 0;
|
||||
}
|
||||
nativeRef.release();
|
||||
}
|
||||
|
||||
/** Returns the {@link DataType} of elements stored in the Tensor. */
|
||||
@ -374,7 +375,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
|
||||
*/
|
||||
public float floatValue() {
|
||||
return scalarFloat(nativeHandle);
|
||||
return scalarFloat(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -383,7 +384,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
|
||||
*/
|
||||
public double doubleValue() {
|
||||
return scalarDouble(nativeHandle);
|
||||
return scalarDouble(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -392,7 +393,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
|
||||
*/
|
||||
public int intValue() {
|
||||
return scalarInt(nativeHandle);
|
||||
return scalarInt(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -401,7 +402,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
|
||||
*/
|
||||
public long longValue() {
|
||||
return scalarLong(nativeHandle);
|
||||
return scalarLong(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -410,7 +411,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
||||
*/
|
||||
public boolean booleanValue() {
|
||||
return scalarBoolean(nativeHandle);
|
||||
return scalarBoolean(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -419,7 +420,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
|
||||
*/
|
||||
public byte[] bytesValue() {
|
||||
return scalarBytes(nativeHandle);
|
||||
return scalarBytes(getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -448,7 +449,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
*/
|
||||
public <U> U copyTo(U dst) {
|
||||
throwExceptionIfTypeIsIncompatible(dst);
|
||||
readNDArray(nativeHandle, dst);
|
||||
readNDArray(getNativeHandle(), dst);
|
||||
return dst;
|
||||
}
|
||||
|
||||
@ -553,16 +554,27 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
@SuppressWarnings("rawtypes")
|
||||
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
|
||||
t.shapeCopy = shape(handle);
|
||||
t.nativeHandle = handle;
|
||||
t.nativeRef = new NativeReference(handle);
|
||||
return t;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an eager Tensor object from a handle to the C TF_Tensor object.
|
||||
*
|
||||
* <p>Takes ownership of the handle.
|
||||
*/
|
||||
static Tensor<?> fromHandle(long handle, EagerSession session) {
|
||||
Tensor<?> t = fromHandle(handle);
|
||||
t.nativeRef.eager(session, t);
|
||||
return t;
|
||||
}
|
||||
|
||||
long getNativeHandle() {
|
||||
return nativeHandle;
|
||||
return nativeRef.tensorHandle;
|
||||
}
|
||||
|
||||
private long nativeHandle;
|
||||
private DataType dtype;
|
||||
private NativeReference nativeRef = null;
|
||||
private final DataType dtype;
|
||||
private long[] shapeCopy = null;
|
||||
|
||||
private Tensor(DataType t) {
|
||||
@ -570,7 +582,7 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
}
|
||||
|
||||
private ByteBuffer buffer() {
|
||||
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
|
||||
return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
|
||||
@ -609,6 +621,66 @@ public final class Tensor<T> implements AutoCloseable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reference to the underlying native tensor
|
||||
*
|
||||
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
|
||||
* released after executing the last line of the `try` block they were declared in.
|
||||
*
|
||||
* <p>They can also be attached to an eager session, where in this case their lifetime ends either when
|
||||
* this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
|
||||
*
|
||||
* <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to
|
||||
* the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
|
||||
* being explicetly closed before this happens, it will take cake of clearing its association with any eager
|
||||
* session before cleaning up the resources.
|
||||
*/
|
||||
private static class NativeReference {
|
||||
|
||||
/**
|
||||
* Attaches this reference to an eager session
|
||||
*/
|
||||
private class EagerReference extends EagerSession.NativeReference {
|
||||
|
||||
EagerReference(EagerSession session, Tensor<?> tensor) {
|
||||
super(session, tensor);
|
||||
}
|
||||
|
||||
@Override
|
||||
void delete() {
|
||||
// Mark this eager reference as cleared since it has been deleted by the session
|
||||
NativeReference.this.eagerRef = null;
|
||||
NativeReference.this.release();
|
||||
}
|
||||
}
|
||||
|
||||
NativeReference(long tensorHandle) {
|
||||
this.tensorHandle = tensorHandle;
|
||||
}
|
||||
|
||||
void eager(EagerSession session, Tensor<?> tensor) {
|
||||
if (eagerRef != null) {
|
||||
throw new IllegalStateException("The tensor is already attached to an eager session");
|
||||
}
|
||||
eagerRef = new EagerReference(session, tensor);
|
||||
}
|
||||
|
||||
synchronized void release() {
|
||||
if (tensorHandle != 0L) {
|
||||
// Clear any remaining eager reference to this tensor
|
||||
if (eagerRef != null) {
|
||||
eagerRef.clear();
|
||||
eagerRef = null;
|
||||
}
|
||||
Tensor.delete(tensorHandle);
|
||||
tensorHandle = 0L;
|
||||
}
|
||||
}
|
||||
|
||||
private long tensorHandle;
|
||||
private EagerReference eagerRef;
|
||||
}
|
||||
|
||||
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
|
||||
|
||||
static {
|
||||
|
@ -57,6 +57,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
||||
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
|
||||
if (tensor_handle == nullptr) return 0;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
static_assert(sizeof(jlong) >= sizeof(TF_Tensor*),
|
||||
"Cannot represent a C TF_Tensor as a Java long");
|
||||
return reinterpret_cast<jlong>(tensor);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||
TFE_Op* op = requireOp(env, handle);
|
||||
|
@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
|
||||
JNIEnv *, jclass, jlong);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: resolveTensorHandle
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle(
|
||||
JNIEnv *, jclass, jlong);
|
||||
|
||||
/**
|
||||
* Class: org_tensorflow_EagerOperation
|
||||
* Method: outputListLength
|
||||
|
@ -54,6 +54,22 @@ public class EagerOperationTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputTensor() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
EagerOperation add = opBuilder(session, "Add", "CompareResult")
|
||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||
.build();
|
||||
assertEquals(6, add.tensor(0).intValue());
|
||||
|
||||
// Validate that we retrieve the right shape and datatype from the tensor
|
||||
// that has been resolved
|
||||
assertEquals(0, add.shape(0).length);
|
||||
assertEquals(DataType.INT32, add.dtype(0));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inputAndOutputListLengths() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
@ -105,7 +121,7 @@ public class EagerOperationTest {
|
||||
@Test
|
||||
public void opNotAccessibleIfSessionIsClosed() {
|
||||
EagerSession session = EagerSession.create();
|
||||
EagerOperation add = opBuilder(session, "Add", "SetDevice")
|
||||
EagerOperation add = opBuilder(session, "Add", "SessionClosed")
|
||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||
.build();
|
||||
@ -119,6 +135,40 @@ public class EagerOperationTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputIndexOutOfBounds() {
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
EagerOperation add = opBuilder(session, "Add", "OutOfRange")
|
||||
.addInput(TestUtil.constant(session, "Const1", 2))
|
||||
.addInput(TestUtil.constant(session, "Const2", 4))
|
||||
.build();
|
||||
try {
|
||||
add.getUnsafeNativeHandle(1);
|
||||
fail();
|
||||
} catch (IndexOutOfBoundsException e) {
|
||||
// expected
|
||||
}
|
||||
try {
|
||||
add.shape(1);
|
||||
fail();
|
||||
} catch (IndexOutOfBoundsException e) {
|
||||
// expected
|
||||
}
|
||||
try {
|
||||
add.dtype(1);
|
||||
fail();
|
||||
} catch (IndexOutOfBoundsException e) {
|
||||
// expected
|
||||
}
|
||||
try {
|
||||
add.tensor(1);
|
||||
fail();
|
||||
} catch (IndexOutOfBoundsException e) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
|
||||
return new EagerOperationBuilder(session, type, name);
|
||||
}
|
||||
|
@ -166,6 +166,17 @@ public class GraphOperationTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputTensorNotSupported() {
|
||||
try (Graph g = new Graph()) {
|
||||
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||
try {
|
||||
split.output(0).tensor();
|
||||
fail();
|
||||
} catch (IllegalStateException e) {}
|
||||
}
|
||||
}
|
||||
|
||||
private static int split(int[] values, int num_split) {
|
||||
try (Graph g = new Graph()) {
|
||||
|
@ -18,6 +18,7 @@ package org.tensorflow;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
@ -28,6 +29,7 @@ import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
@ -519,6 +521,25 @@ public class TensorTest {
|
||||
// The expected exception.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void eagerTensorIsReleasedAfterSessionIsClosed() {
|
||||
Tensor<Integer> sum;
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
Output<?> x = TestUtil.constant(session, "Const1", 10);
|
||||
Output<?> y = TestUtil.constant(session, "Const2", 20);
|
||||
sum = TestUtil.<Integer>addN(session, x, y).tensor();
|
||||
assertNotEquals(0L, sum.getNativeHandle());
|
||||
assertEquals(30, sum.intValue());
|
||||
}
|
||||
assertEquals(0L, sum.getNativeHandle());
|
||||
try {
|
||||
sum.intValue();
|
||||
fail();
|
||||
} catch (NullPointerException e) {
|
||||
// expected.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void fromHandle() {
|
||||
|
@ -67,8 +67,8 @@ public class TestUtil {
|
||||
.<T>output(0);
|
||||
}
|
||||
|
||||
public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
|
||||
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
|
||||
public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) {
|
||||
return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
|
||||
}
|
||||
|
||||
public static <T> Output<T> matmul(
|
||||
|
Loading…
Reference in New Issue
Block a user