Add new API for running using SignatureDef in Interpreter Java API.

This change exposes these new methods:
* runSignature: Which accepts map inputs/ outputs which uses signatureDef inputs/outputs names as keys.
* getSignatureDefNames: Returns List of Strings of available signatures in the loaded model.
* getSignatureInputs: Gets the list of SignatureDefs inputs for method provided.
* getSignatureOutputs: Gets the list of SignatureDefs outputs for method

PiperOrigin-RevId: 348146116
Change-Id: I3e44982b0d27e01074298273530a6ec7a525932b
This commit is contained in:
Karim Nosir 2020-12-17 21:38:16 -08:00 committed by TensorFlower Gardener
parent 2d74ab3683
commit 6be2ddd86a
6 changed files with 364 additions and 2 deletions

View File

@ -16,6 +16,7 @@ exports_files([
"src/testdata/add.bin",
"src/testdata/add_unknown_dimensions.bin",
"src/testdata/grace_hopper_224.jpg",
"src/testdata/mul_add_signature_def.bin",
"src/testdata/tile_with_bool_input.bin",
"AndroidManifest.xml",
"proguard.flags",
@ -262,6 +263,7 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/add_unknown_dimensions.bin",
"src/testdata/mul_add_signature_def.bin",
"src/testdata/tile_with_bool_input.bin",
"//tensorflow/lite:testdata/dynamic_shapes.bin",
"//tensorflow/lite:testdata/multi_add.bin",

View File

@ -19,6 +19,7 @@ import java.io.File;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -77,6 +78,8 @@ import org.checkerframework.checker.nullness.qual.NonNull;
*
* <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
* but is not guaranteed.
*
* <p>Note: This class is not thread safe.
*/
public final class Interpreter implements AutoCloseable {
@ -222,6 +225,7 @@ public final class Interpreter implements AutoCloseable {
*/
public Interpreter(@NonNull File modelFile, Options options) {
wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
signatureNameList = getSignatureDefNames();
}
/**
@ -281,6 +285,7 @@ public final class Interpreter implements AutoCloseable {
*/
public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
wrapper = new NativeInterpreterWrapper(byteBuffer, options);
signatureNameList = getSignatureDefNames();
}
/**
@ -369,6 +374,49 @@ public final class Interpreter implements AutoCloseable {
wrapper.run(inputs, outputs);
}
/**
* Runs model inference based on SignatureDef provided through @code methodName.
*
* <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
* data types.
*
* @param inputs A Map of inputs from input name in the signatureDef to an input object.
* @param outputs a map mapping from output name in SignatureDef to output data.
* @param methodName The exported method name identifying the SignatureDef.
* @throws IllegalArgumentException if {@code inputs} or {@code outputs} or {@code methodName}is
* null or empty, or if error occurs when running the inference.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public void runSignature(
@NonNull Map<String, Object> inputs,
@NonNull Map<String, Object> outputs,
String methodName) {
checkNotClosed();
if (methodName == null && signatureNameList.length == 1) {
methodName = signatureNameList[0];
}
if (methodName == null) {
throw new IllegalArgumentException(
"Input error: SignatureDef methodName should not be null. null is only allowed if the"
+ " model has a single Signature. Available Signatures: "
+ Arrays.toString(signatureNameList));
}
wrapper.runSignature(inputs, outputs, methodName);
}
/* Same as {@link Interpreter#runSignature(Object, Object, String)} but doesn't require
* passing a methodName, assuming the model has one SignatureDef. If the model has more than
* one SignatureDef it will throw an exception.
*
* * <p>WARNING: This is an experimental API and subject to change.
* */
public void runSignature(
@NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
checkNotClosed();
runSignature(inputs, outputs, null);
}
/**
* Expicitly updates allocations for all tensors, if necessary.
*
@ -450,6 +498,36 @@ public final class Interpreter implements AutoCloseable {
return wrapper.getInputTensor(inputIndex);
}
/**
* Gets the list of SignatureDef exported method names available in the model.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureDefNames() {
checkNotClosed();
return wrapper.getSignatureDefNames();
}
/**
* Gets the list of SignatureDefs inputs for method {@code methodName}
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureInputs(String methodName) {
checkNotClosed();
return wrapper.getSignatureInputs(methodName);
}
/**
* Gets the list of SignatureDefs outputs for method {@code methodName}
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureOutputs(String methodName) {
checkNotClosed();
return wrapper.getSignatureOutputs(methodName);
}
/** Gets the number of output Tensors. */
public int getOutputTensorCount() {
checkNotClosed();
@ -584,4 +662,5 @@ public final class Interpreter implements AutoCloseable {
}
NativeInterpreterWrapper wrapper;
String[] signatureNameList;
}

View File

@ -22,6 +22,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.tensorflow.lite.nnapi.NnApiDelegate;
/**
@ -30,6 +31,8 @@ import org.tensorflow.lite.nnapi.NnApiDelegate;
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
* NativeInterpreterWrapper} object is no longer needed.
*
* Note: This class is not thread safe.
*/
final class NativeInterpreterWrapper implements AutoCloseable {
@ -136,6 +139,36 @@ final class NativeInterpreterWrapper implements AutoCloseable {
ownedDelegates.clear();
}
public void runSignature(
Map<String, Object> inputs, Map<String, Object> outputs, String methodName) {
if (inputs == null || inputs.isEmpty()) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
if (outputs == null || outputs.isEmpty()) {
throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
}
initTensorIndexesMaps();
// Map inputs/output to input indexes.
Map<Integer, Object> inputsWithInputIndex = new TreeMap<>();
Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>();
for (Map.Entry<String, Object> input : inputs.entrySet()) {
int tensorIndex =
getInputTensorIndexFromSignature(interpreterHandle, input.getKey(), methodName);
inputsWithInputIndex.put(tensorToInputsIndexes.get(tensorIndex), input.getValue());
}
for (Map.Entry<String, Object> output : outputs.entrySet()) {
int tensorIndex =
getOutputTensorIndexFromSignature(interpreterHandle, output.getKey(), methodName);
outputsWithOutputIndex.put(tensorToOutputsIndexes.get(tensorIndex), output.getValue());
}
Object[] inputsList = new Object[inputs.size()];
int index = 0;
for (Map.Entry<Integer, Object> input : inputsWithInputIndex.entrySet()) {
inputsList[index++] = input.getValue();
}
run(inputsList, outputsWithOutputIndex);
}
/** Sets inputs, runs model inference and returns outputs. */
void run(Object[] inputs, Map<Integer, Object> outputs) {
inferenceDurationNanoseconds = -1;
@ -257,7 +290,26 @@ final class NativeInterpreterWrapper implements AutoCloseable {
String.format(
"Input error: '%s' is not a valid name for any input. Names of inputs and their "
+ "indexes are %s",
name, inputsIndexes.toString()));
name, inputsIndexes));
}
}
/** Initializes mapping from tensor index to input/output index. **/
private void initTensorIndexesMaps() {
if (tensorToInputsIndexes != null) {
return;
}
tensorToInputsIndexes = new HashMap<>();
tensorToOutputsIndexes = new HashMap<>();
int inputCount = getInputTensorCount();
for (int i = 0; i < inputCount; ++i) {
int tensorIndex = getInputTensorIndex(interpreterHandle, i);
tensorToInputsIndexes.put(tensorIndex, i);
}
int outputCount = getOutputTensorCount();
for (int i = 0; i < outputCount; ++i) {
int tensorIndex = getOutputTensorIndex(interpreterHandle, i);
tensorToOutputsIndexes.put(tensorIndex, i);
}
}
@ -279,7 +331,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
String.format(
"Input error: '%s' is not a valid name for any output. Names of outputs and their "
+ "indexes are %s",
name, outputsIndexes.toString()));
name, outputsIndexes));
}
}
@ -314,6 +366,27 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return inputTensor;
}
/** Gets the list of SignatureDefs available in the model, if any. */
public String[] getSignatureDefNames() {
return getSignatureDefNames(interpreterHandle);
}
private static native String[] getSignatureDefNames(long interpreterHandle);
/** Gets the list of SignatureDefs inputs for method {@code methodName} */
String[] getSignatureInputs(String methodName) {
return getSignatureInputs(interpreterHandle, methodName);
}
private static native String[] getSignatureInputs(long interpreterHandle, String methodName);
/** Gets the list of SignatureDefs outputs for method {@code methodName} */
String[] getSignatureOutputs(String methodName) {
return getSignatureOutputs(interpreterHandle, methodName);
}
private static native String[] getSignatureOutputs(long interpreterHandle, String methodName);
/** Gets the number of output tensors. */
int getOutputTensorCount() {
return outputTensors.length;
@ -430,6 +503,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
// Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
private Map<String, Integer> outputsIndexes;
// Lazily constructed maps of tensor index to index in input and output indexes.
private Map<Integer, Integer> tensorToInputsIndexes;
private Map<Integer, Integer> tensorToOutputsIndexes;
// Lazily constructed and populated arrays of input and output Tensor wrappers.
private Tensor[] inputTensors;
@ -448,6 +524,12 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
private static native int getInputTensorIndexFromSignature(
long interpreterHandle, String signatureInputName, String methodName);
private static native int getOutputTensorIndexFromSignature(
long interpreterHandle, String signatureInputName, String methodName);
private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
private static native int getInputCount(long interpreterHandle);

View File

@ -158,6 +158,48 @@ bool VerifyModel(const void* buf, size_t len) {
return tflite::VerifyModelBuffer(verifier);
}
// Helper method that fetches the tensor index based on SignatureDef details
// from either inputs or outputs.
// Returns -1 if invalid names are passed.
int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
jstring method_name,
tflite::Interpreter* interpreter,
bool is_input) {
// Fetch name strings.
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const char* signature_input_name_ptr =
env->GetStringUTFChars(signature_tensor_name, nullptr);
// Lookup if the input is valid.
const auto& signature_list =
(is_input ? interpreter->signature_inputs(method_name_ptr)
: interpreter->signature_outputs(method_name_ptr));
const auto& tensor = signature_list.find(signature_input_name_ptr);
// Release the memory before returning.
env->ReleaseStringUTFChars(method_name, method_name_ptr);
env->ReleaseStringUTFChars(signature_tensor_name, signature_input_name_ptr);
return tensor == signature_list.end() ? -1 : tensor->second;
}
jobjectArray GetSignatureInputsOutputsList(
const std::map<std::string, uint32_t>& input_output_list, JNIEnv* env) {
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
ThrowException(env, tflite::jni::kUnsupportedOperationException,
"Internal error: Can not find java/lang/String class to get "
"SignatureDef names.");
return nullptr;
}
jobjectArray names = env->NewObjectArray(input_output_list.size(),
string_class, env->NewStringUTF(""));
int i = 0;
for (const auto& input : input_output_list) {
env->SetObjectArrayElement(names, i++,
env->NewStringUTF(input.first.c_str()));
}
return names;
}
} // namespace
extern "C" {
@ -226,6 +268,74 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
return JNI_FALSE;
}
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
JNIEnv* env, jclass clazz, jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
ThrowException(env, tflite::jni::kUnsupportedOperationException,
"Internal error: Can not find java/lang/String class to get "
"SignatureDef names.");
return nullptr;
}
const auto& signature_defs = interpreter->signature_def_names();
jobjectArray names = static_cast<jobjectArray>(env->NewObjectArray(
signature_defs.size(), string_class, env->NewStringUTF("")));
for (int i = 0; i < signature_defs.size(); ++i) {
env->SetObjectArrayElement(names, i,
env->NewStringUTF(signature_defs[i]->c_str()));
}
return names;
}
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
interpreter->signature_inputs(method_name_ptr), env);
// Release the memory before returning.
env->ReleaseStringUTFChars(method_name, method_name_ptr);
return signature_inputs;
}
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
interpreter->signature_outputs(method_name_ptr), env);
// Release the memory before returning.
env->ReleaseStringUTFChars(method_name, method_name_ptr);
return signature_outputs;
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_input_name, method_name,
interpreter, /*is_input=*/true);
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_output_name, method_name,
interpreter, /*is_input=*/false);
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
JNIEnv* env, jclass clazz, jlong handle, jint input_index) {

View File

@ -44,6 +44,8 @@ public final class InterpreterTest {
"tensorflow/lite/testdata/dynamic_shapes.bin";
private static final String BOOL_MODEL =
"tensorflow/lite/java/src/testdata/tile_with_bool_input.bin";
private static final String MODEL_WITH_SIGNATURE_PATH =
"tensorflow/lite/java/src/testdata/mul_add_signature_def.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
@ -55,6 +57,8 @@ public final class InterpreterTest {
private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
private static final ByteBuffer BOOL_MODEL_BUFFER = TestUtils.getTestFileAsBuffer(BOOL_MODEL);
private static final ByteBuffer MODEL_WITH_SIGNATURE_BUFFER =
TestUtils.getTestFileAsBuffer(MODEL_WITH_SIGNATURE_PATH);
@Test
public void testInterpreter() throws Exception {
@ -723,6 +727,91 @@ public final class InterpreterTest {
}
}
@Test
public void testModelWithSignatureDef() {
try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
String[] signatureNames = interpreter.getSignatureDefNames();
String[] expectedSignatureNames = {"mul_add"};
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
String[] expectedSignatureInputs = {"x", "y"};
assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
String[] expectedSignatureOutputs = {"output_0"};
assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
FloatBuffer output = FloatBuffer.allocate(1);
float[] inputX = {2.0f};
float[] inputY = {4.0f};
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", inputX);
inputs.put("y", inputY);
Map<String, Object> outputs = new HashMap<>();
outputs.put("output_0", output);
interpreter.runSignature(inputs, outputs, "mul_add");
// Result should be x * 3.0 + y
FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
}
}
@Test
public void testModelWithSignatureDefNullMethodName() {
try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
String[] signatureNames = interpreter.getSignatureDefNames();
String[] expectedSignatureNames = {"mul_add"};
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
String[] expectedSignatureInputs = {"x", "y"};
assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
String[] expectedSignatureOutputs = {"output_0"};
assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
FloatBuffer output = FloatBuffer.allocate(1);
float[] inputX = {2.0f};
float[] inputY = {4.0f};
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", inputX);
inputs.put("y", inputY);
Map<String, Object> outputs = new HashMap<>();
outputs.put("output_0", output);
interpreter.runSignature(inputs, outputs, null);
// Result should be x * 3.0 + y
FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
output = FloatBuffer.allocate(1);
outputs.put("output_0", output);
interpreter.runSignature(inputs, outputs);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
}
}
@Test
public void testModelWithSignatureDefNoSignatures() {
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
String[] signatureNames = interpreter.getSignatureDefNames();
String[] expectedSignatureNames = {};
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
Map<String, Object> inputs = new HashMap<>();
Map<String, Object> outputs = new HashMap<>();
try {
interpreter.runSignature(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"Input error: SignatureDef methodName should not be null. null is only allowed if"
+ " the model has a single Signature");
}
}
}
private static native long getNativeHandleForDelegate();
private static native long getNativeHandleForInvalidDelegate();

Binary file not shown.