Merge pull request #25949 from melissagrueter:while_loop

PiperOrigin-RevId: 235601480
This commit is contained in:
TensorFlower Gardener 2019-02-25 14:34:21 -08:00
commit 6799f6e990
4 changed files with 391 additions and 12 deletions

View File

@ -235,7 +235,116 @@ public final class Graph implements AutoCloseable {
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
return addGradients(null, new Output<?>[] {y}, x, null);
}
/**
* Used to instantiate an abstract class which overrides the buildSubgraph method to build a
* conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
* create a lambda for the same purpose.
*
* <p>To be used when calling {@link #whileLoop(Output[],
* org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
*
* <p>Example usage (prior to Java 8):
*
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public
* void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build
* body subgraph } }; }
*
* <p>Example usage (after Java 8):
*
* <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
* build body subgraph };}
*/
public interface WhileSubgraphBuilder {
/**
* To be overridden by user with code to build conditional or body subgraph for a while loop
*
* @param g the subgraph
* @param inputs subgraph inputs
* @param outputs subgraph outputs
*/
public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs);
}
// called by while loop code in graph_jni.cc to construct conditional/body subgraphs
private static long[] buildSubgraph(
WhileSubgraphBuilder subgraphBuilder,
long subgraphHandle,
long[] inputHandles,
int[] inputIndices,
long[] outputHandles,
int[] outputIndices) {
Graph subgraph = new Graph(subgraphHandle);
int ninputs = inputHandles.length;
int noutputs = outputHandles.length;
Output<?>[] inputs = new Output<?>[ninputs];
Output<?>[] outputs = new Output<?>[noutputs];
long[] outputHandlesAndIndices = new long[noutputs * 2];
synchronized (subgraph.nativeHandleLock) {
try (Reference ref = subgraph.ref()) {
for (int i = 0; i < ninputs; i++) {
Operation op = new Operation(subgraph, inputHandles[i]);
inputs[i] = new Output<>(op, inputIndices[i]);
}
for (int i = 0; i < noutputs; i++) {
Operation op = new Operation(subgraph, outputHandles[i]);
outputs[i] = new Output<>(op, outputIndices[i]);
}
subgraphBuilder.buildSubgraph(subgraph, inputs, outputs);
for (int i = 0, j = noutputs; i < noutputs; i++, j++) {
outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle();
outputHandlesAndIndices[j] = (long) outputs[i].index();
}
}
return outputHandlesAndIndices;
}
}
/**
* Builds a while loop.
*
* @param inputs the loop inputs
* @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph
* @param bgBuilder WhileSubgraphBuilder to build the body subgraph
* @param name name for the loop
* @return list of loop outputs, of the same length as {@code inputs}
*/
public Output<?>[] whileLoop(
Output<?>[] inputs,
WhileSubgraphBuilder cgBuilder,
WhileSubgraphBuilder bgBuilder,
String name) {
int ninputs = inputs.length;
long[] inputHandles = new long[ninputs];
int[] inputIndices = new int[ninputs];
Output<?>[] outputs = new Output<?>[ninputs];
synchronized (nativeHandleLock) {
try (Reference ref = ref()) {
for (int i = 0; i < ninputs; i++) {
inputHandles[i] = inputs[i].op().getUnsafeNativeHandle();
inputIndices[i] = inputs[i].index();
}
long[] outputHandlesAndIndices =
whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder);
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
Operation op = new Operation(this, outputHandlesAndIndices[i]);
outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]);
}
}
return outputs;
}
}
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;
@ -357,6 +466,14 @@ public final class Graph implements AutoCloseable {
long[] gradInputHandles,
int[] gradInputIndices);
private static native long[] whileLoop(
long handle,
long[] inputHandles,
int[] inputIndices,
String name,
WhileSubgraphBuilder condGraphBuilder,
WhileSubgraphBuilder bodyGraphBuilder);
static {
TensorFlow.init();
}

View File

