Merge pull request #28475 from karllessard:java-eager-operation

PiperOrigin-RevId: 247638740
This commit is contained in:
TensorFlower Gardener 2019-05-10 14:04:37 -07:00
commit b422174124
11 changed files with 637 additions and 40 deletions

View File

@ -159,6 +159,19 @@ tf_java_test(
],
)
tf_java_test(
name = "EagerOperationTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.EagerOperationTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "GraphTest",
size = "small",

View File

@ -23,6 +23,21 @@ package org.tensorflow;
*/
abstract class AbstractOperation implements Operation {
@Override
public Output<?>[] outputList(int idx, int length) {
Output<?>[] outputs = new Output<?>[length];
for (int i = 0; i < length; ++i) {
outputs[i] = output(idx + i);
}
return outputs;
}
@Override
@SuppressWarnings({"rawtypes", "unchecked"})
public <T> Output<T> output(int idx) {
return new Output(this, idx);
}
@Override
public String toString() {
return String.format("<%s '%s'>", type(), name());

View File

@ -0,0 +1,134 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import java.util.Arrays;
/**
* Implementation of an {@link Operation} executed eagerly.
*
* <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the
* EagerOperation instance may fail with an {@code IllegalStateException}.
*
* <p>EagerOperation instances are thread-safe.
*/
class EagerOperation extends AbstractOperation {
EagerOperation(
EagerSession session,
long opNativeHandle,
long[] outputNativeHandles,
String type,
String name) {
this.session = session;
this.type = type;
this.name = name;
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
}
@Override
public String name() {
return name;
}
@Override
public String type() {
return type;
}
@Override
public int numOutputs() {
return nativeRef.outputHandles.length;
}
@Override
public int outputListLength(final String name) {
return outputListLength(nativeRef.opHandle, name);
}
@Override
public int inputListLength(final String name) {
return inputListLength(nativeRef.opHandle, name);
}
@Override
public long getUnsafeNativeHandle(int outputIndex) {
return nativeRef.outputHandles[outputIndex];
}
@Override
public long[] shape(int outputIndex) {
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) {
shape[i] = dim(outputNativeHandle, i);
}
return shape;
}
@Override
public DataType dtype(int outputIndex) {
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataType.fromC(dataType(outputNativeHandle));
}
private static class NativeReference extends EagerSession.NativeReference {
NativeReference(
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
super(session, operation);
this.opHandle = opHandle;
this.outputHandles = outputHandles;
}
@Override
void delete() {
if (opHandle != 0L) {
for (long tensorHandle : outputHandles) {
if (tensorHandle != 0L) {
EagerOperation.deleteTensorHandle(tensorHandle);
}
}
EagerOperation.delete(opHandle);
opHandle = 0L;
Arrays.fill(outputHandles, 0L);
}
}
private long opHandle;
private final long[] outputHandles;
}
private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private static native void delete(long handle);
private static native void deleteTensorHandle(long handle);
private static native int outputListLength(long handle, String name);
private static native int inputListLength(long handle, String name);
private static native int dataType(long handle);
private static native int numDims(long handle);
private static native long dim(long handle, int index);
}

View File

@ -31,10 +31,14 @@ final class EagerOperationBuilder implements OperationBuilder {
}
@Override
public Operation build() {
// TODO (karllessard) Execute the eager operation and pass output tensor handles to new
// EagerOperation class
throw new UnsupportedOperationException("Eager execution is not supported yet");
public EagerOperation build() {
long[] tensorHandles = execute(nativeRef.opHandle);
EagerOperation operation =
new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name);
// Release our reference to the native op handle now that we transferred its
// ownership to the EagerOperation
nativeRef.clear();
return operation;
}
@Override
@ -44,7 +48,7 @@ final class EagerOperationBuilder implements OperationBuilder {
}
@Override
public OperationBuilder addInputList(Output<?>[] inputs) {
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
long[] inputHandles = new long[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
@ -60,18 +64,18 @@ final class EagerOperationBuilder implements OperationBuilder {
}
@Override
public OperationBuilder setDevice(String device) {
public EagerOperationBuilder setDevice(String device) {
setDevice(nativeRef.opHandle, device);
return this;
}
@Override
public OperationBuilder setAttr(String name, String value) {
public EagerOperationBuilder setAttr(String name, String value) {
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
}
@Override
public OperationBuilder setAttr(String name, String[] values) {
public EagerOperationBuilder setAttr(String name, String[] values) {
Charset utf8 = StandardCharsets.UTF_8;
Object[] objects = new Object[values.length];
for (int i = 0; i < values.length; ++i) {
@ -82,55 +86,55 @@ final class EagerOperationBuilder implements OperationBuilder {
}
@Override
public OperationBuilder setAttr(String name, byte[] values) {
public EagerOperationBuilder setAttr(String name, byte[] values) {
setAttrString(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, long value) {
public EagerOperationBuilder setAttr(String name, long value) {
setAttrInt(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, long[] values) {
public EagerOperationBuilder setAttr(String name, long[] values) {
setAttrIntList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, float value) {
public EagerOperationBuilder setAttr(String name, float value) {
setAttrFloat(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, float[] values) {
public EagerOperationBuilder setAttr(String name, float[] values) {
setAttrFloatList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, boolean value) {
public EagerOperationBuilder setAttr(String name, boolean value) {
setAttrBool(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, boolean[] values) {
public EagerOperationBuilder setAttr(String name, boolean[] values) {
setAttrBoolList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, DataType value) {
public EagerOperationBuilder setAttr(String name, DataType value) {
setAttrType(nativeRef.opHandle, name, value.c());
return this;
}
@Override
public OperationBuilder setAttr(String name, DataType[] values) {
public EagerOperationBuilder setAttr(String name, DataType[] values) {
int[] c = new int[values.length];
for (int i = 0; i < values.length; ++i) {
c[i] = values[i].c();
@ -140,26 +144,26 @@ final class EagerOperationBuilder implements OperationBuilder {
}
@Override
public OperationBuilder setAttr(String name, Tensor<?> value) {
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
return this;
}
@Override
public OperationBuilder setAttr(String name, Tensor<?>[] values) {
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
throw new UnsupportedOperationException(
"Tensor list attributes are not supported in eager mode");
}
@Override
public OperationBuilder setAttr(String name, Shape value) {
public EagerOperationBuilder setAttr(String name, Shape value) {
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
return this;
}
@Override
public OperationBuilder setAttr(String name, Shape[] values) {
public EagerOperationBuilder setAttr(String name, Shape[] values) {
int[] numDimensions = new int[values.length];
int totalNumDimensions = 0;
for (int idx = 0; idx < values.length; ++idx) {

View File

@ -75,21 +75,6 @@ public final class GraphOperation extends AbstractOperation {
}
}
@Override
public Output<?>[] outputList(int idx, int length) {
Output<?>[] outputs = new Output<?>[length];
for (int i = 0; i < length; ++i) {
outputs[i] = output(idx + i);
}
return outputs;
}
@Override
@SuppressWarnings({"rawtypes", "unchecked"})
public <T> Output<T> output(int idx) {
return new Output(this, idx);
}
@Override
public int hashCode() {
return Long.valueOf(getUnsafeNativeHandle()).hashCode();

View File

@ -22,6 +22,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
// This value should be >= to the maximum number of outputs in any op
#define MAX_OUTPUTS_PER_OP 8
namespace {
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
@ -89,7 +92,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
JNIEnv* env, jclass clazz, jlong op_handle) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return 0;
int num_retvals = 8; // should be >= than max number of outputs in any op
int num_retvals = MAX_OUTPUTS_PER_OP;
std::unique_ptr<TFE_TensorHandle*[]> retvals(
new TFE_TensorHandle*[num_retvals]);
TF_Status* status = TF_NewStatus();

View File

@ -0,0 +1,130 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/java/src/main/native/eager_operation_jni.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException,
"Eager session has been closed");
return nullptr;
}
return reinterpret_cast<TFE_Op*>(handle);
}
TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException, "EagerSession has been closed");
return nullptr;
}
return reinterpret_cast<TFE_TensorHandle*>(handle);
}
} // namespace
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
if (handle == 0) return;
TFE_DeleteOp(reinterpret_cast<TFE_Op*>(handle));
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(
JNIEnv* env, jclass clazz, jlong handle) {
if (handle == 0) return;
TFE_DeleteTensorHandle(reinterpret_cast<TFE_TensorHandle*>(handle));
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Op* op = requireOp(env, handle);
if (op == nullptr) return 0;
TF_Status* status = TF_NewStatus();
const char* cname = env->GetStringUTFChars(name, nullptr);
int length = TFE_OpGetOutputLength(op, cname, status);
env->ReleaseStringUTFChars(name, cname);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return static_cast<jint>(length);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Op* op = requireOp(env, handle);
if (op == nullptr) return 0;
TF_Status* status = TF_NewStatus();
const char* cname = env->GetStringUTFChars(name, nullptr);
int length = TFE_OpGetInputLength(op, cname, status);
env->ReleaseStringUTFChars(name, cname);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return static_cast<jint>(length);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(
JNIEnv* env, jclass clazz, jlong handle) {
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
if (tensor_handle == nullptr) return 0;
TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle);
return static_cast<jint>(data_type);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(
JNIEnv* env, jclass clazz, jlong handle) {
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
if (tensor_handle == nullptr) return 0;
TF_Status* status = TF_NewStatus();
int num_dims = TFE_TensorHandleNumDims(tensor_handle, status);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return static_cast<jint>(num_dims);
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv* env,
jclass clazz,
jlong handle,
jint dim_index) {
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle);
if (tensor_handle == nullptr) return 0;
TF_Status* status = TF_NewStatus();
int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return static_cast<jlong>(dim);
}

View File

@ -0,0 +1,86 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_tensorflow_EagerOperation
* Method: delete
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *,
jclass, jlong);
/*
* Class: org_tensorflow_EagerOperation
* Method: deleteTensorHandle
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: outputListLength
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength(
JNIEnv *, jclass, jlong, jstring);
/**
* Class: org_tensorflow_EagerOperation
* Method: inputListLength
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength(
JNIEnv *, jclass, jlong, jstring);
/**
* Class: org_tensorflow_EagerOperation
* Method: dataType
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(JNIEnv *,
jclass,
jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: numDims
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(JNIEnv *,
jclass,
jlong);
/**
* Class: org_tensorflow_EagerOperation
* Method: dim
* Signature: (JI)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv *, jclass,
jlong, jint);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_

View File

@ -21,6 +21,7 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link EagerOperationBuilder} class. */
@RunWith(JUnit4.class)
public class EagerOperationBuilderTest {
@ -32,6 +33,7 @@ public class EagerOperationBuilderTest {
new EagerOperationBuilder(session, "Add", "add");
fail();
} catch (IllegalStateException e) {
// expected
}
}
@ -45,8 +47,99 @@ public class EagerOperationBuilderTest {
opBuilder.setAttr("dtype", DataType.FLOAT);
fail();
} catch (IllegalStateException e) {
// expected
}
}
// TODO (karllessard) add more tests when EagerOperation is implemented as well
@Test
public void addInputs() {
try (EagerSession session = EagerSession.create()) {
Operation asrt =
opBuilder(session, "Assert", "assert")
.addInput(TestUtil.constant(session, "Cond", true))
.addInputList(new Output<?>[] {TestUtil.constant(session, "Error", -1)})
.build();
try {
opBuilder(session, "Const", "var").addControlInput(asrt);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
}
}
@Test
public void setDevice() {
try (EagerSession session = EagerSession.create()) {
opBuilder(session, "Add", "SetDevice")
.setDevice("/job:localhost/replica:0/task:0/device:CPU:0")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
}
}
@Test
public void setAttrs() {
// The effect of setting an attribute may not easily be visible from the other parts of this
// package's API. Thus, for now, the test simply executes the various setAttr variants to see
// that there are no exceptions.
//
// This is a bit of an awkward test since it has to find operations with attributes of specific
// types that aren't inferred from the input arguments.
try (EagerSession session = EagerSession.create()) {
// dtype, tensor attributes.
try (Tensor<Integer> t = Tensors.create(1)) {
opBuilder(session, "Const", "DataTypeAndTensor")
.setAttr("dtype", DataType.INT32)
.setAttr("value", t)
.build();
}
// type, int (TF "int" attributes are 64-bit signed, so a Java long).
opBuilder(session, "RandomUniform", "DataTypeAndInt")
.addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1}))
.setAttr("seed", 10)
.setAttr("dtype", DataType.FLOAT)
.build();
// list(int), string
opBuilder(session, "MaxPool", "IntListAndString")
.addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2]))
.setAttr("ksize", new long[] {1, 1, 1, 1})
.setAttr("strides", new long[] {1, 1, 1, 1})
.setAttr("padding", "SAME")
.build();
// list(float), device
opBuilder(session, "FractionalMaxPool", "FloatList")
.addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2]))
.setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
.build();
// shape
opBuilder(session, "EnsureShape", "ShapeAttr")
.addInput(TestUtil.constant(session, "Const", new int[2][2]))
.setAttr("shape", Shape.make(2, 2))
.build();
// list(shape)
opBuilder(session, "FIFOQueue", "queue")
.setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
.setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)})
.build();
// bool
opBuilder(session, "All", "Bool")
.addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false}))
.addInput(TestUtil.constant(session, "Axis", 0))
.setAttr("keep_dims", false)
.build();
// float
opBuilder(session, "ApproximateEqual", "Float")
.addInput(TestUtil.constant(session, "Const1", 10.00001f))
.addInput(TestUtil.constant(session, "Const2", 10.00000f))
.setAttr("tolerance", 0.1f)
.build();
// Missing tests: list(string), list(byte), list(bool), list(type)
}
}
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
return new EagerOperationBuilder(session, type, name);
}
}

View File

@ -0,0 +1,128 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link EagerOperation} class. */
@RunWith(JUnit4.class)
public class EagerOperationTest {
@Test
public void failToCreateIfSessionIsClosed() {
EagerSession session = EagerSession.create();
session.close();
try {
new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
fail();
} catch (IllegalStateException e) {
// expected
}
}
@Test
public void outputDataTypeAndShape() {
try (EagerSession session = EagerSession.create();
Tensor<Integer> t = Tensors.create(new int[2][3])) {
EagerOperation op =
opBuilder(session, "Const", "OutputAttrs")
.setAttr("dtype", DataType.INT32)
.setAttr("value", t)
.build();
assertEquals(DataType.INT32, op.dtype(0));
assertEquals(2, op.shape(0)[0]);
assertEquals(3, op.shape(0)[1]);
}
}
@Test
public void inputAndOutputListLengths() {
try (EagerSession session = EagerSession.create()) {
Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f});
Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f});
EagerOperation acc =
opBuilder(session, "AddN", "InputListLength")
.addInputList(new Output<?>[] {c1, c2})
.build();
assertEquals(2, acc.inputListLength("inputs"));
assertEquals(1, acc.outputListLength("sum"));
EagerOperation split =
opBuilder(session, "Split", "OutputListLength")
.addInput(TestUtil.constant(session, "Axis", 0))
.addInput(c1)
.setAttr("num_split", 2)
.build();
assertEquals(1, split.inputListLength("split_dim"));
assertEquals(2, split.outputListLength("output"));
try {
split.inputListLength("no_such_input");
fail();
} catch (IllegalArgumentException e) {
// expected
}
try {
split.outputListLength("no_such_output");
fail();
} catch (IllegalArgumentException e) {
// expected
}
}
}
@Test
public void numOutputs() {
try (EagerSession session = EagerSession.create()) {
EagerOperation op =
opBuilder(session, "UniqueWithCountsV2", "unq")
.addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1}))
.addInput(TestUtil.constant(session, "Axis", new int[] {0}))
.setAttr("out_idx", DataType.INT32)
.build();
assertEquals(3, op.numOutputs());
}
}
@Test
public void opNotAccessibleIfSessionIsClosed() {
EagerSession session = EagerSession.create();
EagerOperation add =
opBuilder(session, "Add", "SetDevice")
.addInput(TestUtil.constant(session, "Const1", 2))
.addInput(TestUtil.constant(session, "Const2", 4))
.build();
assertEquals(1, add.outputListLength("z"));
session.close();
try {
add.outputListLength("z");
fail();
} catch (IllegalStateException e) {
// expected
}
}
private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
return new EagerOperationBuilder(session, type, name);
}
}

View File

@ -50,8 +50,14 @@ public class TestUtil {
}
}
public static <T> Output<T> constant(Graph g, String name, Object value) {
return constantOp(g, name, value).<T>output(0);
public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) {
try (Tensor<?> t = Tensor.create(value)) {
return env.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.<T>output(0);
}
}
public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {