Merge pull request #28212 from karllessard:java-eager-operation-builder
PiperOrigin-RevId: 246208940
This commit is contained in:
commit
cba813e758
@ -146,6 +146,19 @@ tf_java_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_java_test(
|
||||
name = "EagerOperationBuilderTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/EagerOperationBuilderTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.EagerOperationBuilderTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
tf_java_test(
|
||||
name = "GraphTest",
|
||||
size = "small",
|
||||
|
@ -0,0 +1,254 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import java.nio.charset.Charset;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
/**
|
||||
* An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly.
|
||||
*/
|
||||
final class EagerOperationBuilder implements OperationBuilder {
|
||||
|
||||
EagerOperationBuilder(EagerSession session, String type, String name) {
|
||||
this.session = session;
|
||||
this.type = type;
|
||||
this.name = name;
|
||||
this.nativeRef = new NativeReference(session, this, allocate(session.nativeHandle(), type));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Operation build() {
|
||||
// TODO (karllessard) Execute the eager operation and pass output tensor handles to new
|
||||
// EagerOperation class
|
||||
throw new UnsupportedOperationException("Eager execution is not supported yet");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EagerOperationBuilder addInput(Output<?> input) {
|
||||
addInput(nativeRef.opHandle, input.getUnsafeNativeHandle());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder addInputList(Output<?>[] inputs) {
|
||||
long[] inputHandles = new long[inputs.length];
|
||||
for (int i = 0; i < inputs.length; ++i) {
|
||||
inputHandles[i] = inputs[i].getUnsafeNativeHandle();
|
||||
}
|
||||
addInputList(nativeRef.opHandle, inputHandles);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder addControlInput(Operation control) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Control inputs are not supported in an eager execution environment");
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setDevice(String device) {
|
||||
setDevice(nativeRef.opHandle, device);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, String value) {
|
||||
return setAttr(name, value.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, String[] values) {
|
||||
Charset utf8 = StandardCharsets.UTF_8;
|
||||
Object[] objects = new Object[values.length];
|
||||
for (int i = 0; i < values.length; ++i) {
|
||||
objects[i] = values[i].getBytes(utf8);
|
||||
}
|
||||
setAttrStringList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, byte[] values) {
|
||||
setAttrString(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, long value) {
|
||||
setAttrInt(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, long[] values) {
|
||||
setAttrIntList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, float value) {
|
||||
setAttrFloat(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, float[] values) {
|
||||
setAttrFloatList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, boolean value) {
|
||||
setAttrBool(nativeRef.opHandle, name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, boolean[] values) {
|
||||
setAttrBoolList(nativeRef.opHandle, name, values);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, DataType value) {
|
||||
setAttrType(nativeRef.opHandle, name, value.c());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, DataType[] values) {
|
||||
int[] c = new int[values.length];
|
||||
for (int i = 0; i < values.length; ++i) {
|
||||
c[i] = values[i].c();
|
||||
}
|
||||
setAttrTypeList(nativeRef.opHandle, name, c);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Tensor<?> value) {
|
||||
setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Tensor<?>[] values) {
|
||||
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
|
||||
throw new UnsupportedOperationException(
|
||||
"Tensor list attributes are not supported in eager mode");
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Shape value) {
|
||||
setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperationBuilder setAttr(String name, Shape[] values) {
|
||||
int[] numDimensions = new int[values.length];
|
||||
int totalNumDimensions = 0;
|
||||
for (int idx = 0; idx < values.length; ++idx) {
|
||||
int n = values[idx].numDimensions();
|
||||
numDimensions[idx] = n;
|
||||
if (n > 0) {
|
||||
totalNumDimensions += n;
|
||||
}
|
||||
}
|
||||
// Flatten the shapes into a single array to avoid too much overhead in the
|
||||
// native part
|
||||
long[] shapes = new long[totalNumDimensions];
|
||||
int shapeIdx = 0;
|
||||
for (Shape shape : values) {
|
||||
if (shape.numDimensions() > 0) {
|
||||
for (long dim : shape.asArray()) {
|
||||
shapes[shapeIdx++] = dim;
|
||||
}
|
||||
}
|
||||
}
|
||||
setAttrShapeList(nativeRef.opHandle, name, shapes, numDimensions);
|
||||
return this;
|
||||
}
|
||||
|
||||
private static class NativeReference extends EagerSession.NativeReference {
|
||||
|
||||
NativeReference(EagerSession session, EagerOperationBuilder operation, long opHandle) {
|
||||
super(session, operation);
|
||||
this.opHandle = opHandle;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear() {
|
||||
super.clear();
|
||||
opHandle = 0L;
|
||||
}
|
||||
|
||||
@Override
|
||||
synchronized void delete() {
|
||||
if (opHandle != 0L) {
|
||||
EagerOperationBuilder.delete(opHandle);
|
||||
opHandle = 0L;
|
||||
}
|
||||
}
|
||||
|
||||
private long opHandle;
|
||||
}
|
||||
|
||||
private final EagerSession session;
|
||||
private final String type;
|
||||
private final String name;
|
||||
private final NativeReference nativeRef;
|
||||
|
||||
private static native long allocate(long ctxHandle, String type);
|
||||
|
||||
private static native void delete(long opHandle);
|
||||
|
||||
private static native long[] execute(long opHandle);
|
||||
|
||||
private static native void addInput(long opHandle, long tensorHandle);
|
||||
|
||||
private static native void addInputList(long opHandle, long[] tensorHandles);
|
||||
|
||||
private static native void setDevice(long opHandle, String device);
|
||||
|
||||
private static native void setAttrString(long opHandle, String name, byte[] value);
|
||||
|
||||
private static native void setAttrStringList(long opHandle, String name, Object[] value);
|
||||
|
||||
private static native void setAttrInt(long opHandle, String name, long value);
|
||||
|
||||
private static native void setAttrIntList(long opHandle, String name, long[] values);
|
||||
|
||||
private static native void setAttrFloat(long opHandle, String name, float value);
|
||||
|
||||
private static native void setAttrFloatList(long opHandle, String name, float[] values);
|
||||
|
||||
private static native void setAttrBool(long opHandle, String name, boolean value);
|
||||
|
||||
private static native void setAttrBoolList(long opHandle, String name, boolean[] values);
|
||||
|
||||
private static native void setAttrType(long opHandle, String name, int type);
|
||||
|
||||
private static native void setAttrTypeList(long opHandle, String name, int[] types);
|
||||
|
||||
private static native void setAttrTensor(long opHandle, String name, long tensorHandle);
|
||||
|
||||
private static native void setAttrShape(long opHandle, String name, long[] shape, int numDims);
|
||||
|
||||
private static native void setAttrShapeList(
|
||||
long opHandle, String name, long[] shapes, int[] numDims);
|
||||
}
|
@ -245,9 +245,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
|
||||
nativeResources.tryCleanup();
|
||||
}
|
||||
checkSession();
|
||||
// TODO (karllessard) create a new EagerOperationBuilder
|
||||
throw new UnsupportedOperationException(
|
||||
"Eager execution mode is not yet supported in this version of TensorFlow");
|
||||
return new EagerOperationBuilder(this, type, name);
|
||||
}
|
||||
|
||||
long nativeHandle() {
|
||||
checkSession();
|
||||
return nativeHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -408,8 +411,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable {
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
||||
private static native long allocateOperation(long contextHandle, String name);
|
||||
|
||||
static {
|
||||
TensorFlow.init();
|
||||
}
|
||||
|
@ -523,7 +523,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
|
||||
*/
|
||||
public static Constant<String> create(Scope scope, String data, Charset charset) {
|
||||
try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
|
||||
return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class));
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
|
332
tensorflow/java/src/main/native/eager_operation_builder_jni.cc
Normal file
332
tensorflow/java/src/main/native/eager_operation_builder_jni.cc
Normal file
@ -0,0 +1,332 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/java/src/main/native/eager_operation_builder_jni.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_Op* requireOp(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException,
|
||||
"Operation has already been built");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_Op*>(handle);
|
||||
}
|
||||
|
||||
TFE_Context* requireContext(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException, "Context has been deleted");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_Context*>(handle);
|
||||
}
|
||||
|
||||
TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException,
|
||||
"close() has been called on the Tensor");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TF_Tensor*>(handle);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException,
|
||||
"Tensor handle has been deleted");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_TensorHandle*>(handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
|
||||
JNIEnv* env, jclass clazz, jlong context_handle, jstring name) {
|
||||
TFE_Context* context = requireContext(env, context_handle);
|
||||
if (context == nullptr) return 0;
|
||||
const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
|
||||
env->ReleaseStringUTFChars(name, op_or_function_name);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
|
||||
"Cannot represent a C TFE_Op as a Java long");
|
||||
return reinterpret_cast<jlong>(op);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle) {
|
||||
if (op_handle == 0) return;
|
||||
TFE_DeleteOp(reinterpret_cast<TFE_Op*>(op_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return 0;
|
||||
int num_retvals = 8; // should be >= than max number of outputs in any op
|
||||
std::unique_ptr<TFE_TensorHandle*[]> retvals(
|
||||
new TFE_TensorHandle*[num_retvals]);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Execute(op, retvals.get(), &num_retvals, status);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
jlongArray rethandles = env->NewLongArray(num_retvals);
|
||||
if (num_retvals > 0) {
|
||||
jlong* retval = env->GetLongArrayElements(rethandles, nullptr);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
retval[i] = reinterpret_cast<jlong>(retvals[i]);
|
||||
}
|
||||
env->ReleaseLongArrayElements(rethandles, retval, 0);
|
||||
}
|
||||
return rethandles;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
const char* cname = env->GetStringUTFChars(device_name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpSetDevice(op, cname, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
env->ReleaseStringUTFChars(device_name, cname);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle);
|
||||
if (tensor_handle == nullptr) return;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpAddInput(op, tensor_handle, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr);
|
||||
size_t num_inputs = static_cast<size_t>(env->GetArrayLength(input_handles));
|
||||
std::unique_ptr<TFE_TensorHandle*[]> tensor_handles(
|
||||
new TFE_TensorHandle*[num_inputs]);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]);
|
||||
if (tensor_handles[i] == nullptr) {
|
||||
env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
|
||||
return;
|
||||
}
|
||||
}
|
||||
env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
|
||||
jbyteArray value) {
|
||||
static_assert(sizeof(jbyte) == 1,
|
||||
"Require Java byte to be represented as a single byte");
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
|
||||
jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
|
||||
TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value));
|
||||
env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
|
||||
env->ReleaseStringUTFChars(attr_name, cname);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(
|
||||
JNIEnv* env, jclass object, jlong op_handle, jstring attr_name,
|
||||
jobjectArray values) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
|
||||
int num_values = env->GetArrayLength(values);
|
||||
static_assert(sizeof(jbyte) == 1,
|
||||
"Require Java byte to be represented as a single byte");
|
||||
std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
|
||||
std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
|
||||
std::unique_ptr<void*[]> cvalues(new void*[num_values]);
|
||||
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
|
||||
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
jbyteArray v =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
|
||||
jarrays[i] = v;
|
||||
jvalues[i] = env->GetByteArrayElements(v, nullptr);
|
||||
cvalues[i] = jvalues[i];
|
||||
lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
|
||||
}
|
||||
TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
|
||||
}
|
||||
env->ReleaseStringUTFChars(attr_name, cname);
|
||||
}
|
||||
|
||||
#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
|
||||
JNIEXPORT void JNICALL \
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
|
||||
jtype value) { \
|
||||
static_assert( \
|
||||
sizeof(ctype) >= sizeof(jtype), \
|
||||
"Information loss when converting between Java and C types"); \
|
||||
TFE_Op* op = requireOp(env, op_handle); \
|
||||
if (op == nullptr) return; \
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
|
||||
TFE_OpSetAttr##name(op, cname, static_cast<ctype>(value)); \
|
||||
env->ReleaseStringUTFChars(attr_name, cname); \
|
||||
}
|
||||
|
||||
#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
|
||||
JNIEXPORT void JNICALL \
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
|
||||
jtype##Array value) { \
|
||||
TFE_Op* op = requireOp(env, op_handle); \
|
||||
if (op == nullptr) return; \
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
|
||||
/* Make a copy of the array to paper over any differences */ \
|
||||
/* in byte representations of the jtype and ctype */ \
|
||||
/* For example, jint vs TF_DataType. */ \
|
||||
/* If this copy turns out to be a problem in practice */ \
|
||||
/* can avoid it for many types. */ \
|
||||
const int n = env->GetArrayLength(value); \
|
||||
std::unique_ptr<ctype[]> cvalue(new ctype[n]); \
|
||||
jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
|
||||
for (int i = 0; i < n; ++i) { \
|
||||
cvalue[i] = static_cast<ctype>(elems[i]); \
|
||||
} \
|
||||
TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \
|
||||
env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
|
||||
env->ReleaseStringUTFChars(attr_name, cname); \
|
||||
}
|
||||
|
||||
#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
|
||||
DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
|
||||
DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
|
||||
|
||||
DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
|
||||
DEFINE_SET_ATTR(Float, Float, jfloat, float);
|
||||
DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
|
||||
DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
|
||||
#undef DEFINE_SET_ATTR
|
||||
#undef DEFINE_SET_ATTR_LIST
|
||||
#undef DEFINE_SET_ATTR_SCALAR
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring attr_name,
|
||||
jlong tensor_handle) {
|
||||
TFE_Op* op = requireOp(env, handle);
|
||||
if (op == nullptr) return;
|
||||
TF_Tensor* t = requireTensor(env, tensor_handle);
|
||||
if (t == nullptr) return;
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpSetAttrTensor(op, cname, t, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
env->ReleaseStringUTFChars(attr_name, cname);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
|
||||
jlongArray shape, jint num_dims) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
std::unique_ptr<int64_t[]> cvalue;
|
||||
// num_dims and env->GetArrayLength(shape) are assumed to be consistent.
|
||||
// i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
|
||||
if (num_dims > 0) {
|
||||
cvalue.reset(new int64_t[num_dims]);
|
||||
jlong* elems = env->GetLongArrayElements(shape, nullptr);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
cvalue[i] = static_cast<int64_t>(elems[i]);
|
||||
}
|
||||
env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
|
||||
}
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast<int>(num_dims),
|
||||
status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
env->ReleaseStringUTFChars(attr_name, cname);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(
|
||||
JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
|
||||
jlongArray shapes, jintArray num_dims) {
|
||||
TFE_Op* op = requireOp(env, op_handle);
|
||||
if (op == nullptr) return;
|
||||
std::unique_ptr<int64_t[]> cshapes;
|
||||
std::unique_ptr<const int64_t*[]> cdims;
|
||||
std::unique_ptr<int[]> cnum_dims;
|
||||
const int num_dims_length = env->GetArrayLength(num_dims);
|
||||
if (num_dims_length > 0) {
|
||||
const int shapes_length = env->GetArrayLength(shapes);
|
||||
cshapes.reset(new int64_t[shapes_length]);
|
||||
cdims.reset(new const int64_t*[num_dims_length]);
|
||||
cnum_dims.reset(new int[num_dims_length]);
|
||||
jlong* shapes_elems =
|
||||
static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
|
||||
std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
|
||||
env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
|
||||
int64_t* cshapes_ptr = cshapes.get();
|
||||
jint* num_dims_elems =
|
||||
static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
|
||||
for (int i = 0; i < num_dims_length; ++i) {
|
||||
cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
|
||||
cdims[i] = cshapes_ptr;
|
||||
if (cnum_dims[i] > 0) {
|
||||
cshapes_ptr += cnum_dims[i];
|
||||
}
|
||||
}
|
||||
env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
|
||||
}
|
||||
const char* cname = env->GetStringUTFChars(attr_name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(),
|
||||
num_dims_length, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
TF_DeleteStatus(status);
|
||||
env->ReleaseStringUTFChars(attr_name, cname);
|
||||
}
|
191
tensorflow/java/src/main/native/eager_operation_builder_jni.h
Normal file
191
tensorflow/java/src/main/native/eager_operation_builder_jni.h
Normal file
@ -0,0 +1,191 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
|
||||
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: allocate
|
||||
* Signature: (JLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
|
||||
JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: delete
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: execute
|
||||
* Signature: (J)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: addInput
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
|
||||
JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: addInputList
|
||||
* Signature: (J[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
|
||||
JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setDevice
|
||||
* Signature: (JLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
|
||||
JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrString
|
||||
* Signature: (JLjava/lang/String;[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
|
||||
JNIEnv *, jclass, jlong, jstring, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrStringList
|
||||
* Signature: (JLjava/lang/String;[L)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv *, jclass,
|
||||
jlong, jstring,
|
||||
jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrInt
|
||||
* Signature: (JLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrInt(
|
||||
JNIEnv *, jclass, jlong, jstring, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrIntList
|
||||
* Signature: (JLjava/lang/String;[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrIntList(
|
||||
JNIEnv *, jclass, jlong, jstring, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrFloat
|
||||
* Signature: (JLjava/lang/String;F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrFloat(
|
||||
JNIEnv *, jclass, jlong, jstring, jfloat);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrFloatList
|
||||
* Signature: (JLjava/lang/String;[F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrFloatList(JNIEnv *, jclass,
|
||||
jlong, jstring,
|
||||
jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrBool
|
||||
* Signature: (JLjava/lang/String;Z)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrBool(
|
||||
JNIEnv *, jclass, jlong, jstring, jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrBoolList
|
||||
* Signature: (JLjava/lang/String;[Z)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrBoolList(JNIEnv *, jclass,
|
||||
jlong, jstring,
|
||||
jbooleanArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrType
|
||||
* Signature: (JLjava/lang/String;I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrType(
|
||||
JNIEnv *, jclass, jlong, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrTypeList
|
||||
* Signature: (JLjava/lang/String;[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrTypeList(JNIEnv *, jclass,
|
||||
jlong, jstring,
|
||||
jintArray);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrTensor
|
||||
* Signature: (JLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
|
||||
JNIEnv *, jclass, jlong, jstring, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrShape
|
||||
* Signature: (JLjava/lang/String;[JI)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
|
||||
JNIEnv *, jclass, jlong, jstring, jlongArray, jint);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerOperationBuilder
|
||||
* Method: setAttrShapeList
|
||||
* Signature: (JLjava/lang/String;[J[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv *, jclass,
|
||||
jlong, jstring,
|
||||
jlongArray,
|
||||
jintArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_
|
@ -21,18 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_Context* requireContext(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
throwException(env, kIllegalStateException, "Context has been deleted");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TFE_Context*>(handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
|
||||
JNIEnv* env, jclass clazz, jboolean async, jint dpp, jbyteArray config) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -74,21 +62,3 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv* env,
|
||||
if (handle == 0) return;
|
||||
TFE_DeleteContext(reinterpret_cast<TFE_Context*>(handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring name) {
|
||||
TFE_Context* context = requireContext(env, handle);
|
||||
if (context == nullptr) return 0;
|
||||
const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
|
||||
env->ReleaseStringUTFChars(name, op_or_function_name);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return 0;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
|
||||
"Cannot represent a C TFE_Op as a Java long");
|
||||
return reinterpret_cast<jlong>(op);
|
||||
}
|
||||
|
@ -38,14 +38,6 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate(
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv *, jclass,
|
||||
jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_EagerSession
|
||||
* Method: allocateOperation
|
||||
* Signature: (JLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation(
|
||||
JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public class EagerOperationBuilderTest {
|
||||
|
||||
@Test
|
||||
public void failToCreateIfSessionIsClosed() {
|
||||
EagerSession session = EagerSession.create();
|
||||
session.close();
|
||||
try {
|
||||
new EagerOperationBuilder(session, "Add", "add");
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failToBuildOpIfSessionIsClosed() {
|
||||
EagerOperationBuilder opBuilder;
|
||||
try (EagerSession session = EagerSession.create()) {
|
||||
opBuilder = new EagerOperationBuilder(session, "Empty", "empty");
|
||||
}
|
||||
try {
|
||||
opBuilder.setAttr("dtype", DataType.FLOAT);
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (karllessard) add more tests when EagerOperation is implemented as well
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
Loading…
Reference in New Issue
Block a user