diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index a5e19d7a8f6..68a52f2162b 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -17,6 +17,10 @@ package org.tensorflow.contrib.android;
import android.content.res.AssetManager;
import android.util.Log;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
import java.util.Random;
/**
@@ -91,16 +95,153 @@ public class TensorFlowInferenceInterface {
*/
public native void close();
- // Methods for creating a native Tensor and filling it with values.
- public native void fillNodeFloat(String inputName, int[] dims, float[] values);
- public native void fillNodeInt(String inputName, int[] dims, int[] values);
- public native void fillNodeDouble(String inputName, int[] dims, double[] values);
- public native void fillNodeByte(String inputName, int[] dims, byte[] values);
+ // Methods for taking a native Tensor and filling it with values from Java arrays.
- public native void readNodeFloat(String outputName, float[] values);
- public native void readNodeInt(String outputName, int[] values);
- public native void readNodeDouble(String outputName, double[] values);
- public native void readNodeByte(String outputName, byte[] values);
+ /**
+ * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
+ * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
+ * as many elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeFloat(String inputName, int[] dims, float[] src);
+
+ /**
+ * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
+ * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
+ * as many elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeInt(String inputName, int[] dims, int[] src);
+
+ /**
+ * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
+ * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
+ * as many elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeDouble(String inputName, int[] dims, double[] src);
+
+ /**
+ * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
+ * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
+ * as many elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeByte(String inputName, int[] dims, byte[] src);
+
+ // Methods for taking a native Tensor and filling it with src from Java native IO buffers.
+
+ /**
+ * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
+ * direct and native ordered java.nio buffers, copy the contents into the input
+ * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
+ * elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeFromFloatBuffer(String inputName, IntBuffer dims, FloatBuffer src);
+
+ /**
+ * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
+ * direct and native ordered java.nio buffers, copy the contents into the input
+ * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
+ * elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeFromIntBuffer(String inputName, IntBuffer dims, IntBuffer src);
+
+ /**
+ * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
+ * direct and native ordered java.nio buffers, copy the contents into the input
+ * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
+ * elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeFromDoubleBuffer(String inputName, IntBuffer dims, DoubleBuffer src);
+
+ /**
+ * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
+ * direct and native ordered java.nio buffers, copy the contents into the input
+ * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
+ * elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
+ public native void fillNodeFromByteBuffer(String inputName, IntBuffer dims, ByteBuffer src);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
+ * dst} must have length greater than or equal to that of the source Tensor. This operation will
+ * not affect dst's content past the source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeFloat(String outputName, float[] dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
+ * dst} must have length greater than or equal to that of the source Tensor. This operation will
+ * not affect dst's content past the source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeInt(String outputName, int[] dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
+ * dst} must have length greater than or equal to that of the source Tensor. This operation will
+ * not affect dst's content past the source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeDouble(String outputName, double[] dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
+ * dst} must have length greater than or equal to that of the source Tensor. This operation will
+ * not affect dst's content past the source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeByte(String outputName, byte[] dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into the direct and
+ * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than
+ * or equal to that of the source Tensor. This operation will not affect dst's content past the
+ * source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeIntoFloatBuffer(String outputName, FloatBuffer dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into the direct and
+ * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than
+ * or equal to that of the source Tensor. This operation will not affect dst's content past the
+ * source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeIntoIntBuffer(String outputName, IntBuffer dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into the direct and
+ * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than
+ * or equal to that of the source Tensor. This operation will not affect dst's content past the
+ * source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeIntoDoubleBuffer(String outputName, DoubleBuffer dst);
+
+ /**
+ * Read from a Tensor named {@link outputName} and copy the contents into the direct and
+ * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than
+ * or equal to that of the source Tensor. This operation will not affect dst's content past the
+ * source Tensor's size.
+ *
+ * @return 0 on success, -1 on failure.
+ */
+ public native int readNodeIntoByteBuffer(String outputName, ByteBuffer dst);
/**
* Canary method solely for determining if the tensorflow_inference native library should be
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
index 0a5d10e5c2f..d3cfe1fdf0a 100644
--- a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
+++ b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
@@ -285,47 +285,121 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
env->ReleaseIntArrayElements(dims, dim_vals, JNI_ABORT); \
tensorflow::Tensor input_tensor(TENSOR_DTYPE, shape); \
auto tensor_mapped = input_tensor.flat(); \
- j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
+ j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(src, &iCopied); \
j##JAVA_DTYPE* value_ptr = values; \
- const int array_size = env->GetArrayLength(arr); \
- for (int i = 0; \
- i < std::min(static_cast(tensor_mapped.size()), array_size); \
- ++i) { \
+ const int src_size = static_cast(env->GetArrayLength(src)); \
+ const int dst_size = static_cast(tensor_mapped.size()); \
+ CHECK_GE(src_size, dst_size) \
+ << "src array must have at least as many elements as dst Tensor."; \
+ const int num_items = std::min(src_size, dst_size); \
+ for (int i = 0; i < num_items; ++i) { \
tensor_mapped(i) = *value_ptr++; \
} \
- env->Release##DTYPE##ArrayElements(arr, values, JNI_ABORT); \
+ env->Release##DTYPE##ArrayElements(src, values, JNI_ABORT); \
std::string input_name = GetString(env, node_name); \
std::pair input_pair(input_name, \
input_tensor); \
vars->input_tensors[input_name] = input_pair; \
}
+#define FILL_NODE_NIO_BUFFER_METHOD(DTYPE, CTYPE, TENSOR_DTYPE) \
+ FILL_NODE_NIO_BUFFER_SIGNATURE(DTYPE) { \
+ SessionVariables* vars = GetSessionVars(env, thiz); \
+ tensorflow::TensorShape shape; \
+ const int* dim_vals = reinterpret_cast( \
+ env->GetDirectBufferAddress(dims_buffer)); \
+ const int num_dims = env->GetDirectBufferCapacity(dims_buffer); \
+ for (int i = 0; i < num_dims; ++i) { \
+ shape.AddDim(dim_vals[i]); \
+ } \
+ tensorflow::Tensor input_tensor(TENSOR_DTYPE, shape); \
+ auto tensor_mapped = input_tensor.flat(); \
+ const CTYPE* values = reinterpret_cast( \
+ env->GetDirectBufferAddress(src_buffer)); \
+ const CTYPE* value_ptr = values; \
+ const int src_size = \
+ static_cast(env->GetDirectBufferCapacity(src_buffer)); \
+ const int dst_size = static_cast(tensor_mapped.size()); \
+ CHECK_GE(src_size, dst_size) \
+ << "src buffer must have at least as many elements as dst Tensor."; \
+ const int num_items = std::min(src_size, dst_size); \
+ for (int i = 0; i < num_items; ++i) { \
+ tensor_mapped(i) = *value_ptr++; \
+ } \
+ std::string input_name = GetString(env, node_name); \
+ std::pair input_pair(input_name, \
+ input_tensor); \
+ vars->input_tensors[input_name] = input_pair; \
+ }
+
#define READ_NODE_METHOD(DTYPE, JAVA_DTYPE, CTYPE) \
READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
SessionVariables* vars = GetSessionVars(env, thiz); \
- Tensor* t = GetTensor(env, thiz, node_name_jstring); \
+ Tensor* t = GetTensor(env, thiz, node_name); \
if (t == nullptr) { \
return -1; \
} \
auto tensor_mapped = t->flat(); \
jboolean iCopied = JNI_FALSE; \
- j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
+ j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(dst, &iCopied); \
+ if (values == nullptr) { \
+ return -1; \
+ } \
j##JAVA_DTYPE* value_ptr = values; \
- const int num_items = std::min(static_cast(tensor_mapped.size()), \
- env->GetArrayLength(arr)); \
+ const int src_size = static_cast(tensor_mapped.size()); \
+ const int dst_size = static_cast(env->GetArrayLength(dst)); \
+ CHECK_GE(dst_size, src_size) \
+ << "dst array must have length >= src Tensor's flattened size."; \
+ const int num_items = std::min(src_size, dst_size); \
for (int i = 0; i < num_items; ++i) { \
*value_ptr++ = tensor_mapped(i); \
} \
- env->Release##DTYPE##ArrayElements(arr, values, 0); \
+ env->Release##DTYPE##ArrayElements(dst, values, 0); \
return 0; \
}
+#define READ_NODE_NIO_BUFFER_METHOD(DTYPE, CTYPE) \
+ READ_NODE_NIO_BUFFER_SIGNATURE(DTYPE) { \
+ SessionVariables* vars = GetSessionVars(env, thiz); \
+ Tensor* t = GetTensor(env, thiz, node_name); \
+ if (t == nullptr) { \
+ return -1; \
+ } \
+ auto tensor_mapped = t->flat(); \
+ CTYPE* values = \
+ reinterpret_cast(env->GetDirectBufferAddress(dst_buffer)); \
+ if (values == nullptr) { \
+ return -1; \
+ } \
+ CTYPE* value_ptr = values; \
+ const int src_size = static_cast(tensor_mapped.size()); \
+ const int dst_size = \
+ static_cast(env->GetDirectBufferCapacity(dst_buffer)); \
+ CHECK_GE(dst_size, src_size) \
+ << "dst buffer must have capacity >= src Tensor's flattened size."; \
+ const int num_items = std::min(src_size, dst_size); \
+ for (int i = 0; i < num_items; ++i) { \
+ *value_ptr++ = tensor_mapped(i); \
+ } \
+ return 0; \
+ }
+
FILL_NODE_METHOD(Float, float, float, tensorflow::DT_FLOAT)
FILL_NODE_METHOD(Int, int, int, tensorflow::DT_INT32)
FILL_NODE_METHOD(Double, double, double, tensorflow::DT_DOUBLE)
FILL_NODE_METHOD(Byte, byte, uint8_t, tensorflow::DT_UINT8)
+FILL_NODE_NIO_BUFFER_METHOD(Float, float, tensorflow::DT_FLOAT)
+FILL_NODE_NIO_BUFFER_METHOD(Int, int, tensorflow::DT_INT32)
+FILL_NODE_NIO_BUFFER_METHOD(Double, double, tensorflow::DT_DOUBLE)
+FILL_NODE_NIO_BUFFER_METHOD(Byte, uint8_t, tensorflow::DT_UINT8)
+
READ_NODE_METHOD(Float, float, float)
READ_NODE_METHOD(Int, int, int)
READ_NODE_METHOD(Double, double, double)
READ_NODE_METHOD(Byte, byte, uint8_t)
+
+READ_NODE_NIO_BUFFER_METHOD(Float, float);
+READ_NODE_NIO_BUFFER_METHOD(Int, int);
+READ_NODE_NIO_BUFFER_METHOD(Double, double);
+READ_NODE_NIO_BUFFER_METHOD(Byte, uint8_t);
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.h b/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
index d0aff10b098..93fb8ba3159 100644
--- a/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
+++ b/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
@@ -33,12 +33,20 @@ extern "C" {
#define FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
JNIEXPORT void TENSORFLOW_METHOD(fillNode##DTYPE)( \
JNIEnv * env, jobject thiz, jstring node_name, jintArray dims, \
- j##JAVA_DTYPE##Array arr)
+ j##JAVA_DTYPE##Array src)
-#define READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
- JNIEXPORT jint TENSORFLOW_METHOD(readNode##DTYPE)( \
- JNIEnv * env, jobject thiz, jstring node_name_jstring, \
- j##JAVA_DTYPE##Array arr)
+#define FILL_NODE_NIO_BUFFER_SIGNATURE(DTYPE) \
+ JNIEXPORT void TENSORFLOW_METHOD(fillNodeFrom##DTYPE##Buffer)( \
+ JNIEnv * env, jobject thiz, jstring node_name, jobject dims_buffer, \
+ jobject src_buffer)
+
+#define READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
+ JNIEXPORT jint TENSORFLOW_METHOD(readNode##DTYPE)( \
+ JNIEnv * env, jobject thiz, jstring node_name, j##JAVA_DTYPE##Array dst)
+
+#define READ_NODE_NIO_BUFFER_SIGNATURE(DTYPE) \
+ JNIEXPORT jint TENSORFLOW_METHOD(readNodeInto##DTYPE##Buffer)( \
+ JNIEnv * env, jobject thiz, jstring node_name, jobject dst_buffer)
JNIEXPORT void JNICALL TENSORFLOW_METHOD(testLoaded)(JNIEnv* env, jobject thiz);
@@ -61,11 +69,21 @@ FILL_NODE_SIGNATURE(Int, int);
FILL_NODE_SIGNATURE(Double, double);
FILL_NODE_SIGNATURE(Byte, byte);
+FILL_NODE_NIO_BUFFER_SIGNATURE(Float);
+FILL_NODE_NIO_BUFFER_SIGNATURE(Int);
+FILL_NODE_NIO_BUFFER_SIGNATURE(Double);
+FILL_NODE_NIO_BUFFER_SIGNATURE(Byte);
+
READ_NODE_SIGNATURE(Float, float);
READ_NODE_SIGNATURE(Int, int);
READ_NODE_SIGNATURE(Double, double);
READ_NODE_SIGNATURE(Byte, byte);
+READ_NODE_NIO_BUFFER_SIGNATURE(Float);
+READ_NODE_NIO_BUFFER_SIGNATURE(Int);
+READ_NODE_NIO_BUFFER_SIGNATURE(Double);
+READ_NODE_NIO_BUFFER_SIGNATURE(Byte);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus