Java API to get the size of output list operations (#9640)
* Java API to get the size of output list operations * Java API to get the size of output list operations
This commit is contained in:
parent
9340f27ae1
commit
3273cf4f4d
@ -54,6 +54,18 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "OperationTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
|
||||
test_class = "org.tensorflow.OperationTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "SavedModelBundleTest",
|
||||
size = "small",
|
||||
|
@ -70,6 +70,28 @@ public final class Operation {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the size of the list of Tensors produced by this operation.
|
||||
*
|
||||
* <p>An Operation has multiple named outputs, each of which produces either
|
||||
* a single tensor or a list of tensors. This method returns the size of
|
||||
* the list of tensors for a specific named output of the operation.
|
||||
*
|
||||
* @param name identifier of the list of tensors (of which there may
|
||||
* be many) produced by this operation.
|
||||
* @returns the size of the list of Tensors produced by this named output.
|
||||
* @throws IllegalArgumentException if this operation has no output
|
||||
* with the provided name.
|
||||
*/
|
||||
public int outputListLength(final String name) {
|
||||
Graph.Reference r = graph.ref();
|
||||
try {
|
||||
return outputListLength(unsafeNativeHandle, name);
|
||||
} finally {
|
||||
r.close();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns a symbolic handle to one of the tensors produced by this operation. */
|
||||
public Output output(int idx) {
|
||||
return new Output(this, idx);
|
||||
@ -108,6 +130,8 @@ public final class Operation {
|
||||
|
||||
private static native int numOutputs(long handle);
|
||||
|
||||
private static native int outputListLength(long handle, String name);
|
||||
|
||||
private static native long[] shape(long graphHandle, long opHandle, int output);
|
||||
|
||||
private static native int dtype(long graphHandle, long opHandle, int output);
|
||||
|
@ -66,6 +66,24 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
|
||||
return TF_OperationNumOutputs(op);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle,
|
||||
jstring name) {
|
||||
TF_Operation* op = requireHandle(env, handle);
|
||||
if (op == nullptr) return 0;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
const char* cname = env->GetStringUTFChars(name, nullptr);
|
||||
int result = TF_OperationOutputListLength(op, cname, status);
|
||||
env->ReleaseStringUTFChars(name, cname);
|
||||
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
return result;
|
||||
}
|
||||
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(
|
||||
JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
|
||||
jint output_index) {
|
||||
|
@ -46,6 +46,16 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass,
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *,
|
||||
jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_Operation
|
||||
* Method: outputListLength
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv *,
|
||||
jclass,
|
||||
jlong,
|
||||
jstring);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_Operation
|
||||
* Method: shape
|
||||
|
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.Operation}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class OperationTest {
|
||||
|
||||
@Test
|
||||
public void outputListLengthFailsOnInvalidName() {
|
||||
try (Graph g = new Graph()) {
|
||||
Operation op = g.opBuilder("Add", "Add")
|
||||
.addInput(TestUtil.constant(g, "x", 1))
|
||||
.addInput(TestUtil.constant(g, "y", 2))
|
||||
.build();
|
||||
assertEquals(1, op.outputListLength("z"));
|
||||
|
||||
try {
|
||||
op.outputListLength("unknown");
|
||||
fail("Did not catch bad name");
|
||||
} catch (IllegalArgumentException iae) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputListLength() {
|
||||
assertEquals(1, split(new int[]{0, 1}, 1));
|
||||
assertEquals(2, split(new int[]{0, 1}, 2));
|
||||
assertEquals(3, split(new int[]{0, 1, 2}, 3));
|
||||
}
|
||||
|
||||
private int split(int[] values, int num_split) {
|
||||
try (Graph g = new Graph()) {
|
||||
return g.opBuilder("Split", "Split")
|
||||
.addInput(TestUtil.constant(g, "split_dim", 0))
|
||||
.addInput(TestUtil.constant(g, "values", values))
|
||||
.setAttr("num_split", num_split)
|
||||
.build()
|
||||
.outputListLength("output");
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user