Merge pull request #28212 from karllessard:java-eager-operation-builder

PiperOrigin-RevId: 246208940
This commit is contained in:
TensorFlower Gardener 2019-05-01 17:51:35 -07:00
commit cba813e758
10 changed files with 850 additions and 45 deletions

View File

@ -146,6 +146,19 @@ tf_java_test(
],
)
tf_java_test(
name = "EagerOperationBuilderTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/EagerOperationBuilderTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.EagerOperationBuilderTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "GraphTest",
size = "small",

View File

@ -0,0 +1,254 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
/**
* An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly.
*/
final class EagerOperationBuilder implements OperationBuilder {
EagerOperationBuilder(EagerSession session, String type, String name) {
this.session = session;
this.type = type;
this.name = name;
this.nativeRef = new NativeReference(session, this, allocate(session.nativeHandle(), type));
}
@Override
public Operation build() {
// TODO (karllessard) Execute the eager operation and pass output tensor handles to new
// EagerOperation class
throw new UnsupportedOperationException("Eager execution is not supported yet");
}
@Override
public EagerOperationBuilder addInput(Output<?> input) {
addInput(nativeRef.opHandle, input.getUnsafeNativeHandle());
return this;
}
@Override
public OperationBuilder addInputList(Output<?>[] inputs) {
long[] inputHandles = new long[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
}
addInputList(nativeRef.opHandle, inputHandles);
return this;
}
@Override
public OperationBuilder addControlInput(Operation control) {
throw new UnsupportedOperationException(
"Control inputs are not supported in an eager execution environment");
}
@Override
public OperationBuilder setDevice(String device) {
setDevice(nativeRef.opHandle, device);
return this;
}
@Override
public OperationBuilder setAttr(String name, String value) {
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
}
@Override
public OperationBuilder setAttr(String name, String[] values) {
Charset utf8 = StandardCharsets.UTF_8;
Object[] objects = new Object[values.length];
for (int i = 0; i < values.length; ++i) {
objects[i] = values[i].getBytes(utf8);
}
setAttrStringList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, byte[] values) {
setAttrString(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, long value) {
setAttrInt(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, long[] values) {
setAttrIntList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, float value) {
setAttrFloat(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, float[] values) {
setAttrFloatList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, boolean value) {
setAttrBool(nativeRef.opHandle, name, value);
return this;
}
@Override
public OperationBuilder setAttr(String name, boolean[] values) {
setAttrBoolList(nativeRef.opHandle, name, values);
return this;
}
@Override
public OperationBuilder setAttr(String name, DataType value) {
setAttrType(nativeRef.opHandle, name, value.c());
return this;
}
@Override
public OperationBuilder setAttr(String name, DataType[] values) {
int[] c = new int[values.length];
for (int i = 0; i < values.length; ++i) {
c[i] = values[i].c();
}
setAttrTypeList(nativeRef.opHandle, name, c);
return this;
}
@Override
public OperationBuilder setAttr(String name, Tensor<?> value) {
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
return this;
}
@Override
public OperationBuilder setAttr(String name, Tensor<?>[] values) {
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
throw new UnsupportedOperationException(
"Tensor list attributes are not supported in eager mode");
}
@Override
public OperationBuilder setAttr(String name, Shape value) {
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
return this;
}
@Override
public OperationBuilder setAttr(String name, Shape[] values) {
int[] numDimensions = new int[values.length];
int totalNumDimensions = 0;
for (int idx = 0; idx < values.length; ++idx) {
int n = values[idx].numDimensions();
numDimensions[idx] = n;
if (n > 0) {
totalNumDimensions += n;
}
}
// Flatten the shapes into a single array to avoid too much overhead in the
// native part
long[] shapes = new long[totalNumDimensions];
int shapeIdx = 0;
for (Shape shape : values) {
if (shape.numDimensions() > 0) {
for (long dim : shape.asArray()) {
shapes[shapeIdx++] = dim;
}
}
}
setAttrShapeList(nativeRef.opHandle, name, shapes, numDimensions);
return this;
}
private static class NativeReference extends EagerSession.NativeReference {
NativeReference(EagerSession session, EagerOperationBuilder operation, long opHandle) {
super(session, operation);
this.opHandle = opHandle;
}
@Override
public void clear() {
super.clear();
opHandle = 0L;
}
@Override
synchronized void delete() {
if (opHandle != 0L) {
EagerOperationBuilder.delete(opHandle);
opHandle = 0L;
}
}
private long opHandle;
}
private final EagerSession session;
private final String type;
private final String name;
private final NativeReference nativeRef;
private static native long allocate(long ctxHandle, String type);
private static native void delete(long opHandle);
private static native long[] execute(long opHandle);
private static native void addInput(long opHandle, long tensorHandle);
private static native void addInputList(long opHandle, long[] tensorHandles);
private static native void setDevice(long opHandle, String device);
private static native void setAttrString(long opHandle, String name, byte[] value);
private static native void setAttrStringList(long opHandle, String name, Object[] value);
private static native void setAttrInt(long opHandle, String name, long value);
private static native void setAttrIntList(long opHandle, String name, long[] values);
private static native void setAttrFloat(long opHandle, String name, float value);
private static native void setAttrFloatList(long opHandle, String name, float[] values);
private static native void setAttrBool(long opHandle, String name, boolean value);
private static native void setAttrBoolList(long opHandle, String name, boolean[] values);
private static native void setAttrType(long opHandle, String name, int type);
private static native void setAttrTypeList(long opHandle, String name, int[] types);
private static native void setAttrTensor(long opHandle, String name, long tensorHandle);
private static native void setAttrShape(long opHandle, String name, long[] shape, int numDims);
private static native void setAttrShapeList(
long opHandle, String name, long[] shapes, int[] numDims);
}

View File

@ -245,9 +245,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
nativeResources.tryCleanup();
}
checkSession();
// TODO (karllessard) create a new EagerOperationBuilder
throw new UnsupportedOperationException(
"Eager execution mode is not yet supported in this version of TensorFlow");
return new EagerOperationBuilder(this, type, name);
}
long nativeHandle() {
checkSession();
return nativeHandle;
}
/**
@ -408,8 +411,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
private static native void delete(long handle);
private static native long allocateOperation(long contextHandle, String name);
static {
TensorFlow.init();
}

View File

@ -523,7 +523,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*/
public static Constant<String> create(Scope scope, String data, Charset charset) {
try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class));
return createWithTensor(scope, value);
}
}

View File

@ -0,0 +1,332 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/java/src/main/native/eager_operation_builder_jni.h"
#include <cstring>
#include <memory>
#include <set>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException,
"Operation has already been built");
return nullptr;
}
return reinterpret_cast<TFE_Op*>(handle);
}
TFE_Context* requireContext(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException, "Context has been deleted");
return nullptr;
}
return reinterpret_cast<TFE_Context*>(handle);
}
TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException,
"close() has been called on the Tensor");
return nullptr;
}
return reinterpret_cast<TF_Tensor*>(handle);
}
TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException,
"Tensor handle has been deleted");
return nullptr;
}
return reinterpret_cast<TFE_TensorHandle*>(handle);
}
} // namespace
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
JNIEnv* env, jclass clazz, jlong context_handle, jstring name) {
TFE_Context* context = requireContext(env, context_handle);
if (context == nullptr) return 0;
const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
env->ReleaseStringUTFChars(name, op_or_function_name);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
"Cannot represent a C TFE_Op as a Java long");
return reinterpret_cast<jlong>(op);
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete(
JNIEnv* env, jclass clazz, jlong op_handle) {
if (op_handle == 0) return;
TFE_DeleteOp(reinterpret_cast<TFE_Op*>(op_handle));
}
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
JNIEnv* env, jclass clazz, jlong op_handle) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return 0;
int num_retvals = 8; // should be >= than max number of outputs in any op
std::unique_ptr<TFE_TensorHandle*[]> retvals(
new TFE_TensorHandle*[num_retvals]);
TF_Status* status = TF_NewStatus();
TFE_Execute(op, retvals.get(), &num_retvals, status);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
}
TF_DeleteStatus(status);
jlongArray rethandles = env->NewLongArray(num_retvals);
if (num_retvals > 0) {
jlong* retval = env->GetLongArrayElements(rethandles, nullptr);
for (int i = 0; i < num_retvals; ++i) {
retval[i] = reinterpret_cast<jlong>(retvals[i]);
}
env->ReleaseLongArrayElements(rethandles, retval, 0);
}
return rethandles;
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
const char* cname = env->GetStringUTFChars(device_name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_OpSetDevice(op, cname, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
env->ReleaseStringUTFChars(device_name, cname);
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle);
if (tensor_handle == nullptr) return;
TF_Status* status = TF_NewStatus();
TFE_OpAddInput(op, tensor_handle, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr);
size_t num_inputs = static_cast<size_t>(env->GetArrayLength(input_handles));
std::unique_ptr<TFE_TensorHandle*[]> tensor_handles(
new TFE_TensorHandle*[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]);
if (tensor_handles[i] == nullptr) {
env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
return;
}
}
env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
TF_Status* status = TF_NewStatus();
TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
jbyteArray value) {
static_assert(sizeof(jbyte) == 1,
"Require Java byte to be represented as a single byte");
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value));
env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
env->ReleaseStringUTFChars(attr_name, cname);
}
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(
JNIEnv* env, jclass object, jlong op_handle, jstring attr_name,
jobjectArray values) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
int num_values = env->GetArrayLength(values);
static_assert(sizeof(jbyte) == 1,
"Require Java byte to be represented as a single byte");
std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
std::unique_ptr<void*[]> cvalues(new void*[num_values]);
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
for (int i = 0; i < num_values; ++i) {
jbyteArray v =
static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
jarrays[i] = v;
jvalues[i] = env->GetByteArrayElements(v, nullptr);
cvalues[i] = jvalues[i];
lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
}
TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values);
for (int i = 0; i < num_values; ++i) {
env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
}
env->ReleaseStringUTFChars(attr_name, cname);
}
#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
JNIEXPORT void JNICALL \
Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
jtype value) { \
static_assert( \
sizeof(ctype) >= sizeof(jtype), \
"Information loss when converting between Java and C types"); \
TFE_Op* op = requireOp(env, op_handle); \
if (op == nullptr) return; \
const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
TFE_OpSetAttr##name(op, cname, static_cast<ctype>(value)); \
env->ReleaseStringUTFChars(attr_name, cname); \
}
#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
JNIEXPORT void JNICALL \
Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
jtype##Array value) { \
TFE_Op* op = requireOp(env, op_handle); \
if (op == nullptr) return; \
const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
/* Make a copy of the array to paper over any differences */ \
/* in byte representations of the jtype and ctype */ \
/* For example, jint vs TF_DataType. */ \
/* If this copy turns out to be a problem in practice */ \
/* can avoid it for many types. */ \
const int n = env->GetArrayLength(value); \
std::unique_ptr<ctype[]> cvalue(new ctype[n]); \
jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
for (int i = 0; i < n; ++i) { \
cvalue[i] = static_cast<ctype>(elems[i]); \
} \
TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \
env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
env->ReleaseStringUTFChars(attr_name, cname); \
}
#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
DEFINE_SET_ATTR(Float, Float, jfloat, float);
DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
#undef DEFINE_SET_ATTR
#undef DEFINE_SET_ATTR_LIST
#undef DEFINE_SET_ATTR_SCALAR
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
JNIEnv* env, jclass clazz, jlong handle, jstring attr_name,
jlong tensor_handle) {
TFE_Op* op = requireOp(env, handle);
if (op == nullptr) return;
TF_Tensor* t = requireTensor(env, tensor_handle);
if (t == nullptr) return;
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_OpSetAttrTensor(op, cname, t, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
env->ReleaseStringUTFChars(attr_name, cname);
}
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
jlongArray shape, jint num_dims) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
std::unique_ptr<int64_t[]> cvalue;
// num_dims and env->GetArrayLength(shape) are assumed to be consistent.
// i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
if (num_dims > 0) {
cvalue.reset(new int64_t[num_dims]);
jlong* elems = env->GetLongArrayElements(shape, nullptr);
for (int i = 0; i < num_dims; ++i) {
cvalue[i] = static_cast<int64_t>(elems[i]);
}
env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
}
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast<int>(num_dims),
status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
env->ReleaseStringUTFChars(attr_name, cname);
}
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
jlongArray shapes, jintArray num_dims) {
TFE_Op* op = requireOp(env, op_handle);
if (op == nullptr) return;
std::unique_ptr<int64_t[]> cshapes;
std::unique_ptr<const int64_t*[]> cdims;
std::unique_ptr<int[]> cnum_dims;
const int num_dims_length = env->GetArrayLength(num_dims);
if (num_dims_length > 0) {
const int shapes_length = env->GetArrayLength(shapes);
cshapes.reset(new int64_t[shapes_length]);
cdims.reset(new const int64_t*[num_dims_length]);
cnum_dims.reset(new int[num_dims_length]);
jlong* shapes_elems =
static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
int64_t* cshapes_ptr = cshapes.get();
jint* num_dims_elems =
static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
for (int i = 0; i < num_dims_length; ++i) {
cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
cdims[i] = cshapes_ptr;
if (cnum_dims[i] > 0) {
cshapes_ptr += cnum_dims[i];
}
}
env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
}
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(),
num_dims_length, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
env->ReleaseStringUTFChars(attr_name, cname);
}

