Merge pull request #25949 from melissagrueter:while_loop
PiperOrigin-RevId: 235601480
This commit is contained in:
commit
6799f6e990
@ -236,6 +236,115 @@ public final class Graph implements AutoCloseable {
|
|||||||
return addGradients(null, new Output<?>[] {y}, x, null);
|
return addGradients(null, new Output<?>[] {y}, x, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to instantiate an abstract class which overrides the buildSubgraph method to build a
|
||||||
|
* conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
|
||||||
|
* create a lambda for the same purpose.
|
||||||
|
*
|
||||||
|
* <p>To be used when calling {@link #whileLoop(Output[],
|
||||||
|
* org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
|
||||||
|
*
|
||||||
|
* <p>Example usage (prior to Java 8):
|
||||||
|
*
|
||||||
|
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public
|
||||||
|
* void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build
|
||||||
|
* body subgraph } }; }
|
||||||
|
*
|
||||||
|
* <p>Example usage (after Java 8):
|
||||||
|
*
|
||||||
|
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
|
||||||
|
* build body subgraph };}
|
||||||
|
*/
|
||||||
|
public interface WhileSubgraphBuilder {
|
||||||
|
/**
|
||||||
|
* To be overridden by user with code to build conditional or body subgraph for a while loop
|
||||||
|
*
|
||||||
|
* @param g the subgraph
|
||||||
|
* @param inputs subgraph inputs
|
||||||
|
* @param outputs subgraph outputs
|
||||||
|
*/
|
||||||
|
public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// called by while loop code in graph_jni.cc to construct conditional/body subgraphs
|
||||||
|
private static long[] buildSubgraph(
|
||||||
|
WhileSubgraphBuilder subgraphBuilder,
|
||||||
|
long subgraphHandle,
|
||||||
|
long[] inputHandles,
|
||||||
|
int[] inputIndices,
|
||||||
|
long[] outputHandles,
|
||||||
|
int[] outputIndices) {
|
||||||
|
Graph subgraph = new Graph(subgraphHandle);
|
||||||
|
|
||||||
|
int ninputs = inputHandles.length;
|
||||||
|
int noutputs = outputHandles.length;
|
||||||
|
Output<?>[] inputs = new Output<?>[ninputs];
|
||||||
|
Output<?>[] outputs = new Output<?>[noutputs];
|
||||||
|
long[] outputHandlesAndIndices = new long[noutputs * 2];
|
||||||
|
|
||||||
|
synchronized (subgraph.nativeHandleLock) {
|
||||||
|
try (Reference ref = subgraph.ref()) {
|
||||||
|
|
||||||
|
for (int i = 0; i < ninputs; i++) {
|
||||||
|
Operation op = new Operation(subgraph, inputHandles[i]);
|
||||||
|
inputs[i] = new Output<>(op, inputIndices[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < noutputs; i++) {
|
||||||
|
Operation op = new Operation(subgraph, outputHandles[i]);
|
||||||
|
outputs[i] = new Output<>(op, outputIndices[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
subgraphBuilder.buildSubgraph(subgraph, inputs, outputs);
|
||||||
|
|
||||||
|
for (int i = 0, j = noutputs; i < noutputs; i++, j++) {
|
||||||
|
outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle();
|
||||||
|
outputHandlesAndIndices[j] = (long) outputs[i].index();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputHandlesAndIndices;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds a while loop.
|
||||||
|
*
|
||||||
|
* @param inputs the loop inputs
|
||||||
|
* @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph
|
||||||
|
* @param bgBuilder WhileSubgraphBuilder to build the body subgraph
|
||||||
|
* @param name name for the loop
|
||||||
|
* @return list of loop outputs, of the same length as {@code inputs}
|
||||||
|
*/
|
||||||
|
public Output<?>[] whileLoop(
|
||||||
|
Output<?>[] inputs,
|
||||||
|
WhileSubgraphBuilder cgBuilder,
|
||||||
|
WhileSubgraphBuilder bgBuilder,
|
||||||
|
String name) {
|
||||||
|
int ninputs = inputs.length;
|
||||||
|
long[] inputHandles = new long[ninputs];
|
||||||
|
int[] inputIndices = new int[ninputs];
|
||||||
|
Output<?>[] outputs = new Output<?>[ninputs];
|
||||||
|
|
||||||
|
synchronized (nativeHandleLock) {
|
||||||
|
try (Reference ref = ref()) {
|
||||||
|
|
||||||
|
for (int i = 0; i < ninputs; i++) {
|
||||||
|
inputHandles[i] = inputs[i].op().getUnsafeNativeHandle();
|
||||||
|
inputIndices[i] = inputs[i].index();
|
||||||
|
}
|
||||||
|
|
||||||
|
long[] outputHandlesAndIndices =
|
||||||
|
whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder);
|
||||||
|
|
||||||
|
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||||
|
Operation op = new Operation(this, outputHandlesAndIndices[i]);
|
||||||
|
outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private final Object nativeHandleLock = new Object();
|
private final Object nativeHandleLock = new Object();
|
||||||
private long nativeHandle;
|
private long nativeHandle;
|
||||||
private int refcount = 0;
|
private int refcount = 0;
|
||||||
@ -357,6 +466,14 @@ public final class Graph implements AutoCloseable {
|
|||||||
long[] gradInputHandles,
|
long[] gradInputHandles,
|
||||||
int[] gradInputIndices);
|
int[] gradInputIndices);
|
||||||
|
|
||||||
|
private static native long[] whileLoop(
|
||||||
|
long handle,
|
||||||
|
long[] inputHandles,
|
||||||
|
int[] inputIndices,
|
||||||
|
String name,
|
||||||
|
WhileSubgraphBuilder condGraphBuilder,
|
||||||
|
WhileSubgraphBuilder bodyGraphBuilder);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
TensorFlow.init();
|
TensorFlow.init();
|
||||||
}
|
}
|
||||||
|
@ -18,19 +18,28 @@ limitations under the License.
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/java/src/main/native/utils_jni.h"
|
|
||||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||||
|
#include "tensorflow/java/src/main/native/utils_jni.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
|
template <class T>
|
||||||
static_assert(sizeof(jlong) >= sizeof(TF_Graph*),
|
T* requireHandleImpl(JNIEnv* env, jlong handle) {
|
||||||
|
static_assert(sizeof(jlong) >= sizeof(T*),
|
||||||
"Cannot package C object pointers as a Java long");
|
"Cannot package C object pointers as a Java long");
|
||||||
if (handle == 0) {
|
if (handle == 0) {
|
||||||
throwException(env, kIllegalStateException,
|
throwException(env, kIllegalStateException,
|
||||||
"close() has been called on the Graph");
|
"close() has been called on the Graph");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return reinterpret_cast<TF_Graph*>(handle);
|
return reinterpret_cast<T*>(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
|
||||||
|
return requireHandleImpl<TF_Graph>(env, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
|
||||||
|
return requireHandleImpl<TF_Operation>(env, handle);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -56,10 +65,8 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
|
|||||||
return reinterpret_cast<jlong>(op);
|
return reinterpret_cast<jlong>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv* env,
|
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
|
||||||
jclass clazz,
|
JNIEnv* env, jclass clazz, jlong handle, jint position) {
|
||||||
jlong handle,
|
|
||||||
jint position) {
|
|
||||||
TF_Graph* g = requireHandle(env, handle);
|
TF_Graph* g = requireHandle(env, handle);
|
||||||
if (g == nullptr) return nullptr;
|
if (g == nullptr) return nullptr;
|
||||||
|
|
||||||
@ -189,3 +196,140 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
|
|||||||
|
|
||||||
return dy_handles_and_indices;
|
return dy_handles_and_indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// helper function for while loop -- constructs conditional or body subgraph
|
||||||
|
jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
|
||||||
|
TF_Graph* const subgraph,
|
||||||
|
const TF_Output* const inputs,
|
||||||
|
const TF_Output* const outputs, const int ninputs,
|
||||||
|
const int noutputs) {
|
||||||
|
jmethodID build_subgraph_method_id = env->GetStaticMethodID(
|
||||||
|
clazz, "buildSubgraph",
|
||||||
|
"(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
|
||||||
|
if (build_subgraph_method_id == 0) return nullptr;
|
||||||
|
|
||||||
|
jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
|
||||||
|
|
||||||
|
jlongArray input_handles = env->NewLongArray(ninputs);
|
||||||
|
jintArray input_indices = env->NewIntArray(ninputs);
|
||||||
|
jlongArray output_handles = env->NewLongArray(noutputs);
|
||||||
|
jintArray output_indices = env->NewIntArray(noutputs);
|
||||||
|
|
||||||
|
jlong* input_handles_elems =
|
||||||
|
env->GetLongArrayElements(input_handles, nullptr);
|
||||||
|
jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
|
||||||
|
jlong* output_handles_elems =
|
||||||
|
env->GetLongArrayElements(output_handles, nullptr);
|
||||||
|
jint* output_indices_elems =
|
||||||
|
env->GetIntArrayElements(output_indices, nullptr);
|
||||||
|
|
||||||
|
for (int i = 0; i < ninputs; ++i) {
|
||||||
|
input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
|
||||||
|
input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < noutputs; ++i) {
|
||||||
|
output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
|
||||||
|
output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
|
||||||
|
}
|
||||||
|
|
||||||
|
env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
|
||||||
|
env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
|
||||||
|
env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
|
||||||
|
env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
|
||||||
|
|
||||||
|
// call Java code to construct the subgraph
|
||||||
|
jlongArray output_handles_and_indices =
|
||||||
|
(jlongArray)env->CallStaticObjectMethod(
|
||||||
|
clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
|
||||||
|
input_handles, input_indices, output_handles, output_indices);
|
||||||
|
|
||||||
|
if (env->ExceptionOccurred()) {
|
||||||
|
env->ExceptionDescribe();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returned array contains both op handles and output indices, in pair
|
||||||
|
return output_handles_and_indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
|
||||||
|
jintArray input_indices, jstring name, jobject cond_graph_builder,
|
||||||
|
jobject body_graph_builder) {
|
||||||
|
TF_Graph* g = requireHandle(env, handle);
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
if (g == nullptr) return nullptr;
|
||||||
|
|
||||||
|
int ninputs = env->GetArrayLength(input_handles);
|
||||||
|
|
||||||
|
std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
|
||||||
|
resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
|
||||||
|
ninputs);
|
||||||
|
if (env->ExceptionCheck()) return nullptr;
|
||||||
|
|
||||||
|
// initialize while params
|
||||||
|
TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
|
||||||
|
throwExceptionIfNotOK(env, status);
|
||||||
|
|
||||||
|
// build conditional subgraph
|
||||||
|
jlongArray cond_output_handles_and_indices =
|
||||||
|
buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
|
||||||
|
params.cond_inputs, ¶ms.cond_output, params.ninputs, 1);
|
||||||
|
|
||||||
|
// build body subgraph
|
||||||
|
jlongArray body_output_handles_and_indices = buildSubgraph(
|
||||||
|
env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
|
||||||
|
params.body_outputs, params.ninputs, params.ninputs);
|
||||||
|
|
||||||
|
if (cond_output_handles_and_indices == nullptr ||
|
||||||
|
body_output_handles_and_indices == nullptr)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// set cond_output param to output of the conditional subgraph
|
||||||
|
jlong* cond_output_elems =
|
||||||
|
env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
|
||||||
|
TF_Operation* cond_output_op =
|
||||||
|
requireOperationHandle(env, cond_output_elems[0]);
|
||||||
|
params.cond_output = {cond_output_op,
|
||||||
|
static_cast<jint>(cond_output_elems[1])};
|
||||||
|
env->ReleaseLongArrayElements(cond_output_handles_and_indices,
|
||||||
|
cond_output_elems, 0);
|
||||||
|
|
||||||
|
// set body_outputs param to outputs of the body subgraph
|
||||||
|
jlong* body_output_elems =
|
||||||
|
env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
|
||||||
|
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||||
|
TF_Operation* body_output_op =
|
||||||
|
requireOperationHandle(env, body_output_elems[i]);
|
||||||
|
params.body_outputs[i] = {body_output_op,
|
||||||
|
static_cast<jint>(body_output_elems[j])};
|
||||||
|
}
|
||||||
|
env->ReleaseLongArrayElements(body_output_handles_and_indices,
|
||||||
|
body_output_elems, 0);
|
||||||
|
|
||||||
|
// set loop name param
|
||||||
|
params.name = env->GetStringUTFChars(name, 0);
|
||||||
|
|
||||||
|
// build the while loop, storing loop outputs in `outputs`
|
||||||
|
std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
|
||||||
|
TF_FinishWhile(¶ms, status, outputs.get());
|
||||||
|
|
||||||
|
throwExceptionIfNotOK(env, status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
env->ReleaseStringUTFChars(name, params.name);
|
||||||
|
|
||||||
|
// returned array contains both op handles and output indices, in pair
|
||||||
|
jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
|
||||||
|
jlong* output_elems =
|
||||||
|
env->GetLongArrayElements(output_handles_and_indices, nullptr);
|
||||||
|
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||||
|
TF_Output output = outputs.get()[i];
|
||||||
|
output_elems[i] = reinterpret_cast<jlong>(output.oper);
|
||||||
|
output_elems[j] = static_cast<jlong>(output.index);
|
||||||
|
}
|
||||||
|
env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
|
||||||
|
|
||||||
|
return output_handles_and_indices;
|
||||||
|
}
|
||||||
|
@ -51,8 +51,9 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
|
|||||||
* Signature: (JI)[J
|
* Signature: (JI)[J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
|
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
|
||||||
jclass, jlong,
|
jclass,
|
||||||
jint);
|
jlong,
|
||||||
|
jint);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_tensorflow_Graph
|
* Class: org_tensorflow_Graph
|
||||||
@ -82,6 +83,15 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
|
|||||||
JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
|
JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
|
||||||
jintArray, jlongArray, jintArray);
|
jintArray, jlongArray, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: org_tensorflow_Graph
|
||||||
|
* Method: whileLoop
|
||||||
|
* Signature:
|
||||||
|
* (J[J[IILjava/lang/String;Lorg/tensorflow/Graph/WhileSubgraphBuilder;Lorg/tensorflow/Graph/WhileSubgraphBuilder;)[J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
|
||||||
|
JNIEnv *, jclass, jlong, jlongArray, jintArray, jstring, jobject, jobject);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -255,6 +255,114 @@ public class GraphTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void buildWhileLoopSingleInput() {
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session s = new Session(g)) {
|
||||||
|
|
||||||
|
Output<?> input = TestUtil.placeholder(g, "input1", Integer.class);
|
||||||
|
|
||||||
|
// could write this using lambda after Java 8
|
||||||
|
Graph.WhileSubgraphBuilder condGraphBuilder =
|
||||||
|
new Graph.WhileSubgraphBuilder() {
|
||||||
|
@Override
|
||||||
|
public void buildSubgraph(
|
||||||
|
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
|
||||||
|
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
|
||||||
|
// condInputs[0] < 16
|
||||||
|
Output<?> condOutput =
|
||||||
|
condGraph
|
||||||
|
.opBuilder("Less", "cond")
|
||||||
|
.addInput(condInputs[0])
|
||||||
|
.addInput(sixteen)
|
||||||
|
.build()
|
||||||
|
.output(0);
|
||||||
|
|
||||||
|
condOutputs[0] = condOutput;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// could write this using lambda after Java 8
|
||||||
|
Graph.WhileSubgraphBuilder bodyGraphBuilder =
|
||||||
|
new Graph.WhileSubgraphBuilder() {
|
||||||
|
@Override
|
||||||
|
public void buildSubgraph(
|
||||||
|
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
|
||||||
|
bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Output<?>[] loopOutputs =
|
||||||
|
g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop");
|
||||||
|
|
||||||
|
try (Tensor<Integer> c = Tensors.create(2);
|
||||||
|
Tensor<?> output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) {
|
||||||
|
|
||||||
|
assertEquals(16, output.intValue()); // ((2^2)^2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void buildWhileLoopMultipleInputs() {
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session s = new Session(g)) {
|
||||||
|
|
||||||
|
Output<?> input1 = TestUtil.placeholder(g, "input1", Integer.class);
|
||||||
|
Output<?> input2 = TestUtil.placeholder(g, "input2", Integer.class);
|
||||||
|
Output<?>[] inputs = toArray(input1, input2);
|
||||||
|
|
||||||
|
// could write this using lambda after Java 8
|
||||||
|
Graph.WhileSubgraphBuilder condGraphBuilder =
|
||||||
|
new Graph.WhileSubgraphBuilder() {
|
||||||
|
@Override
|
||||||
|
public void buildSubgraph(
|
||||||
|
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
|
||||||
|
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
|
||||||
|
Output<?> condOutput =
|
||||||
|
condGraph
|
||||||
|
.opBuilder("Less", "cond")
|
||||||
|
.addInput(condInputs[0])
|
||||||
|
.addInput(sixteen)
|
||||||
|
.build()
|
||||||
|
.output(0); // condInputs[0] < 16
|
||||||
|
|
||||||
|
condOutputs[0] = condOutput;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// could write this using lambda after Java 8
|
||||||
|
Graph.WhileSubgraphBuilder bodyGraphBuilder =
|
||||||
|
new Graph.WhileSubgraphBuilder() {
|
||||||
|
@Override
|
||||||
|
public void buildSubgraph(
|
||||||
|
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
|
||||||
|
bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]);
|
||||||
|
bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Output<?>[] loopOutputs =
|
||||||
|
g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop");
|
||||||
|
|
||||||
|
try (Tensor<Integer> c1 = Tensors.create(2);
|
||||||
|
Tensor<Integer> c2 = Tensors.create(5);
|
||||||
|
TestUtil.AutoCloseableList<Tensor<?>> outputs =
|
||||||
|
new TestUtil.AutoCloseableList<>(
|
||||||
|
s.runner()
|
||||||
|
.feed(input1, c1)
|
||||||
|
.feed(input2, c2)
|
||||||
|
.fetch(loopOutputs[0])
|
||||||
|
.fetch(loopOutputs[1])
|
||||||
|
.run())) {
|
||||||
|
|
||||||
|
assertEquals(2, outputs.size());
|
||||||
|
assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2)
|
||||||
|
assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static Output<?>[] toArray(Output<?>... outputs) {
|
private static Output<?>[] toArray(Output<?>... outputs) {
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user