@ -18,19 +18,28 @@ limitations under the License.
#include <limits>
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/utils_jni.h"
namespace {
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
static_assert(sizeof(jlong) >= sizeof(TF_Graph*),
template <class T>
T* requireHandleImpl(JNIEnv* env, jlong handle) {
static_assert(sizeof(jlong) >= sizeof(T*),
"Cannot package C object pointers as a Java long");
if (handle == 0) {
throwException(env, kIllegalStateException,
"close() has been called on the Graph");
return nullptr;
}
return reinterpret_cast<TF_Graph*>(handle);
return reinterpret_cast<T*>(handle);
}
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
return requireHandleImpl<TF_Graph>(env, handle);
}
TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
return requireHandleImpl<TF_Operation>(env, handle);
}
} // namespace
@ -56,10 +65,8 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
return reinterpret_cast<jlong>(op);
}
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv* env,
jclass clazz,
jlong handle,
jint position) {
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
JNIEnv* env, jclass clazz, jlong handle, jint position) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return nullptr;
@ -189,3 +196,140 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
return dy_handles_and_indices;
}
// helper function for while loop -- constructs conditional or body subgraph
jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
TF_Graph* const subgraph,
const TF_Output* const inputs,
const TF_Output* const outputs, const int ninputs,
const int noutputs) {
jmethodID build_subgraph_method_id = env->GetStaticMethodID(
clazz, "buildSubgraph",
"(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
if (build_subgraph_method_id == 0) return nullptr;
jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
jlongArray input_handles = env->NewLongArray(ninputs);
jintArray input_indices = env->NewIntArray(ninputs);
jlongArray output_handles = env->NewLongArray(noutputs);
jintArray output_indices = env->NewIntArray(noutputs);
jlong* input_handles_elems =
env->GetLongArrayElements(input_handles, nullptr);
jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
jlong* output_handles_elems =
env->GetLongArrayElements(output_handles, nullptr);
jint* output_indices_elems =
env->GetIntArrayElements(output_indices, nullptr);
for (int i = 0; i < ninputs; ++i) {
input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
}
for (int i = 0; i < noutputs; ++i) {
output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
}
env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
// call Java code to construct the subgraph
jlongArray output_handles_and_indices =
(jlongArray)env->CallStaticObjectMethod(
clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
input_handles, input_indices, output_handles, output_indices);
if (env->ExceptionOccurred()) {
env->ExceptionDescribe();
return nullptr;
}
// returned array contains both op handles and output indices, in pair
return output_handles_and_indices;
}
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
jintArray input_indices, jstring name, jobject cond_graph_builder,
jobject body_graph_builder) {
TF_Graph* g = requireHandle(env, handle);
TF_Status* status = TF_NewStatus();
if (g == nullptr) return nullptr;
int ninputs = env->GetArrayLength(input_handles);
std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
ninputs);
if (env->ExceptionCheck()) return nullptr;
// initialize while params
TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
throwExceptionIfNotOK(env, status);
// build conditional subgraph
jlongArray cond_output_handles_and_indices =
buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
params.cond_inputs, &params.cond_output, params.ninputs, 1);
// build body subgraph
jlongArray body_output_handles_and_indices = buildSubgraph(
env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
params.body_outputs, params.ninputs, params.ninputs);
if (cond_output_handles_and_indices == nullptr ||
body_output_handles_and_indices == nullptr)
return nullptr;
// set cond_output param to output of the conditional subgraph
jlong* cond_output_elems =
env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
TF_Operation* cond_output_op =
requireOperationHandle(env, cond_output_elems[0]);
params.cond_output = {cond_output_op,
static_cast<jint>(cond_output_elems[1])};
env->ReleaseLongArrayElements(cond_output_handles_and_indices,
cond_output_elems, 0);
// set body_outputs param to outputs of the body subgraph
jlong* body_output_elems =
env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
TF_Operation* body_output_op =
requireOperationHandle(env, body_output_elems[i]);
params.body_outputs[i] = {body_output_op,
static_cast<jint>(body_output_elems[j])};
}
env->ReleaseLongArrayElements(body_output_handles_and_indices,
body_output_elems, 0);
// set loop name param
params.name = env->GetStringUTFChars(name, 0);
// build the while loop, storing loop outputs in `outputs`
std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
TF_FinishWhile(&params, status, outputs.get());
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
env->ReleaseStringUTFChars(name, params.name);
// returned array contains both op handles and output indices, in pair
jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
jlong* output_elems =
env->GetLongArrayElements(output_handles_and_indices, nullptr);
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
TF_Output output = outputs.get()[i];
output_elems[i] = reinterpret_cast<jlong>(output.oper);
output_elems[j] = static_cast<jlong>(output.index);
}
env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
return output_handles_and_indices;
}

