Java API to get the size of output list operations (#9640)

* Java API to get the size of output list operations

* Java API to get the size of output list operations
This commit is contained in:
KB Sriram 2017-05-04 13:55:20 -07:00 committed by Vijay Vasudevan
parent 9340f27ae1
commit 3273cf4f4d
5 changed files with 130 additions and 0 deletions

View File

@ -54,6 +54,18 @@ java_test(
],
)
java_test(
name = "OperationTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
test_class = "org.tensorflow.OperationTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
java_test(
name = "SavedModelBundleTest",
size = "small",

View File

@ -70,6 +70,28 @@ public final class Operation {
}
}
/**
* Returns the size of the list of Tensors produced by this operation.
*
* <p>An Operation has multiple named outputs, each of which produces either
* a single tensor or a list of tensors. This method returns the size of
* the list of tensors for a specific named output of the operation.
*
* @param name identifier of the list of tensors (of which there may
* be many) produced by this operation.
* @returns the size of the list of Tensors produced by this named output.
* @throws IllegalArgumentException if this operation has no output
* with the provided name.
*/
public int outputListLength(final String name) {
Graph.Reference r = graph.ref();
try {
return outputListLength(unsafeNativeHandle, name);
} finally {
r.close();
}
}
/** Returns a symbolic handle to one of the tensors produced by this operation. */
public Output output(int idx) {
return new Output(this, idx);
@ -108,6 +130,8 @@ public final class Operation {
private static native int numOutputs(long handle);
private static native int outputListLength(long handle, String name);
private static native long[] shape(long graphHandle, long opHandle, int output);
private static native int dtype(long graphHandle, long opHandle, int output);

View File

@ -66,6 +66,24 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
return TF_OperationNumOutputs(op);
}
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env,
jclass clazz,
jlong handle,
jstring name) {
TF_Operation* op = requireHandle(env, handle);
if (op == nullptr) return 0;
TF_Status* status = TF_NewStatus();
const char* cname = env->GetStringUTFChars(name, nullptr);
int result = TF_OperationOutputListLength(op, cname, status);
env->ReleaseStringUTFChars(name, cname);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
return result;
}
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(
JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
jint output_index) {

View File

@ -46,6 +46,16 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass,
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *,
jclass, jlong);
/*
* Class: org_tensorflow_Operation
* Method: outputListLength
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv *,
jclass,
jlong,
jstring);
/*
* Class: org_tensorflow_Operation
* Method: shape

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.util.List;
/** Unit tests for {@link org.tensorflow.Operation}. */
@RunWith(JUnit4.class)
public class OperationTest {
@Test
public void outputListLengthFailsOnInvalidName() {
try (Graph g = new Graph()) {
Operation op = g.opBuilder("Add", "Add")
.addInput(TestUtil.constant(g, "x", 1))
.addInput(TestUtil.constant(g, "y", 2))
.build();
assertEquals(1, op.outputListLength("z"));
try {
op.outputListLength("unknown");
fail("Did not catch bad name");
} catch (IllegalArgumentException iae) {
// expected
}
}
}
@Test
public void outputListLength() {
assertEquals(1, split(new int[]{0, 1}, 1));
assertEquals(2, split(new int[]{0, 1}, 2));
assertEquals(3, split(new int[]{0, 1, 2}, 3));
}
private int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")
.addInput(TestUtil.constant(g, "split_dim", 0))
.addInput(TestUtil.constant(g, "values", values))
.setAttr("num_split", num_split)
.build()
.outputListLength("output");
}
}
}