Java API to get the size of specified input list of operations. (#10865)
* Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users.
This commit is contained in:
parent
e525ed179e
commit
3c19462307
@ -97,6 +97,28 @@ public final class Operation {
|
||||
return new Output(this, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the size of the given inputs list of Tensors for this operation.
|
||||
*
|
||||
* <p>An Operation has multiple named inputs, each of which contains either
|
||||
* a single tensor or a list of tensors. This method returns the size of
|
||||
* the list of tensors for a specific named input of the operation.
|
||||
*
|
||||
* @param name identifier of the list of tensors (of which there may
|
||||
* be many) inputs to this operation.
|
||||
* @returns the size of the list of Tensors produced by this named input.
|
||||
* @throws IllegalArgumentException if this operation has no input
|
||||
* with the provided name.
|
||||
*/
|
||||
public int inputListLength(final String name) {
|
||||
Graph.Reference r = graph.ref();
|
||||
try {
|
||||
return inputListLength(unsafeNativeHandle, name);
|
||||
} finally {
|
||||
r.close();
|
||||
}
|
||||
}
|
||||
|
||||
long getUnsafeNativeHandle() {
|
||||
return unsafeNativeHandle;
|
||||
}
|
||||
@ -132,6 +154,8 @@ public final class Operation {
|
||||
|
||||
private static native int outputListLength(long handle, String name);
|
||||
|
||||
private static native int inputListLength(long handle, String name);
|
||||
|
||||
private static native long[] shape(long graphHandle, long opHandle, int output);
|
||||
|
||||
private static native int dtype(long graphHandle, long opHandle, int output);
|
||||
|
@ -156,3 +156,21 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env,
|
||||
|
||||
return static_cast<jint>(TF_OperationOutputType(TF_Output{op, output_index}));
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle,
|
||||
jstring name) {
|
||||
TF_Operation* op = requireHandle(env, handle);
|
||||
if (op == nullptr) return 0;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||
int result = TF_OperationInputListLength(op, cname, status);
|
||||
env->ReleaseStringUTFChars(name, cname);
|
||||
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
return result;
|
||||
}
|
||||
|
@ -73,6 +73,17 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(JNIEnv *,
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv *, jclass,
|
||||
jlong, jlong, jint);
|
||||
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_Operation
|
||||
* Method: inputListLength
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv *,
|
||||
jclass,
|
||||
jlong,
|
||||
jstring);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -52,6 +52,16 @@ public class OperationTest {
|
||||
assertEquals(3, split(new int[] {0, 1, 2}, 3));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inputListLength() {
|
||||
assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim"));
|
||||
try {
|
||||
splitWithInputList(new int[] {0, 1}, 2, "inputs");
|
||||
} catch (IllegalArgumentException iae) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
private static int split(int[] values, int num_split) {
|
||||
try (Graph g = new Graph()) {
|
||||
return g.opBuilder("Split", "Split")
|
||||
@ -62,4 +72,15 @@ public class OperationTest {
|
||||
.outputListLength("output");
|
||||
}
|
||||
}
|
||||
|
||||
private static int splitWithInputList(int[] values, int num_split, String name) {
|
||||
try (Graph g = new Graph()) {
|
||||
return g.opBuilder("Split", "Split")
|
||||
.addInput(TestUtil.constant(g, "split_dim", 0))
|
||||
.addInput(TestUtil.constant(g, "values", values))
|
||||
.setAttr("num_split", num_split)
|
||||
.build()
|
||||
.inputListLength(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user