Java API to get the size of specified input list of operations. (#10865)

* Java API to get the size of specified input list of operations

* remove unnecessary explain to avoid bring a new term to users.
This commit is contained in:
myPrecious 2017-06-22 22:11:45 +08:00 committed by Shanqing Cai
parent e525ed179e
commit 3c19462307
4 changed files with 74 additions and 0 deletions

View File

@ -97,6 +97,28 @@ public final class Operation {
return new Output(this, idx);
}
/**
* Returns the size of the given inputs list of Tensors for this operation.
*
* <p>An Operation has multiple named inputs, each of which contains either
* a single tensor or a list of tensors. This method returns the size of
* the list of tensors for a specific named input of the operation.
*
* @param name identifier of the list of tensors (of which there may
* be many) inputs to this operation.
* @returns the size of the list of Tensors produced by this named input.
* @throws IllegalArgumentException if this operation has no input
* with the provided name.
*/
public int inputListLength(final String name) {
Graph.Reference r = graph.ref();
try {
return inputListLength(unsafeNativeHandle, name);
} finally {
r.close();
}
}
long getUnsafeNativeHandle() {
return unsafeNativeHandle;
}
@ -132,6 +154,8 @@ public final class Operation {
private static native int outputListLength(long handle, String name);
private static native int inputListLength(long handle, String name);
private static native long[] shape(long graphHandle, long opHandle, int output);
private static native int dtype(long graphHandle, long opHandle, int output);

View File

@ -156,3 +156,21 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env,
return static_cast<jint>(TF_OperationOutputType(TF_Output{op, output_index}));
}
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env,
jclass clazz,
jlong handle,
jstring name) {
TF_Operation* op = requireHandle(env, handle);
if (op == nullptr) return 0;
TF_Status* status = TF_NewStatus();
const char* cname = env->GetStringUTFChars(name, nullptr);
int result = TF_OperationInputListLength(op, cname, status);
env->ReleaseStringUTFChars(name, cname);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
return result;
}

View File

@ -73,6 +73,17 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(JNIEnv *,
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv *, jclass,
jlong, jlong, jint);
/*
* Class: org_tensorflow_Operation
* Method: inputListLength
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv *,
jclass,
jlong,
jstring);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -52,6 +52,16 @@ public class OperationTest {
assertEquals(3, split(new int[] {0, 1, 2}, 3));
}
@Test
public void inputListLength() {
assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim"));
try {
splitWithInputList(new int[] {0, 1}, 2, "inputs");
} catch (IllegalArgumentException iae) {
// expected
}
}
private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")
@ -62,4 +72,15 @@ public class OperationTest {
.outputListLength("output");
}
}
private static int splitWithInputList(int[] values, int num_split, String name) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")
.addInput(TestUtil.constant(g, "split_dim", 0))
.addInput(TestUtil.constant(g, "values", values))
.setAttr("num_split", num_split)
.build()
.inputListLength(name);
}
}
}