View File

@ -51,8 +51,9 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
* Signature: (JI)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
jclass, jlong,
jint);
jclass,
jlong,
jint);
/*
* Class: org_tensorflow_Graph
@ -82,6 +83,15 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
jintArray, jlongArray, jintArray);
/*
* Class: org_tensorflow_Graph
* Method: whileLoop
* Signature:
* (J[J[IILjava/lang/String;Lorg/tensorflow/Graph/WhileSubgraphBuilder;Lorg/tensorflow/Graph/WhileSubgraphBuilder;)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
JNIEnv *, jclass, jlong, jlongArray, jintArray, jstring, jobject, jobject);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -254,7 +254,115 @@ public class GraphTest {
}
}
}
@Test
public void buildWhileLoopSingleInput() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Output<?> input = TestUtil.placeholder(g, "input1", Integer.class);
// could write this using lambda after Java 8
Graph.WhileSubgraphBuilder condGraphBuilder =
new Graph.WhileSubgraphBuilder() {
@Override
public void buildSubgraph(
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
// condInputs[0] < 16
Output<?> condOutput =
condGraph
.opBuilder("Less", "cond")
.addInput(condInputs[0])
.addInput(sixteen)
.build()
.output(0);
condOutputs[0] = condOutput;
}
};
// could write this using lambda after Java 8
Graph.WhileSubgraphBuilder bodyGraphBuilder =
new Graph.WhileSubgraphBuilder() {
@Override
public void buildSubgraph(
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]);
}
};
Output<?>[] loopOutputs =
g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop");
try (Tensor<Integer> c = Tensors.create(2);
Tensor<?> output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) {
assertEquals(16, output.intValue()); // ((2^2)^2)
}
}
}
@Test
public void buildWhileLoopMultipleInputs() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Output<?> input1 = TestUtil.placeholder(g, "input1", Integer.class);
Output<?> input2 = TestUtil.placeholder(g, "input2", Integer.class);
Output<?>[] inputs = toArray(input1, input2);
// could write this using lambda after Java 8
Graph.WhileSubgraphBuilder condGraphBuilder =
new Graph.WhileSubgraphBuilder() {
@Override
public void buildSubgraph(
Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
Output<?> condOutput =
condGraph
.opBuilder("Less", "cond")
.addInput(condInputs[0])
.addInput(sixteen)
.build()
.output(0); // condInputs[0] < 16
condOutputs[0] = condOutput;
}
};
// could write this using lambda after Java 8
Graph.WhileSubgraphBuilder bodyGraphBuilder =
new Graph.WhileSubgraphBuilder() {
@Override
public void buildSubgraph(
Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]);
bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]);
}
};
Output<?>[] loopOutputs =
g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop");
try (Tensor<Integer> c1 = Tensors.create(2);
Tensor<Integer> c2 = Tensors.create(5);
TestUtil.AutoCloseableList<Tensor<?>> outputs =
new TestUtil.AutoCloseableList<>(
s.runner()
.feed(input1, c1)
.feed(input2, c2)
.fetch(loopOutputs[0])
.fetch(loopOutputs[1])
.run())) {
assertEquals(2, outputs.size());
assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2)
assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2)
}
}
}
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}