View File

@ -0,0 +1,191 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: allocate
* Signature: (JLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: delete
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv *, jclass, jlong);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: execute
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv *, jclass, jlong);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: addInput
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
JNIEnv *, jclass, jlong, jlong);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: addInputList
* Signature: (J[J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setDevice
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrString
* Signature: (JLjava/lang/String;[B)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
JNIEnv *, jclass, jlong, jstring, jbyteArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrStringList
* Signature: (JLjava/lang/String;[L)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv *, jclass,
jlong, jstring,
jobjectArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrInt
* Signature: (JLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrInt(
JNIEnv *, jclass, jlong, jstring, jlong);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrIntList
* Signature: (JLjava/lang/String;[J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrIntList(
JNIEnv *, jclass, jlong, jstring, jlongArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrFloat
* Signature: (JLjava/lang/String;F)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrFloat(
JNIEnv *, jclass, jlong, jstring, jfloat);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrFloatList
* Signature: (JLjava/lang/String;[F)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrFloatList(JNIEnv *, jclass,
jlong, jstring,
jfloatArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrBool
* Signature: (JLjava/lang/String;Z)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrBool(
JNIEnv *, jclass, jlong, jstring, jboolean);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrBoolList
* Signature: (JLjava/lang/String;[Z)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrBoolList(JNIEnv *, jclass,
jlong, jstring,
jbooleanArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrType
* Signature: (JLjava/lang/String;I)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrType(
JNIEnv *, jclass, jlong, jstring, jint);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrTypeList
* Signature: (JLjava/lang/String;[I)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrTypeList(JNIEnv *, jclass,
jlong, jstring,
jintArray);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrTensor
* Signature: (JLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
JNIEnv *, jclass, jlong, jstring, jlong);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrShape
* Signature: (JLjava/lang/String;[JI)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
JNIEnv *, jclass, jlong, jstring, jlongArray, jint);
/*
* Class: org_tensorflow_EagerOperationBuilder
* Method: setAttrShapeList
* Signature: (JLjava/lang/String;[J[I)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv *, jclass,
jlong, jstring,
jlongArray,
jintArray);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_

View File

@ -21,18 +21,6 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
TFE_Context* requireContext(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalStateException, "Context has been deleted");
return nullptr;
}
return reinterpret_cast<TFE_Context*>(handle);
}
} // namespace
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
JNIEnv* env, jclass clazz, jboolean async, jint dpp, jbyteArray config) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -74,21 +62,3 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv* env,
if (handle == 0) return;
TFE_DeleteContext(reinterpret_cast<TFE_Context*>(handle));
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation(
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
TFE_Context* context = requireContext(env, handle);
if (context == nullptr) return 0;
const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
env->ReleaseStringUTFChars(name, op_or_function_name);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
"Cannot represent a C TFE_Op as a Java long");
return reinterpret_cast<jlong>(op);
}

View File

@ -38,14 +38,6 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_EagerSession
* Method: allocateOperation
* Signature: (JLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation(
JNIEnv *, jclass, jlong, jstring);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class EagerOperationBuilderTest {
@Test
public void failToCreateIfSessionIsClosed() {
EagerSession session = EagerSession.create();
session.close();
try {
new EagerOperationBuilder(session, "Add", "add");
fail();
} catch (IllegalStateException e) {
}
}
@Test
public void failToBuildOpIfSessionIsClosed() {
EagerOperationBuilder opBuilder;
try (EagerSession session = EagerSession.create()) {
opBuilder = new EagerOperationBuilder(session, "Empty", "empty");
}
try {
opBuilder.setAttr("dtype", DataType.FLOAT);
fail();
} catch (IllegalStateException e) {
}
}
// TODO (karllessard) add more tests when EagerOperation is implemented as well
}

View File

@ -1,4 +1,4 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.