diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index e8b00547b89..a23ae9c60e9 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -146,6 +146,19 @@ tf_java_test( ], ) +tf_java_test( + name = "EagerOperationBuilderTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/EagerOperationBuilderTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.EagerOperationBuilderTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + tf_java_test( name = "GraphTest", size = "small", diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java new file mode 100644 index 00000000000..a8cb5b5c318 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -0,0 +1,254 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +/** + * An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly. + */ +final class EagerOperationBuilder implements OperationBuilder { + + EagerOperationBuilder(EagerSession session, String type, String name) { + this.session = session; + this.type = type; + this.name = name; + this.nativeRef = new NativeReference(session, this, allocate(session.nativeHandle(), type)); + } + + @Override + public Operation build() { + // TODO (karllessard) Execute the eager operation and pass output tensor handles to new + // EagerOperation class + throw new UnsupportedOperationException("Eager execution is not supported yet"); + } + + @Override + public EagerOperationBuilder addInput(Output<?> input) { + addInput(nativeRef.opHandle, input.getUnsafeNativeHandle()); + return this; + } + + @Override + public OperationBuilder addInputList(Output<?>[] inputs) { + long[] inputHandles = new long[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + inputHandles[i] = inputs[i].getUnsafeNativeHandle(); + } + addInputList(nativeRef.opHandle, inputHandles); + return this; + } + + @Override + public OperationBuilder addControlInput(Operation control) { + throw new UnsupportedOperationException( + "Control inputs are not supported in an eager execution environment"); + } + + @Override + public OperationBuilder setDevice(String device) { + setDevice(nativeRef.opHandle, device); + return this; + } + + @Override + public OperationBuilder setAttr(String name, String value) { + return setAttr(name, value.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public OperationBuilder setAttr(String name, String[] values) { + Charset utf8 = StandardCharsets.UTF_8; + Object[] objects = new Object[values.length]; + for (int i = 0; i < values.length; ++i) { + objects[i] = values[i].getBytes(utf8); + } + setAttrStringList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public OperationBuilder setAttr(String name, byte[] values) { + setAttrString(nativeRef.opHandle, name, values); + return this; + } + + @Override + public OperationBuilder setAttr(String name, long value) { + setAttrInt(nativeRef.opHandle, name, value); + return this; + } + + @Override + public OperationBuilder setAttr(String name, long[] values) { + setAttrIntList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public OperationBuilder setAttr(String name, float value) { + setAttrFloat(nativeRef.opHandle, name, value); + return this; + } + + @Override + public OperationBuilder setAttr(String name, float[] values) { + setAttrFloatList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public OperationBuilder setAttr(String name, boolean value) { + setAttrBool(nativeRef.opHandle, name, value); + return this; + } + + @Override + public OperationBuilder setAttr(String name, boolean[] values) { + setAttrBoolList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public OperationBuilder setAttr(String name, DataType value) { + setAttrType(nativeRef.opHandle, name, value.c()); + return this; + } + + @Override + public OperationBuilder setAttr(String name, DataType[] values) { + int[] c = new int[values.length]; + for (int i = 0; i < values.length; ++i) { + c[i] = values[i].c(); + } + setAttrTypeList(nativeRef.opHandle, name, c); + return this; + } + + @Override + public OperationBuilder setAttr(String name, Tensor<?> value) { + setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle()); + return this; + } + + @Override + public OperationBuilder setAttr(String name, Tensor<?>[] values) { + // TODO (karllessard) could be supported by adding this attribute type in the eager C API + throw new UnsupportedOperationException( + "Tensor list attributes are not supported in eager mode"); + } + + @Override + public OperationBuilder setAttr(String name, Shape value) { + setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions()); + return this; + } + + @Override + public OperationBuilder setAttr(String name, Shape[] values) { + int[] numDimensions = new int[values.length]; + int totalNumDimensions = 0; + for (int idx = 0; idx < values.length; ++idx) { + int n = values[idx].numDimensions(); + numDimensions[idx] = n; + if (n > 0) { + totalNumDimensions += n; + } + } + // Flatten the shapes into a single array to avoid too much overhead in the + // native part + long[] shapes = new long[totalNumDimensions]; + int shapeIdx = 0; + for (Shape shape : values) { + if (shape.numDimensions() > 0) { + for (long dim : shape.asArray()) { + shapes[shapeIdx++] = dim; + } + } + } + setAttrShapeList(nativeRef.opHandle, name, shapes, numDimensions); + return this; + } + + private static class NativeReference extends EagerSession.NativeReference { + + NativeReference(EagerSession session, EagerOperationBuilder operation, long opHandle) { + super(session, operation); + this.opHandle = opHandle; + } + + @Override + public void clear() { + super.clear(); + opHandle = 0L; + } + + @Override + synchronized void delete() { + if (opHandle != 0L) { + EagerOperationBuilder.delete(opHandle); + opHandle = 0L; + } + } + + private long opHandle; + } + + private final EagerSession session; + private final String type; + private final String name; + private final NativeReference nativeRef; + + private static native long allocate(long ctxHandle, String type); + + private static native void delete(long opHandle); + + private static native long[] execute(long opHandle); + + private static native void addInput(long opHandle, long tensorHandle); + + private static native void addInputList(long opHandle, long[] tensorHandles); + + private static native void setDevice(long opHandle, String device); + + private static native void setAttrString(long opHandle, String name, byte[] value); + + private static native void setAttrStringList(long opHandle, String name, Object[] value); + + private static native void setAttrInt(long opHandle, String name, long value); + + private static native void setAttrIntList(long opHandle, String name, long[] values); + + private static native void setAttrFloat(long opHandle, String name, float value); + + private static native void setAttrFloatList(long opHandle, String name, float[] values); + + private static native void setAttrBool(long opHandle, String name, boolean value); + + private static native void setAttrBoolList(long opHandle, String name, boolean[] values); + + private static native void setAttrType(long opHandle, String name, int type); + + private static native void setAttrTypeList(long opHandle, String name, int[] types); + + private static native void setAttrTensor(long opHandle, String name, long tensorHandle); + + private static native void setAttrShape(long opHandle, String name, long[] shape, int numDims); + + private static native void setAttrShapeList( + long opHandle, String name, long[] shapes, int[] numDims); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java index fced103bf44..7f36da173e6 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java @@ -245,9 +245,12 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { nativeResources.tryCleanup(); } checkSession(); - // TODO (karllessard) create a new EagerOperationBuilder - throw new UnsupportedOperationException( - "Eager execution mode is not yet supported in this version of TensorFlow"); + return new EagerOperationBuilder(this, type, name); + } + + long nativeHandle() { + checkSession(); + return nativeHandle; } /** @@ -408,8 +411,6 @@ public final class EagerSession implements ExecutionEnvironment, AutoCloseable { private static native void delete(long handle); - private static native long allocateOperation(long contextHandle, String name); - static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index a712204dd47..ee4301f1159 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -523,7 +523,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { */ public static Constant<String> create(Scope scope, String data, Charset charset) { try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) { - return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + return createWithTensor(scope, value); } } diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc new file mode 100644 index 00000000000..654c453176b --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc @@ -0,0 +1,332 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/eager_operation_builder_jni.h" + +#include <cstring> +#include <memory> +#include <set> + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +namespace { + +TFE_Op* requireOp(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Operation has already been built"); + return nullptr; + } + return reinterpret_cast<TFE_Op*>(handle); +} + +TFE_Context* requireContext(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, "Context has been deleted"); + return nullptr; + } + return reinterpret_cast<TFE_Context*>(handle); +} + +TF_Tensor* requireTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "close() has been called on the Tensor"); + return nullptr; + } + return reinterpret_cast<TF_Tensor*>(handle); +} + +TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Tensor handle has been deleted"); + return nullptr; + } + return reinterpret_cast<TFE_TensorHandle*>(handle); +} + +} // namespace + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate( + JNIEnv* env, jclass clazz, jlong context_handle, jstring name) { + TFE_Context* context = requireContext(env, context_handle); + if (context == nullptr) return 0; + const char* op_or_function_name = env->GetStringUTFChars(name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(context, op_or_function_name, status); + env->ReleaseStringUTFChars(name, op_or_function_name); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + static_assert(sizeof(jlong) >= sizeof(TFE_Op*), + "Cannot represent a C TFE_Op as a Java long"); + return reinterpret_cast<jlong>(op); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete( + JNIEnv* env, jclass clazz, jlong op_handle) { + if (op_handle == 0) return; + TFE_DeleteOp(reinterpret_cast<TFE_Op*>(op_handle)); +} + +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute( + JNIEnv* env, jclass clazz, jlong op_handle) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return 0; + int num_retvals = 8; // should be >= than max number of outputs in any op + std::unique_ptr<TFE_TensorHandle*[]> retvals( + new TFE_TensorHandle*[num_retvals]); + TF_Status* status = TF_NewStatus(); + TFE_Execute(op, retvals.get(), &num_retvals, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + jlongArray rethandles = env->NewLongArray(num_retvals); + if (num_retvals > 0) { + jlong* retval = env->GetLongArrayElements(rethandles, nullptr); + for (int i = 0; i < num_retvals; ++i) { + retval[i] = reinterpret_cast<jlong>(retvals[i]); + } + env->ReleaseLongArrayElements(rethandles, retval, 0); + } + return rethandles; +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice( + JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(device_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetDevice(op, cname, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(device_name, cname); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput( + JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle); + if (tensor_handle == nullptr) return; + TF_Status* status = TF_NewStatus(); + TFE_OpAddInput(op, tensor_handle, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList( + JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr); + size_t num_inputs = static_cast<size_t>(env->GetArrayLength(input_handles)); + std::unique_ptr<TFE_TensorHandle*[]> tensor_handles( + new TFE_TensorHandle*[num_inputs]); + for (int i = 0; i < num_inputs; ++i) { + tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]); + if (tensor_handles[i] == nullptr) { + env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT); + return; + } + } + env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT); + TF_Status* status = TF_NewStatus(); + TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jbyteArray value) { + static_assert(sizeof(jbyte) == 1, + "Require Java byte to be represented as a single byte"); + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + jbyte* cvalue = env->GetByteArrayElements(value, nullptr); + TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value)); + env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrStringList( + JNIEnv* env, jclass object, jlong op_handle, jstring attr_name, + jobjectArray values) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + int num_values = env->GetArrayLength(values); + static_assert(sizeof(jbyte) == 1, + "Require Java byte to be represented as a single byte"); + std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]); + std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]); + std::unique_ptr<void*[]> cvalues(new void*[num_values]); + std::unique_ptr<size_t[]> lengths(new size_t[num_values]); + + for (int i = 0; i < num_values; ++i) { + jbyteArray v = + static_cast<jbyteArray>(env->GetObjectArrayElement(values, i)); + jarrays[i] = v; + jvalues[i] = env->GetByteArrayElements(v, nullptr); + cvalues[i] = jvalues[i]; + lengths[i] = static_cast<size_t>(env->GetArrayLength(v)); + } + TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values); + for (int i = 0; i < num_values; ++i) { + env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT); + } + env->ReleaseStringUTFChars(attr_name, cname); +} + +#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \ + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \ + jtype value) { \ + static_assert( \ + sizeof(ctype) >= sizeof(jtype), \ + "Information loss when converting between Java and C types"); \ + TFE_Op* op = requireOp(env, op_handle); \ + if (op == nullptr) return; \ + const char* cname = env->GetStringUTFChars(attr_name, nullptr); \ + TFE_OpSetAttr##name(op, cname, static_cast<ctype>(value)); \ + env->ReleaseStringUTFChars(attr_name, cname); \ + } + +#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \ + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \ + jtype##Array value) { \ + TFE_Op* op = requireOp(env, op_handle); \ + if (op == nullptr) return; \ + const char* cname = env->GetStringUTFChars(attr_name, nullptr); \ + /* Make a copy of the array to paper over any differences */ \ + /* in byte representations of the jtype and ctype */ \ + /* For example, jint vs TF_DataType. */ \ + /* If this copy turns out to be a problem in practice */ \ + /* can avoid it for many types. */ \ + const int n = env->GetArrayLength(value); \ + std::unique_ptr<ctype[]> cvalue(new ctype[n]); \ + jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \ + for (int i = 0; i < n; ++i) { \ + cvalue[i] = static_cast<ctype>(elems[i]); \ + } \ + TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \ + env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \ + env->ReleaseStringUTFChars(attr_name, cname); \ + } + +#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \ + DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ + DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) + +DEFINE_SET_ATTR(Int, Long, jlong, int64_t); +DEFINE_SET_ATTR(Float, Float, jfloat, float); +DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char); +DEFINE_SET_ATTR(Type, Int, jint, TF_DataType); +#undef DEFINE_SET_ATTR +#undef DEFINE_SET_ATTR_LIST +#undef DEFINE_SET_ATTR_SCALAR + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor( + JNIEnv* env, jclass clazz, jlong handle, jstring attr_name, + jlong tensor_handle) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return; + TF_Tensor* t = requireTensor(env, tensor_handle); + if (t == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrTensor(op, cname, t, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jlongArray shape, jint num_dims) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + std::unique_ptr<int64_t[]> cvalue; + // num_dims and env->GetArrayLength(shape) are assumed to be consistent. + // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape). + if (num_dims > 0) { + cvalue.reset(new int64_t[num_dims]); + jlong* elems = env->GetLongArrayElements(shape, nullptr); + for (int i = 0; i < num_dims; ++i) { + cvalue[i] = static_cast<int64_t>(elems[i]); + } + env->ReleaseLongArrayElements(shape, elems, JNI_ABORT); + } + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast<int>(num_dims), + status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jlongArray shapes, jintArray num_dims) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + std::unique_ptr<int64_t[]> cshapes; + std::unique_ptr<const int64_t*[]> cdims; + std::unique_ptr<int[]> cnum_dims; + const int num_dims_length = env->GetArrayLength(num_dims); + if (num_dims_length > 0) { + const int shapes_length = env->GetArrayLength(shapes); + cshapes.reset(new int64_t[shapes_length]); + cdims.reset(new const int64_t*[num_dims_length]); + cnum_dims.reset(new int[num_dims_length]); + jlong* shapes_elems = + static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr)); + std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3); + env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT); + int64_t* cshapes_ptr = cshapes.get(); + jint* num_dims_elems = + static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr)); + for (int i = 0; i < num_dims_length; ++i) { + cnum_dims[i] = static_cast<int>(num_dims_elems[i]); + cdims[i] = cshapes_ptr; + if (cnum_dims[i] > 0) { + cshapes_ptr += cnum_dims[i]; + } + } + env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT); + } + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(), + num_dims_length, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.h b/tensorflow/java/src/main/native/eager_operation_builder_jni.h new file mode 100644 index 00000000000..6da891d7ae2 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.h @@ -0,0 +1,191 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ + +#include <jni.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: allocate + * Signature: (JLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate( + JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: delete + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: execute + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: addInput + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput( + JNIEnv *, jclass, jlong, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: addInputList + * Signature: (J[J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList( + JNIEnv *, jclass, jlong, jlongArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setDevice + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice( + JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrString + * Signature: (JLjava/lang/String;[B)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString( + JNIEnv *, jclass, jlong, jstring, jbyteArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrStringList + * Signature: (JLjava/lang/String;[L)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv *, jclass, + jlong, jstring, + jobjectArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrInt + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrInt( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrIntList + * Signature: (JLjava/lang/String;[J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrIntList( + JNIEnv *, jclass, jlong, jstring, jlongArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrFloat + * Signature: (JLjava/lang/String;F)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrFloat( + JNIEnv *, jclass, jlong, jstring, jfloat); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrFloatList + * Signature: (JLjava/lang/String;[F)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrFloatList(JNIEnv *, jclass, + jlong, jstring, + jfloatArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrBool + * Signature: (JLjava/lang/String;Z)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrBool( + JNIEnv *, jclass, jlong, jstring, jboolean); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrBoolList + * Signature: (JLjava/lang/String;[Z)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrBoolList(JNIEnv *, jclass, + jlong, jstring, + jbooleanArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrType + * Signature: (JLjava/lang/String;I)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrType( + JNIEnv *, jclass, jlong, jstring, jint); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrTypeList + * Signature: (JLjava/lang/String;[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrTypeList(JNIEnv *, jclass, + jlong, jstring, + jintArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrTensor + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrShape + * Signature: (JLjava/lang/String;[JI)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape( + JNIEnv *, jclass, jlong, jstring, jlongArray, jint); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrShapeList + * Signature: (JLjava/lang/String;[J[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv *, jclass, + jlong, jstring, + jlongArray, + jintArray); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ diff --git a/tensorflow/java/src/main/native/eager_session_jni.cc b/tensorflow/java/src/main/native/eager_session_jni.cc index eb6e9f0b581..58905205c94 100644 --- a/tensorflow/java/src/main/native/eager_session_jni.cc +++ b/tensorflow/java/src/main/native/eager_session_jni.cc @@ -21,18 +21,6 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/java/src/main/native/exception_jni.h" -namespace { - -TFE_Context* requireContext(JNIEnv* env, jlong handle) { - if (handle == 0) { - throwException(env, kIllegalStateException, "Context has been deleted"); - return nullptr; - } - return reinterpret_cast<TFE_Context*>(handle); -} - -} // namespace - JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate( JNIEnv* env, jclass clazz, jboolean async, jint dpp, jbyteArray config) { TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -74,21 +62,3 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv* env, if (handle == 0) return; TFE_DeleteContext(reinterpret_cast<TFE_Context*>(handle)); } - -JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation( - JNIEnv* env, jclass clazz, jlong handle, jstring name) { - TFE_Context* context = requireContext(env, handle); - if (context == nullptr) return 0; - const char* op_or_function_name = env->GetStringUTFChars(name, nullptr); - TF_Status* status = TF_NewStatus(); - TFE_Op* op = TFE_NewOp(context, op_or_function_name, status); - env->ReleaseStringUTFChars(name, op_or_function_name); - if (!throwExceptionIfNotOK(env, status)) { - TF_DeleteStatus(status); - return 0; - } - TF_DeleteStatus(status); - static_assert(sizeof(jlong) >= sizeof(TFE_Op*), - "Cannot represent a C TFE_Op as a Java long"); - return reinterpret_cast<jlong>(op); -} diff --git a/tensorflow/java/src/main/native/eager_session_jni.h b/tensorflow/java/src/main/native/eager_session_jni.h index ffe9d549012..9f7bdaccd36 100644 --- a/tensorflow/java/src/main/native/eager_session_jni.h +++ b/tensorflow/java/src/main/native/eager_session_jni.h @@ -38,14 +38,6 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate( JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv *, jclass, jlong); -/* - * Class: org_tensorflow_EagerSession - * Method: allocateOperation - * Signature: (JLjava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocateOperation( - JNIEnv *, jclass, jlong, jstring); - #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java new file mode 100644 index 00000000000..38534297e8c --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class EagerOperationBuilderTest { + + @Test + public void failToCreateIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + session.close(); + try { + new EagerOperationBuilder(session, "Add", "add"); + fail(); + } catch (IllegalStateException e) { + } + } + + @Test + public void failToBuildOpIfSessionIsClosed() { + EagerOperationBuilder opBuilder; + try (EagerSession session = EagerSession.create()) { + opBuilder = new EagerOperationBuilder(session, "Empty", "empty"); + } + try { + opBuilder.setAttr("dtype", DataType.FLOAT); + fail(); + } catch (IllegalStateException e) { + } + } + + // TODO (karllessard) add more tests when EagerOperation is implemented as well +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java index 68673f5cf30..77f38bb6160 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.