Add new API for running using SignatureDef in Interpreter Java API.
This change exposes these new methods: * runSignature: Which accepts map inputs/ outputs which uses signatureDef inputs/outputs names as keys. * getSignatureDefNames: Returns List of Strings of available signatures in the loaded model. * getSignatureInputs: Gets the list of SignatureDefs inputs for method provided. * getSignatureOutputs: Gets the list of SignatureDefs outputs for method PiperOrigin-RevId: 348146116 Change-Id: I3e44982b0d27e01074298273530a6ec7a525932b
This commit is contained in:
parent
2d74ab3683
commit
6be2ddd86a
@ -16,6 +16,7 @@ exports_files([
|
||||
"src/testdata/add.bin",
|
||||
"src/testdata/add_unknown_dimensions.bin",
|
||||
"src/testdata/grace_hopper_224.jpg",
|
||||
"src/testdata/mul_add_signature_def.bin",
|
||||
"src/testdata/tile_with_bool_input.bin",
|
||||
"AndroidManifest.xml",
|
||||
"proguard.flags",
|
||||
@ -262,6 +263,7 @@ java_test(
|
||||
data = [
|
||||
"src/testdata/add.bin",
|
||||
"src/testdata/add_unknown_dimensions.bin",
|
||||
"src/testdata/mul_add_signature_def.bin",
|
||||
"src/testdata/tile_with_bool_input.bin",
|
||||
"//tensorflow/lite:testdata/dynamic_shapes.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
|
@ -19,6 +19,7 @@ import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -77,6 +78,8 @@ import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
*
|
||||
* <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
|
||||
* but is not guaranteed.
|
||||
*
|
||||
* <p>Note: This class is not thread safe.
|
||||
*/
|
||||
public final class Interpreter implements AutoCloseable {
|
||||
|
||||
@ -222,6 +225,7 @@ public final class Interpreter implements AutoCloseable {
|
||||
*/
|
||||
public Interpreter(@NonNull File modelFile, Options options) {
|
||||
wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
|
||||
signatureNameList = getSignatureDefNames();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -281,6 +285,7 @@ public final class Interpreter implements AutoCloseable {
|
||||
*/
|
||||
public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
|
||||
wrapper = new NativeInterpreterWrapper(byteBuffer, options);
|
||||
signatureNameList = getSignatureDefNames();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -369,6 +374,49 @@ public final class Interpreter implements AutoCloseable {
|
||||
wrapper.run(inputs, outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs model inference based on SignatureDef provided through @code methodName.
|
||||
*
|
||||
* <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
|
||||
* data types.
|
||||
*
|
||||
* @param inputs A Map of inputs from input name in the signatureDef to an input object.
|
||||
* @param outputs a map mapping from output name in SignatureDef to output data.
|
||||
* @param methodName The exported method name identifying the SignatureDef.
|
||||
* @throws IllegalArgumentException if {@code inputs} or {@code outputs} or {@code methodName}is
|
||||
* null or empty, or if error occurs when running the inference.
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*/
|
||||
public void runSignature(
|
||||
@NonNull Map<String, Object> inputs,
|
||||
@NonNull Map<String, Object> outputs,
|
||||
String methodName) {
|
||||
checkNotClosed();
|
||||
if (methodName == null && signatureNameList.length == 1) {
|
||||
methodName = signatureNameList[0];
|
||||
}
|
||||
if (methodName == null) {
|
||||
throw new IllegalArgumentException(
|
||||
"Input error: SignatureDef methodName should not be null. null is only allowed if the"
|
||||
+ " model has a single Signature. Available Signatures: "
|
||||
+ Arrays.toString(signatureNameList));
|
||||
}
|
||||
wrapper.runSignature(inputs, outputs, methodName);
|
||||
}
|
||||
|
||||
/* Same as {@link Interpreter#runSignature(Object, Object, String)} but doesn't require
|
||||
* passing a methodName, assuming the model has one SignatureDef. If the model has more than
|
||||
* one SignatureDef it will throw an exception.
|
||||
*
|
||||
* * <p>WARNING: This is an experimental API and subject to change.
|
||||
* */
|
||||
public void runSignature(
|
||||
@NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
|
||||
checkNotClosed();
|
||||
runSignature(inputs, outputs, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Expicitly updates allocations for all tensors, if necessary.
|
||||
*
|
||||
@ -450,6 +498,36 @@ public final class Interpreter implements AutoCloseable {
|
||||
return wrapper.getInputTensor(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the list of SignatureDef exported method names available in the model.
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*/
|
||||
public String[] getSignatureDefNames() {
|
||||
checkNotClosed();
|
||||
return wrapper.getSignatureDefNames();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the list of SignatureDefs inputs for method {@code methodName}
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*/
|
||||
public String[] getSignatureInputs(String methodName) {
|
||||
checkNotClosed();
|
||||
return wrapper.getSignatureInputs(methodName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the list of SignatureDefs outputs for method {@code methodName}
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*/
|
||||
public String[] getSignatureOutputs(String methodName) {
|
||||
checkNotClosed();
|
||||
return wrapper.getSignatureOutputs(methodName);
|
||||
}
|
||||
|
||||
/** Gets the number of output Tensors. */
|
||||
public int getOutputTensorCount() {
|
||||
checkNotClosed();
|
||||
@ -584,4 +662,5 @@ public final class Interpreter implements AutoCloseable {
|
||||
}
|
||||
|
||||
NativeInterpreterWrapper wrapper;
|
||||
String[] signatureNameList;
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import org.tensorflow.lite.nnapi.NnApiDelegate;
|
||||
|
||||
/**
|
||||
@ -30,6 +31,8 @@ import org.tensorflow.lite.nnapi.NnApiDelegate;
|
||||
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
|
||||
* explicitly freed by invoking the {@link #close()} method when the {@code
|
||||
* NativeInterpreterWrapper} object is no longer needed.
|
||||
*
|
||||
* Note: This class is not thread safe.
|
||||
*/
|
||||
final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
|
||||
@ -136,6 +139,36 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
ownedDelegates.clear();
|
||||
}
|
||||
|
||||
public void runSignature(
|
||||
Map<String, Object> inputs, Map<String, Object> outputs, String methodName) {
|
||||
if (inputs == null || inputs.isEmpty()) {
|
||||
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
|
||||
}
|
||||
if (outputs == null || outputs.isEmpty()) {
|
||||
throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
|
||||
}
|
||||
initTensorIndexesMaps();
|
||||
// Map inputs/output to input indexes.
|
||||
Map<Integer, Object> inputsWithInputIndex = new TreeMap<>();
|
||||
Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>();
|
||||
for (Map.Entry<String, Object> input : inputs.entrySet()) {
|
||||
int tensorIndex =
|
||||
getInputTensorIndexFromSignature(interpreterHandle, input.getKey(), methodName);
|
||||
inputsWithInputIndex.put(tensorToInputsIndexes.get(tensorIndex), input.getValue());
|
||||
}
|
||||
for (Map.Entry<String, Object> output : outputs.entrySet()) {
|
||||
int tensorIndex =
|
||||
getOutputTensorIndexFromSignature(interpreterHandle, output.getKey(), methodName);
|
||||
outputsWithOutputIndex.put(tensorToOutputsIndexes.get(tensorIndex), output.getValue());
|
||||
}
|
||||
Object[] inputsList = new Object[inputs.size()];
|
||||
int index = 0;
|
||||
for (Map.Entry<Integer, Object> input : inputsWithInputIndex.entrySet()) {
|
||||
inputsList[index++] = input.getValue();
|
||||
}
|
||||
run(inputsList, outputsWithOutputIndex);
|
||||
}
|
||||
|
||||
/** Sets inputs, runs model inference and returns outputs. */
|
||||
void run(Object[] inputs, Map<Integer, Object> outputs) {
|
||||
inferenceDurationNanoseconds = -1;
|
||||
@ -257,7 +290,26 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
String.format(
|
||||
"Input error: '%s' is not a valid name for any input. Names of inputs and their "
|
||||
+ "indexes are %s",
|
||||
name, inputsIndexes.toString()));
|
||||
name, inputsIndexes));
|
||||
}
|
||||
}
|
||||
|
||||
/** Initializes mapping from tensor index to input/output index. **/
|
||||
private void initTensorIndexesMaps() {
|
||||
if (tensorToInputsIndexes != null) {
|
||||
return;
|
||||
}
|
||||
tensorToInputsIndexes = new HashMap<>();
|
||||
tensorToOutputsIndexes = new HashMap<>();
|
||||
int inputCount = getInputTensorCount();
|
||||
for (int i = 0; i < inputCount; ++i) {
|
||||
int tensorIndex = getInputTensorIndex(interpreterHandle, i);
|
||||
tensorToInputsIndexes.put(tensorIndex, i);
|
||||
}
|
||||
int outputCount = getOutputTensorCount();
|
||||
for (int i = 0; i < outputCount; ++i) {
|
||||
int tensorIndex = getOutputTensorIndex(interpreterHandle, i);
|
||||
tensorToOutputsIndexes.put(tensorIndex, i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,7 +331,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
String.format(
|
||||
"Input error: '%s' is not a valid name for any output. Names of outputs and their "
|
||||
+ "indexes are %s",
|
||||
name, outputsIndexes.toString()));
|
||||
name, outputsIndexes));
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,6 +366,27 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
return inputTensor;
|
||||
}
|
||||
|
||||
/** Gets the list of SignatureDefs available in the model, if any. */
|
||||
public String[] getSignatureDefNames() {
|
||||
return getSignatureDefNames(interpreterHandle);
|
||||
}
|
||||
|
||||
private static native String[] getSignatureDefNames(long interpreterHandle);
|
||||
|
||||
/** Gets the list of SignatureDefs inputs for method {@code methodName} */
|
||||
String[] getSignatureInputs(String methodName) {
|
||||
return getSignatureInputs(interpreterHandle, methodName);
|
||||
}
|
||||
|
||||
private static native String[] getSignatureInputs(long interpreterHandle, String methodName);
|
||||
|
||||
/** Gets the list of SignatureDefs outputs for method {@code methodName} */
|
||||
String[] getSignatureOutputs(String methodName) {
|
||||
return getSignatureOutputs(interpreterHandle, methodName);
|
||||
}
|
||||
|
||||
private static native String[] getSignatureOutputs(long interpreterHandle, String methodName);
|
||||
|
||||
/** Gets the number of output tensors. */
|
||||
int getOutputTensorCount() {
|
||||
return outputTensors.length;
|
||||
@ -430,6 +503,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
// Lazily constructed maps of input and output names to input and output Tensor indexes.
|
||||
private Map<String, Integer> inputsIndexes;
|
||||
private Map<String, Integer> outputsIndexes;
|
||||
// Lazily constructed maps of tensor index to index in input and output indexes.
|
||||
private Map<Integer, Integer> tensorToInputsIndexes;
|
||||
private Map<Integer, Integer> tensorToOutputsIndexes;
|
||||
|
||||
// Lazily constructed and populated arrays of input and output Tensor wrappers.
|
||||
private Tensor[] inputTensors;
|
||||
@ -448,6 +524,12 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
|
||||
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
|
||||
|
||||
private static native int getInputTensorIndexFromSignature(
|
||||
long interpreterHandle, String signatureInputName, String methodName);
|
||||
|
||||
private static native int getOutputTensorIndexFromSignature(
|
||||
long interpreterHandle, String signatureInputName, String methodName);
|
||||
|
||||
private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
|
||||
|
||||
private static native int getInputCount(long interpreterHandle);
|
||||
|
@ -158,6 +158,48 @@ bool VerifyModel(const void* buf, size_t len) {
|
||||
return tflite::VerifyModelBuffer(verifier);
|
||||
}
|
||||
|
||||
// Helper method that fetches the tensor index based on SignatureDef details
|
||||
// from either inputs or outputs.
|
||||
// Returns -1 if invalid names are passed.
|
||||
int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
|
||||
jstring method_name,
|
||||
tflite::Interpreter* interpreter,
|
||||
bool is_input) {
|
||||
// Fetch name strings.
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
const char* signature_input_name_ptr =
|
||||
env->GetStringUTFChars(signature_tensor_name, nullptr);
|
||||
// Lookup if the input is valid.
|
||||
const auto& signature_list =
|
||||
(is_input ? interpreter->signature_inputs(method_name_ptr)
|
||||
: interpreter->signature_outputs(method_name_ptr));
|
||||
const auto& tensor = signature_list.find(signature_input_name_ptr);
|
||||
// Release the memory before returning.
|
||||
env->ReleaseStringUTFChars(method_name, method_name_ptr);
|
||||
env->ReleaseStringUTFChars(signature_tensor_name, signature_input_name_ptr);
|
||||
return tensor == signature_list.end() ? -1 : tensor->second;
|
||||
}
|
||||
|
||||
jobjectArray GetSignatureInputsOutputsList(
|
||||
const std::map<std::string, uint32_t>& input_output_list, JNIEnv* env) {
|
||||
jclass string_class = env->FindClass("java/lang/String");
|
||||
if (string_class == nullptr) {
|
||||
ThrowException(env, tflite::jni::kUnsupportedOperationException,
|
||||
"Internal error: Can not find java/lang/String class to get "
|
||||
"SignatureDef names.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
jobjectArray names = env->NewObjectArray(input_output_list.size(),
|
||||
string_class, env->NewStringUTF(""));
|
||||
int i = 0;
|
||||
for (const auto& input : input_output_list) {
|
||||
env->SetObjectArrayElement(names, i++,
|
||||
env->NewStringUTF(input.first.c_str()));
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -226,6 +268,74 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
|
||||
return JNI_FALSE;
|
||||
}
|
||||
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
jclass string_class = env->FindClass("java/lang/String");
|
||||
if (string_class == nullptr) {
|
||||
ThrowException(env, tflite::jni::kUnsupportedOperationException,
|
||||
"Internal error: Can not find java/lang/String class to get "
|
||||
"SignatureDef names.");
|
||||
return nullptr;
|
||||
}
|
||||
const auto& signature_defs = interpreter->signature_def_names();
|
||||
jobjectArray names = static_cast<jobjectArray>(env->NewObjectArray(
|
||||
signature_defs.size(), string_class, env->NewStringUTF("")));
|
||||
for (int i = 0; i < signature_defs.size(); ++i) {
|
||||
env->SetObjectArrayElement(names, i,
|
||||
env->NewStringUTF(signature_defs[i]->c_str()));
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
|
||||
interpreter->signature_inputs(method_name_ptr), env);
|
||||
// Release the memory before returning.
|
||||
env->ReleaseStringUTFChars(method_name, method_name_ptr);
|
||||
return signature_inputs;
|
||||
}
|
||||
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
|
||||
interpreter->signature_outputs(method_name_ptr), env);
|
||||
// Release the memory before returning.
|
||||
env->ReleaseStringUTFChars(method_name, method_name_ptr);
|
||||
return signature_outputs;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
|
||||
jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return -1;
|
||||
return GetTensorIndexForSignature(env, signature_input_name, method_name,
|
||||
interpreter, /*is_input=*/true);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
|
||||
jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return -1;
|
||||
return GetTensorIndexForSignature(env, signature_output_name, method_name,
|
||||
interpreter, /*is_input=*/false);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
|
||||
|
@ -44,6 +44,8 @@ public final class InterpreterTest {
|
||||
"tensorflow/lite/testdata/dynamic_shapes.bin";
|
||||
private static final String BOOL_MODEL =
|
||||
"tensorflow/lite/java/src/testdata/tile_with_bool_input.bin";
|
||||
private static final String MODEL_WITH_SIGNATURE_PATH =
|
||||
"tensorflow/lite/java/src/testdata/mul_add_signature_def.bin";
|
||||
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
|
||||
@ -55,6 +57,8 @@ public final class InterpreterTest {
|
||||
private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
|
||||
private static final ByteBuffer BOOL_MODEL_BUFFER = TestUtils.getTestFileAsBuffer(BOOL_MODEL);
|
||||
private static final ByteBuffer MODEL_WITH_SIGNATURE_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(MODEL_WITH_SIGNATURE_PATH);
|
||||
|
||||
@Test
|
||||
public void testInterpreter() throws Exception {
|
||||
@ -723,6 +727,91 @@ public final class InterpreterTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelWithSignatureDef() {
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
|
||||
String[] signatureNames = interpreter.getSignatureDefNames();
|
||||
String[] expectedSignatureNames = {"mul_add"};
|
||||
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
|
||||
|
||||
String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
|
||||
String[] expectedSignatureInputs = {"x", "y"};
|
||||
assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
|
||||
|
||||
String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
|
||||
String[] expectedSignatureOutputs = {"output_0"};
|
||||
assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
|
||||
|
||||
FloatBuffer output = FloatBuffer.allocate(1);
|
||||
float[] inputX = {2.0f};
|
||||
float[] inputY = {4.0f};
|
||||
Map<String, Object> inputs = new HashMap<>();
|
||||
inputs.put("x", inputX);
|
||||
inputs.put("y", inputY);
|
||||
Map<String, Object> outputs = new HashMap<>();
|
||||
outputs.put("output_0", output);
|
||||
interpreter.runSignature(inputs, outputs, "mul_add");
|
||||
// Result should be x * 3.0 + y
|
||||
FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
|
||||
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelWithSignatureDefNullMethodName() {
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_WITH_SIGNATURE_BUFFER)) {
|
||||
String[] signatureNames = interpreter.getSignatureDefNames();
|
||||
String[] expectedSignatureNames = {"mul_add"};
|
||||
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
|
||||
|
||||
String[] signatureInputs = interpreter.getSignatureInputs(expectedSignatureNames[0]);
|
||||
String[] expectedSignatureInputs = {"x", "y"};
|
||||
assertThat(signatureInputs).isEqualTo(expectedSignatureInputs);
|
||||
|
||||
String[] signatureOutputs = interpreter.getSignatureOutputs(expectedSignatureNames[0]);
|
||||
String[] expectedSignatureOutputs = {"output_0"};
|
||||
assertThat(signatureOutputs).isEqualTo(expectedSignatureOutputs);
|
||||
|
||||
FloatBuffer output = FloatBuffer.allocate(1);
|
||||
float[] inputX = {2.0f};
|
||||
float[] inputY = {4.0f};
|
||||
Map<String, Object> inputs = new HashMap<>();
|
||||
inputs.put("x", inputX);
|
||||
inputs.put("y", inputY);
|
||||
Map<String, Object> outputs = new HashMap<>();
|
||||
outputs.put("output_0", output);
|
||||
interpreter.runSignature(inputs, outputs, null);
|
||||
// Result should be x * 3.0 + y
|
||||
FloatBuffer expected = fill(FloatBuffer.allocate(1), 10.0f);
|
||||
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
|
||||
output = FloatBuffer.allocate(1);
|
||||
outputs.put("output_0", output);
|
||||
interpreter.runSignature(inputs, outputs);
|
||||
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelWithSignatureDefNoSignatures() {
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
|
||||
String[] signatureNames = interpreter.getSignatureDefNames();
|
||||
String[] expectedSignatureNames = {};
|
||||
assertThat(signatureNames).isEqualTo(expectedSignatureNames);
|
||||
Map<String, Object> inputs = new HashMap<>();
|
||||
Map<String, Object> outputs = new HashMap<>();
|
||||
try {
|
||||
interpreter.runSignature(inputs, outputs);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertThat(e)
|
||||
.hasMessageThat()
|
||||
.contains(
|
||||
"Input error: SignatureDef methodName should not be null. null is only allowed if"
|
||||
+ " the model has a single Signature");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static native long getNativeHandleForDelegate();
|
||||
|
||||
private static native long getNativeHandleForInvalidDelegate();
|
||||
|
BIN
tensorflow/lite/java/src/testdata/mul_add_signature_def.bin
vendored
Normal file
BIN
tensorflow/lite/java/src/testdata/mul_add_signature_def.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user