Merge pull request #25949 from melissagrueter:while_loop
PiperOrigin-RevId: 235601480
This commit is contained in:
commit
6799f6e990
@ -235,7 +235,116 @@ public final class Graph implements AutoCloseable {
|
||||
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
|
||||
return addGradients(null, new Output<?>[] {y}, x, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Used to instantiate an abstract class which overrides the buildSubgraph method to build a
|
||||
* conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
|
||||
* create a lambda for the same purpose.
|
||||
*
|
||||
* <p>To be used when calling {@link #whileLoop(Output[],
|
||||
* org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
|
||||
*
|
||||
* <p>Example usage (prior to Java 8):
|
||||
*
|
||||
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public
|
||||
* void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build
|
||||
* body subgraph } }; }
|
||||
*
|
||||
* <p>Example usage (after Java 8):
|
||||
*
|
||||
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
|
||||
* build body subgraph };}
|
||||
*/
|
||||
public interface WhileSubgraphBuilder {
|
||||
/**
|
||||
* To be overridden by user with code to build conditional or body subgraph for a while loop
|
||||
*
|
||||
* @param g the subgraph
|
||||
* @param inputs subgraph inputs
|
||||
* @param outputs subgraph outputs
|
||||
*/
|
||||
public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs);
|
||||
}
|
||||
|
||||
// called by while loop code in graph_jni.cc to construct conditional/body subgraphs
|
||||
private static long[] buildSubgraph(
|
||||
WhileSubgraphBuilder subgraphBuilder,
|
||||
long subgraphHandle,
|
||||
long[] inputHandles,
|
||||
int[] inputIndices,
|
||||
long[] outputHandles,
|
||||
int[] outputIndices) {
|
||||
Graph subgraph = new Graph(subgraphHandle);
|
||||
|
||||
int ninputs = inputHandles.length;
|
||||
int noutputs = outputHandles.length;
|
||||
Output<?>[] inputs = new Output<?>[ninputs];
|
||||
Output<?>[] outputs = new Output<?>[noutputs];
|
||||
long[] outputHandlesAndIndices = new long[noutputs * 2];
|
||||
|
||||
synchronized (subgraph.nativeHandleLock) {
|
||||
try (Reference ref = subgraph.ref()) {
|
||||
|
||||
for (int i = 0; i < ninputs; i++) {
|
||||
Operation op = new Operation(subgraph, inputHandles[i]);
|
||||
inputs[i] = new Output<>(op, inputIndices[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < noutputs; i++) {
|
||||
Operation op = new Operation(subgraph, outputHandles[i]);
|
||||
outputs[i] = new Output<>(op, outputIndices[i]);
|
||||
}
|
||||
|
||||
subgraphBuilder.buildSubgraph(subgraph, inputs, outputs);
|
||||
|
||||
for (int i = 0, j = noutputs; i < noutputs; i++, j++) {
|
||||
outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle();
|
||||
outputHandlesAndIndices[j] = (long) outputs[i].index();
|
||||
}
|
||||
}
|
||||
return outputHandlesAndIndices;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a while loop.
|
||||
*
|
||||
* @param inputs the loop inputs
|
||||
* @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph
|
||||
* @param bgBuilder WhileSubgraphBuilder to build the body subgraph
|
||||
* @param name name for the loop
|
||||
* @return list of loop outputs, of the same length as {@code inputs}
|
||||
*/
|
||||
public Output<?>[] whileLoop(
|
||||
Output<?>[] inputs,
|
||||
WhileSubgraphBuilder cgBuilder,
|
||||
WhileSubgraphBuilder bgBuilder,
|
||||
String name) {
|
||||
int ninputs = inputs.length;
|
||||
long[] inputHandles = new long[ninputs];
|
||||
int[] inputIndices = new int[ninputs];
|
||||
Output<?>[] outputs = new Output<?>[ninputs];
|
||||
|
||||
synchronized (nativeHandleLock) {
|
||||
try (Reference ref = ref()) {
|
||||
|
||||
for (int i = 0; i < ninputs; i++) {
|
||||
inputHandles[i] = inputs[i].op().getUnsafeNativeHandle();
|
||||
inputIndices[i] = inputs[i].index();
|
||||
}
|
||||
|
||||
long[] outputHandlesAndIndices =
|
||||
whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder);
|
||||
|
||||
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||
Operation op = new Operation(this, outputHandlesAndIndices[i]);
|
||||
outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]);
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
}
|
||||
|
||||
private final Object nativeHandleLock = new Object();
|
||||
private long nativeHandle;
|
||||
private int refcount = 0;
|
||||
@ -357,6 +466,14 @@ public final class Graph implements AutoCloseable {
|
||||
long[] gradInputHandles,
|
||||
int[] gradInputIndices);
|
||||
|
||||
private static native long[] whileLoop(
|
||||
long handle,
|
||||
long[] inputHandles,
|
||||
int[] inputIndices,
|
||||
String name,
|
||||
WhileSubgraphBuilder condGraphBuilder,
|
||||
WhileSubgraphBuilder bodyGraphBuilder);
|
||||
|
||||
static {
|
||||
TensorFlow.init();
|
||||
}
|
||||
|
@ -18,19 +18,28 @@ limitations under the License.
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/utils_jni.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
#include "tensorflow/java/src/main/native/utils_jni.h"
|
||||
|
||||
namespace {
|
||||
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
|
||||
static_assert(sizeof(jlong) >= sizeof(TF_Graph*),
|
||||
template <class T>
|
||||
T* requireHandleImpl(JNIEnv* env, jlong handle) {
|
||||
static_assert(sizeof(jlong) >= sizeof(T*),
|
||||
"Cannot package C object pointers as a Java long");
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException,
|
||||
"close() has been called on the Graph");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TF_Graph*>(handle);
|
||||
return reinterpret_cast<T*>(handle);
|
||||
}
|
||||
|
||||
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
|
||||
return requireHandleImpl<TF_Graph>(env, handle);
|
||||
}
|
||||
|
||||
TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
|
||||
return requireHandleImpl<TF_Operation>(env, handle);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -56,10 +65,8 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
|
||||
return reinterpret_cast<jlong>(op);
|
||||
}
|
||||
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle,
|
||||
jint position) {
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jint position) {
|
||||
TF_Graph* g = requireHandle(env, handle);
|
||||
if (g == nullptr) return nullptr;
|
||||
|
||||
@ -189,3 +196,140 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
|
||||
|
||||
return dy_handles_and_indices;
|
||||
}
|
||||
|
||||
// helper function for while loop -- constructs conditional or body subgraph
|
||||
jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
|
||||
TF_Graph* const subgraph,
|
||||
const TF_Output* const inputs,
|
||||
const TF_Output* const outputs, const int ninputs,
|
||||
const int noutputs) {
|
||||
jmethodID build_subgraph_method_id = env->GetStaticMethodID(
|
||||
clazz, "buildSubgraph",
|
||||
"(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
|
||||
if (build_subgraph_method_id == 0) return nullptr;
|
||||
|
||||
jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
|
||||
|
||||
jlongArray input_handles = env->NewLongArray(ninputs);
|
||||
jintArray input_indices = env->NewIntArray(ninputs);
|
||||
jlongArray output_handles = env->NewLongArray(noutputs);
|
||||
jintArray output_indices = env->NewIntArray(noutputs);
|
||||
|
||||
jlong* input_handles_elems =
|
||||
env->GetLongArrayElements(input_handles, nullptr);
|
||||
jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
|
||||
jlong* output_handles_elems =
|
||||
env->GetLongArrayElements(output_handles, nullptr);
|
||||
jint* output_indices_elems =
|
||||
env->GetIntArrayElements(output_indices, nullptr);
|
||||
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
|
||||
input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
|
||||
}
|
||||
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
|
||||
output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
|
||||
}
|
||||
|
||||
env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
|
||||
env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
|
||||
env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
|
||||
env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
|
||||
|
||||
// call Java code to construct the subgraph
|
||||
jlongArray output_handles_and_indices =
|
||||
(jlongArray)env->CallStaticObjectMethod(
|
||||
clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
|
||||
input_handles, input_indices, output_handles, output_indices);
|
||||
|
||||
if (env->ExceptionOccurred()) {
|
||||
env->ExceptionDescribe();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// returned array contains both op handles and output indices, in pair
|
||||
return output_handles_and_indices;
|
||||
}
|
||||
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
|
||||
jintArray input_indices, jstring name, jobject cond_graph_builder,
|
||||
jobject body_graph_builder) {
|
||||
TF_Graph* g = requireHandle(env, handle);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
if (g == nullptr) return nullptr;
|
||||
|
||||
int ninputs = env->GetArrayLength(input_handles);
|
||||
|
||||
std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
|
||||
resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
|
||||
ninputs);
|
||||
if (env->ExceptionCheck()) return nullptr;
|
||||
|
||||
// initialize while params
|
||||
TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
|
||||
// build conditional subgraph
|
||||
jlongArray cond_output_handles_and_indices =
|
||||
buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
|
||||
params.cond_inputs, ¶ms.cond_output, params.ninputs, 1);
|
||||
|
||||
// build body subgraph
|
||||
jlongArray body_output_handles_and_indices = buildSubgraph(
|
||||
env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
|
||||
params.body_outputs, params.ninputs, params.ninputs);
|
||||
|
||||
if (cond_output_handles_and_indices == nullptr ||
|
||||
body_output_handles_and_indices == nullptr)
|
||||
return nullptr;
|
||||
|
||||
// set cond_output param to output of the conditional subgraph
|
||||
jlong* cond_output_elems =
|
||||
env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
|
||||
TF_Operation* cond_output_op =
|
||||
requireOperationHandle(env, cond_output_elems[0]);
|
||||
params.cond_output = {cond_output_op,
|
||||
static_cast<jint>(cond_output_elems[1])};
|
||||
env->ReleaseLongArrayElements(cond_output_handles_and_indices,
|
||||
cond_output_elems, 0);
|
||||
|
||||
// set body_outputs param to outputs of the body subgraph
|
||||
jlong* body_output_elems =
|
||||
env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
|
||||
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||
TF_Operation* body_output_op =
|
||||
requireOperationHandle(env, body_output_elems[i]);
|
||||
params.body_outputs[i] = {body_output_op,
|
||||
static_cast<jint>(body_output_elems[j])};
|
||||
}
|
||||
env->ReleaseLongArrayElements(body_output_handles_and_indices,
|
||||
body_output_elems, 0);
|
||||
|
||||
// set loop name param
|
||||
params.name = env->GetStringUTFChars(name, 0);
|
||||
|
||||
// build the while loop, storing loop outputs in `outputs`
|
||||
std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
|
||||
TF_FinishWhile(¶ms, status, outputs.get());
|
||||
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
env->ReleaseStringUTFChars(name, params.name);
|
||||
|
||||
// returned array contains both op handles and output indices, in pair
|
||||
jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
|
||||
jlong* output_elems =
|
||||
env->GetLongArrayElements(output_handles_and_indices, nullptr);
|
||||
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
|
||||
TF_Output output = outputs.get()[i];
|
||||
output_elems[i] = reinterpret_cast<jlong>(output.oper);
|
||||
output_elems[j] = static_cast<jlong>(output.index);
|
||||
}
|
||||
env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
|
||||
|
||||
return output_handles_and_indices;
|
||||
}
|
||||
|
@ -51,8 +51,9 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
|
||||
* Signature: (JI)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
|
||||
jclass, jlong,
|
||||
jint);
|
||||
jclass,
|
||||
jlong,
|
||||
jint);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_Graph
|
||||
@ -82,6 +83,15 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
|
||||
JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
|
||||
jintArray, jlongArray, jintArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_Graph
|
||||
* Method: whileLoop
|
||||
* Signature:
|
||||
* (J[J[IILjava/lang/String;Lorg/tensorflow/Graph/WhileSubgraphBuilder;Lorg/tensorflow/Graph/WhileSubgraphBuilder;)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
|
||||
JNIEnv *, jclass, jlong, jlongArray, jintArray, jstring, jobject, jobject);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -254,7 +254,115 @@ public class GraphTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void buildWhileLoopSingleInput() {
|
||||
try (Graph g = new Graph();
|
||||
Session s = new Session(g)) {
|
||||
|
||||
Output<?> input = TestUtil.placeholder(g, "input1", Integer.class);
|
||||
|
||||
// could write this using lambda after Java 8
|
||||
Graph.WhileSubgraphBuilder condGraphBuilder =
|
||||
new Graph.WhileSubgraphBuilder() {
|
||||
@Override
|
||||
public void buildSubgraph(
|
||||
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
|
||||
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
|
||||
// condInputs[0] < 16
|
||||
Output<?> condOutput =
|
||||
condGraph
|
||||
.opBuilder("Less", "cond")
|
||||
.addInput(condInputs[0])
|
||||
.addInput(sixteen)
|
||||
.build()
|
||||
.output(0);
|
||||
|
||||
condOutputs[0] = condOutput;
|
||||
}
|
||||
};
|
||||
|
||||
// could write this using lambda after Java 8
|
||||
Graph.WhileSubgraphBuilder bodyGraphBuilder =
|
||||
new Graph.WhileSubgraphBuilder() {
|
||||
@Override
|
||||
public void buildSubgraph(
|
||||
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
|
||||
bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]);
|
||||
}
|
||||
};
|
||||
|
||||
Output<?>[] loopOutputs =
|
||||
g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop");
|
||||
|
||||
try (Tensor<Integer> c = Tensors.create(2);
|
||||
Tensor<?> output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) {
|
||||
|
||||
assertEquals(16, output.intValue()); // ((2^2)^2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void buildWhileLoopMultipleInputs() {
|
||||
try (Graph g = new Graph();
|
||||
Session s = new Session(g)) {
|
||||
|
||||
Output<?> input1 = TestUtil.placeholder(g, "input1", Integer.class);
|
||||
Output<?> input2 = TestUtil.placeholder(g, "input2", Integer.class);
|
||||
Output<?>[] inputs = toArray(input1, input2);
|
||||
|
||||
// could write this using lambda after Java 8
|
||||
Graph.WhileSubgraphBuilder condGraphBuilder =
|
||||
new Graph.WhileSubgraphBuilder() {
|
||||
@Override
|
||||
public void buildSubgraph(
|
||||
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
|
||||
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
|
||||
Output<?> condOutput =
|
||||
condGraph
|
||||
.opBuilder("Less", "cond")
|
||||
.addInput(condInputs[0])
|
||||
.addInput(sixteen)
|
||||
.build()
|
||||
.output(0); // condInputs[0] < 16
|
||||
|
||||
condOutputs[0] = condOutput;
|
||||
}
|
||||
};
|
||||
|
||||
// could write this using lambda after Java 8
|
||||
Graph.WhileSubgraphBuilder bodyGraphBuilder =
|
||||
new Graph.WhileSubgraphBuilder() {
|
||||
@Override
|
||||
public void buildSubgraph(
|
||||
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
|
||||
bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]);
|
||||
bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]);
|
||||
}
|
||||
};
|
||||
|
||||
Output<?>[] loopOutputs =
|
||||
g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop");
|
||||
|
||||
try (Tensor<Integer> c1 = Tensors.create(2);
|
||||
Tensor<Integer> c2 = Tensors.create(5);
|
||||
TestUtil.AutoCloseableList<Tensor<?>> outputs =
|
||||
new TestUtil.AutoCloseableList<>(
|
||||
s.runner()
|
||||
.feed(input1, c1)
|
||||
.feed(input2, c2)
|
||||
.fetch(loopOutputs[0])
|
||||
.fetch(loopOutputs[1])
|
||||
.run())) {
|
||||
|
||||
assertEquals(2, outputs.size());
|
||||
assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2)
|
||||
assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static Output<?>[] toArray(Output<?>... outputs) {
|
||||
return outputs;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user