diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index f2904ad5a69..a8910248c13 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -54,6 +54,18 @@ java_test( ], ) +java_test( + name = "OperationTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/OperationTest.java"], + test_class = "org.tensorflow.OperationTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + java_test( name = "SavedModelBundleTest", size = "small", diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java index 48db554e072..43dbaf125c9 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -70,6 +70,28 @@ public final class Operation { } } + /** + * Returns the size of the list of Tensors produced by this operation. + * + *
An Operation has multiple named outputs, each of which produces either + * a single tensor or a list of tensors. This method returns the size of + * the list of tensors for a specific named output of the operation. + * + * @param name identifier of the list of tensors (of which there may + * be many) produced by this operation. + * @returns the size of the list of Tensors produced by this named output. + * @throws IllegalArgumentException if this operation has no output + * with the provided name. + */ + public int outputListLength(final String name) { + Graph.Reference r = graph.ref(); + try { + return outputListLength(unsafeNativeHandle, name); + } finally { + r.close(); + } + } + /** Returns a symbolic handle to one of the tensors produced by this operation. */ public Output output(int idx) { return new Output(this, idx); @@ -108,6 +130,8 @@ public final class Operation { private static native int numOutputs(long handle); + private static native int outputListLength(long handle, String name); + private static native long[] shape(long graphHandle, long opHandle, int output); private static native int dtype(long graphHandle, long opHandle, int output); diff --git a/tensorflow/java/src/main/native/operation_jni.cc b/tensorflow/java/src/main/native/operation_jni.cc index 32e59bc0aed..b3d5fc4ec37 100644 --- a/tensorflow/java/src/main/native/operation_jni.cc +++ b/tensorflow/java/src/main/native/operation_jni.cc @@ -66,6 +66,24 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env, return TF_OperationNumOutputs(op); } +JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env, + jclass clazz, + jlong handle, + jstring name) { + TF_Operation* op = requireHandle(env, handle); + if (op == nullptr) return 0; + + TF_Status* status = TF_NewStatus(); + + const char* cname = env->GetStringUTFChars(name, nullptr); + int result = TF_OperationOutputListLength(op, cname, status); + env->ReleaseStringUTFChars(name, cname); + + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + return result; +} + JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape( JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle, jint output_index) { diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/operation_jni.h index 6292a48069c..b5d156f7c27 100644 --- a/tensorflow/java/src/main/native/operation_jni.h +++ b/tensorflow/java/src/main/native/operation_jni.h @@ -46,6 +46,16 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass, JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Operation + * Method: outputListLength + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv *, + jclass, + jlong, + jstring); + /* * Class: org_tensorflow_Operation * Method: shape diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java new file mode 100644 index 00000000000..53bd511b5bc --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** Unit tests for {@link org.tensorflow.Operation}. */ +@RunWith(JUnit4.class) +public class OperationTest { + + @Test + public void outputListLengthFailsOnInvalidName() { + try (Graph g = new Graph()) { + Operation op = g.opBuilder("Add", "Add") + .addInput(TestUtil.constant(g, "x", 1)) + .addInput(TestUtil.constant(g, "y", 2)) + .build(); + assertEquals(1, op.outputListLength("z")); + + try { + op.outputListLength("unknown"); + fail("Did not catch bad name"); + } catch (IllegalArgumentException iae) { + // expected + } + } + } + + @Test + public void outputListLength() { + assertEquals(1, split(new int[]{0, 1}, 1)); + assertEquals(2, split(new int[]{0, 1}, 2)); + assertEquals(3, split(new int[]{0, 1, 2}, 3)); + } + + private int split(int[] values, int num_split) { + try (Graph g = new Graph()) { + return g.opBuilder("Split", "Split") + .addInput(TestUtil.constant(g, "split_dim", 0)) + .addInput(TestUtil.constant(g, "values", values)) + .setAttr("num_split", num_split) + .build() + .outputListLength("output"); + } + } +}