Add TfLite flex delegate with support for TF ops
Create a org.tensorflow.lite.flex.FlexDelegate class which wraps its native counterpart for using TensorFlow ops in TensorFlow Lite. Clients can either instantiate this delegate directly when using ops that require TF ops, or add it as a dependency to their project, and it will be instantiated when necessary. Also introduce a tensorflow-lite-select-tf-ops.aar target, which is a plugin that should be used alongside tensorflow-lite.aar. The existing tensorflow-lite-with-select-tf-ops, which is a monolithic build that includes all of core TFLite, is now deprecated and will soon be removed. This work is in anticipation of pushing prebuilt tensorflow-lite-select-tf-ops libraries to the TensorFlow Lite Bintray repository. PiperOrigin-RevId: 264948003
This commit is contained in:
parent
095f802808
commit
4250bf575c
@ -640,7 +640,7 @@ TfLiteStatus Subgraph::OpPrepare(const TfLiteRegistration& op_reg,
|
||||
if (IsFlexOp(op_reg.custom_name)) {
|
||||
ReportError(
|
||||
"Regular TensorFlow ops are not supported by this interpreter. "
|
||||
"Make sure you invoke the Flex delegate before inference.");
|
||||
"Make sure you apply/link the Flex delegate before inference.");
|
||||
} else {
|
||||
ReportError("Encountered unresolved custom op: %s.",
|
||||
op_reg.custom_name);
|
||||
|
@ -0,0 +1,7 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "flex_delegate",
|
||||
srcs = ["FlexDelegate.java"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
@ -0,0 +1,69 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.flex;
|
||||
|
||||
import java.io.Closeable;
|
||||
import org.tensorflow.lite.Delegate;
|
||||
import org.tensorflow.lite.annotations.UsedByReflection;
|
||||
|
||||
/** {@link Delegate} for using select TensorFlow ops. */
|
||||
@UsedByReflection("Interpreter")
|
||||
public class FlexDelegate implements Delegate, Closeable {
|
||||
|
||||
private static final long INVALID_DELEGATE_HANDLE = 0;
|
||||
private static final String TFLITE_FLEX_LIB = "tensorflowlite_flex_jni";
|
||||
|
||||
private long delegateHandle;
|
||||
|
||||
@UsedByReflection("Interpreter")
|
||||
public FlexDelegate() {
|
||||
delegateHandle = nativeCreateDelegate();
|
||||
}
|
||||
|
||||
@Override
|
||||
@UsedByReflection("Interpreter")
|
||||
public long getNativeHandle() {
|
||||
return delegateHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Releases native resources held by the delegate.
|
||||
*
|
||||
* <p>User is expected to call this method explicitly.
|
||||
*/
|
||||
@Override
|
||||
@UsedByReflection("Interpreter")
|
||||
public void close() {
|
||||
if (delegateHandle != INVALID_DELEGATE_HANDLE) {
|
||||
nativeDeleteDelegate(delegateHandle);
|
||||
delegateHandle = INVALID_DELEGATE_HANDLE;
|
||||
}
|
||||
}
|
||||
|
||||
public static void initTensorFlowForTesting() {
|
||||
nativeInitTensorFlow();
|
||||
}
|
||||
|
||||
static {
|
||||
System.loadLibrary(TFLITE_FLEX_LIB);
|
||||
}
|
||||
|
||||
private static native long nativeInitTensorFlow();
|
||||
|
||||
private static native long nativeCreateDelegate();
|
||||
|
||||
private static native void nativeDeleteDelegate(long delegateHandle);
|
||||
}
|
25
tensorflow/lite/delegates/flex/java/src/main/native/BUILD
Normal file
25
tensorflow/lite/delegates/flex/java/src/main/native/BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
# Description:
|
||||
# Java Native Interface (JNI) library intended for implementing the
|
||||
# TensorFlow Lite Flex delegate for using TensorFlow ops with TensorFlow Lite.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
cc_library(
|
||||
name = "native",
|
||||
srcs = ["flex_delegate_jni.cc"],
|
||||
copts = tflite_copts(),
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/java/jni",
|
||||
"//tensorflow/lite/testing:init_tensorflow",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,17 +15,31 @@ limitations under the License.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/testing/init_tensorflow.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_lite_TensorFlowLite_initTensorFlow(
|
||||
JNIEnv* env, jclass clazz) {
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_flex_FlexDelegate_nativeInitTensorFlow(JNIEnv* env,
|
||||
jclass clazz) {
|
||||
::tflite::InitTensorFlow();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_tensorflow_lite_flex_FlexDelegate_nativeCreateDelegate(JNIEnv* env,
|
||||
jclass clazz) {
|
||||
return reinterpret_cast<jlong>(tflite::FlexDelegate::Create().release());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_flex_FlexDelegate_nativeDeleteDelegate(
|
||||
JNIEnv* env, jclass clazz, jlong delegate) {
|
||||
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
@ -13,6 +13,7 @@ package(
|
||||
|
||||
JAVA_SRCS = glob([
|
||||
"src/main/java/org/tensorflow/lite/*.java",
|
||||
"src/main/java/org/tensorflow/lite/annotations/*.java",
|
||||
]) + ["//tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi:nnapi_delegate_src"]
|
||||
|
||||
# Building tensorflow-lite.aar including 4 variants of .so
|
||||
@ -24,10 +25,20 @@ aar_with_jni(
|
||||
android_library = ":tensorflowlite",
|
||||
)
|
||||
|
||||
# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite.
|
||||
# EXPERIMENTAL: AAR target for using TensorFlow ops with TFLite. Note that this
|
||||
# .aar contains *only* the Flex delegate for using select tf ops; clients must
|
||||
# also include the core `tensorflow-lite` runtime.
|
||||
aar_with_jni(
|
||||
name = "tensorflow-lite-select-tf-ops",
|
||||
android_library = ":tensorflowlite_flex",
|
||||
)
|
||||
|
||||
# DEPRECATED: AAR target that supports TensorFlow op execution with TFLite.
|
||||
# Please use `tensorflowlite-select-tf-ops` instead (along with the standard
|
||||
# `tensorflowlite` AAR).
|
||||
aar_with_jni(
|
||||
name = "tensorflow-lite-with-select-tf-ops",
|
||||
android_library = ":tensorflowlite_flex",
|
||||
android_library = ":tensorflowlite_flex_deprecated",
|
||||
)
|
||||
|
||||
# EXPERIMENTAL: AAR target for GPU acceleration. Note that this .aar contains
|
||||
@ -50,12 +61,31 @@ android_library(
|
||||
)
|
||||
|
||||
# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
|
||||
# Note that this library contains *only* the Flex delegate and its Java wrapper for using
|
||||
# select TF ops; clients must also include the core `tensorflowlite` runtime.
|
||||
android_library(
|
||||
name = "tensorflowlite_flex",
|
||||
srcs = JAVA_SRCS,
|
||||
srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"],
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["proguard.flags"],
|
||||
deps = [
|
||||
":tensorflowlite_java",
|
||||
":tensorflowlite_native_flex",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
# DEPRECATED: Android target that supports TensorFlow op execution with TFLite.
|
||||
# Please use `tensorflowlite_flex`.
|
||||
android_library(
|
||||
name = "tensorflowlite_flex_deprecated",
|
||||
srcs = JAVA_SRCS + [
|
||||
"//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate",
|
||||
],
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["proguard.flags"],
|
||||
deps = [
|
||||
":tensorflowlite",
|
||||
":tensorflowlite_native_flex",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
@ -98,10 +128,11 @@ java_library(
|
||||
# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite.
|
||||
java_library(
|
||||
name = "tensorflowlitelib_flex",
|
||||
srcs = JAVA_SRCS,
|
||||
srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"],
|
||||
javacopts = JAVACOPTS,
|
||||
deps = [
|
||||
":libtensorflowlite_flex_jni.so",
|
||||
":tensorflowlitelib",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
@ -219,7 +250,10 @@ java_test(
|
||||
java_test(
|
||||
name = "InterpreterFlexTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"],
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterFlexTest.java",
|
||||
"src/test/java/org/tensorflow/lite/TestUtils.java",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/multi_add_flex.bin",
|
||||
],
|
||||
@ -234,6 +268,7 @@ java_test(
|
||||
test_class = "org.tensorflow.lite.InterpreterFlexTest",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":tensorflowlitelib",
|
||||
":tensorflowlitelib_flex",
|
||||
"@com_google_truth",
|
||||
"@junit",
|
||||
@ -265,11 +300,20 @@ filegroup(
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java",
|
||||
"src/test/java/org/tensorflow/lite/InterpreterTest.java",
|
||||
"src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java",
|
||||
"src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "portable_flex_tests",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterFlexTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflowlite_jni",
|
||||
srcs = select({
|
||||
@ -310,10 +354,7 @@ tflite_jni_binary(
|
||||
tflite_jni_binary(
|
||||
name = "libtensorflowlite_flex_jni.so",
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/delegates/nnapi/java/src/main/native",
|
||||
"//tensorflow/lite/java/src/main/native",
|
||||
"//tensorflow/lite/java/src/main/native:init_tensorflow",
|
||||
"//tensorflow/lite/delegates/flex/java/src/main/native",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,3 +1,9 @@
|
||||
-keepclassmembers class org.tensorflow.lite.NativeInterpreterWrapper {
|
||||
private long inferenceDurationNanoseconds;
|
||||
}
|
||||
}
|
||||
|
||||
-keep class org.tensorflow.lite.annotations.UsedByReflection
|
||||
-keep @org.tensorflow.lite.annotations.UsedByReflection class *
|
||||
-keepclassmembers class * {
|
||||
@org.tensorflow.lite.annotations.UsedByReflection *;
|
||||
}
|
||||
|
@ -85,7 +85,20 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
|
||||
delegates.add(delegate);
|
||||
}
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
|
||||
try {
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
} catch (IllegalStateException e) {
|
||||
// Only try flex delegate usage if allocation fails. This avoids unnecessary creation of the
|
||||
// flex delegate, which can be expensive.
|
||||
optionalFlexDelegate = maybeCreateFlexDelegate();
|
||||
if (optionalFlexDelegate != null) {
|
||||
applyDelegate(interpreterHandle, errorHandle, optionalFlexDelegate.getNativeHandle());
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
this.isMemoryAllocated = true;
|
||||
}
|
||||
|
||||
@ -118,6 +131,14 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
optionalNnApiDelegate.close();
|
||||
optionalNnApiDelegate = null;
|
||||
}
|
||||
if (optionalFlexDelegate instanceof AutoCloseable) {
|
||||
try {
|
||||
((AutoCloseable) optionalFlexDelegate).close();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to close flex delegate: " + e);
|
||||
}
|
||||
}
|
||||
optionalFlexDelegate = null;
|
||||
}
|
||||
|
||||
/** Sets inputs, runs model inference and returns outputs. */
|
||||
@ -319,6 +340,16 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
private static Delegate maybeCreateFlexDelegate() {
|
||||
try {
|
||||
Class<?> clazz = Class.forName("org.tensorflow.lite.flex.FlexDelegate");
|
||||
return (Delegate) clazz.getConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
// The error will propagate when tensors are allocated.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
|
||||
|
||||
private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx);
|
||||
@ -355,6 +386,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
// NNAPI is enabled via Interpreter.Options.
|
||||
private NnApiDelegate optionalNnApiDelegate;
|
||||
|
||||
// Only used if 1) flex ops are used, and 2) the flex delegate is available.
|
||||
private Delegate optionalFlexDelegate;
|
||||
|
||||
private static native long allocateTensors(long interpreterHandle, long errorHandle);
|
||||
|
||||
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
|
||||
|
@ -18,8 +18,7 @@ package org.tensorflow.lite;
|
||||
/** Static utility methods loading the TensorFlowLite runtime. */
|
||||
public final class TensorFlowLite {
|
||||
|
||||
private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
|
||||
private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
|
||||
private static final String LIBNAME = "tensorflowlite_jni";
|
||||
|
||||
private TensorFlowLite() {}
|
||||
|
||||
@ -39,36 +38,19 @@ public final class TensorFlowLite {
|
||||
/** Returns the version of the underlying TensorFlowLite model schema. */
|
||||
public static native String schemaVersion();
|
||||
|
||||
/**
|
||||
* Initialize tensorflow's libraries. This will throw an exception if used when TensorFlow isn't
|
||||
* linked in.
|
||||
*/
|
||||
static native void initTensorFlow();
|
||||
|
||||
/**
|
||||
* Load the TensorFlowLite runtime C library.
|
||||
*
|
||||
* @hide
|
||||
*/
|
||||
public static boolean init() {
|
||||
Throwable primaryLibException;
|
||||
try {
|
||||
System.loadLibrary(PRIMARY_LIBNAME);
|
||||
System.loadLibrary(LIBNAME);
|
||||
return true;
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
primaryLibException = e;
|
||||
System.err.println("TensorFlowLite: failed to load native library: " + e);
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
System.loadLibrary(FALLBACK_LIBNAME);
|
||||
return true;
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
// If the fallback fails, log the error for the primary load instead.
|
||||
System.err.println(
|
||||
"TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static {
|
||||
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.annotations;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotation used for marking methods and fields that are called by reflection. Useful for keeping
|
||||
* components that would otherwise be removed by Proguard. Use the value parameter to mention a file
|
||||
* that calls this method.
|
||||
*
|
||||
* @hide
|
||||
*/
|
||||
@Target({ElementType.METHOD, ElementType.FIELD, ElementType.TYPE, ElementType.CONSTRUCTOR})
|
||||
public @interface UsedByReflection {
|
||||
String value();
|
||||
}
|
@ -35,19 +35,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "init_tensorflow",
|
||||
srcs = [
|
||||
"init_tensorflow_jni.cc",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/java/jni",
|
||||
"//tensorflow/lite/testing:init_tensorflow",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# This includes all ops. If you want a smaller binary, you should copy and
|
||||
# modify builtin_ops_jni.cc. You should then link your binary against both
|
||||
# ":native_framework_only" and your own version of ":native_builtin_ops".
|
||||
|
@ -63,6 +63,11 @@ BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; }
|
||||
|
||||
int BufferErrorReporter::Report(const char* format, va_list args) {
|
||||
int size = 0;
|
||||
// If an error has already been logged, insert a newline.
|
||||
if (start_idx_ > 0 && start_idx_ < end_idx_) {
|
||||
buffer_[start_idx_++] = '\n';
|
||||
++size;
|
||||
}
|
||||
if (start_idx_ < end_idx_) {
|
||||
size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args);
|
||||
}
|
||||
|
@ -17,12 +17,13 @@ package org.tensorflow.lite;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.flex.FlexDelegate;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
|
||||
@ -31,30 +32,60 @@ import org.junit.runners.JUnit4;
|
||||
@RunWith(JUnit4.class)
|
||||
public final class InterpreterFlexTest {
|
||||
|
||||
private static final File FLEX_MODEL_FILE =
|
||||
new File("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
private static final ByteBuffer FLEX_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
|
||||
/** Smoke test validating that flex model loading works when the flex delegate is linked. */
|
||||
/** Smoke test validating that flex model loading works when the flex delegate is used. */
|
||||
@Test
|
||||
public void testFlexModel() throws Exception {
|
||||
try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
|
||||
assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
|
||||
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(2).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(3).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getOutputTensorCount()).isEqualTo(2);
|
||||
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getOutputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
Object[] inputs = new Object[] {new float[1], new float[1], new float[1], new float[1]};
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, new float[1]);
|
||||
outputs.put(1, new float[1]);
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs);
|
||||
FlexDelegate delegate = new FlexDelegate();
|
||||
Interpreter.Options options = new Interpreter.Options().addDelegate(delegate);
|
||||
try (Interpreter interpreter = new Interpreter(FLEX_MODEL_BUFFER, options)) {
|
||||
testCommon(interpreter);
|
||||
} finally {
|
||||
delegate.close();
|
||||
}
|
||||
}
|
||||
|
||||
/** Smoke test validating that flex model loading works when the flex delegate is linked. */
|
||||
@Test
|
||||
public void testFlexModelDelegateAutomaticallyApplied() throws Exception {
|
||||
try (Interpreter interpreter = new Interpreter(FLEX_MODEL_BUFFER)) {
|
||||
testCommon(interpreter);
|
||||
}
|
||||
}
|
||||
|
||||
private static void testCommon(Interpreter interpreter) {
|
||||
assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
|
||||
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(2).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getInputTensor(3).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getOutputTensorCount()).isEqualTo(2);
|
||||
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
assertThat(interpreter.getOutputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
|
||||
|
||||
float[] input1 = {1};
|
||||
float[] input2 = {2};
|
||||
float[] input3 = {3};
|
||||
float[] input4 = {5};
|
||||
Object[] inputs = new Object[] {input1, input2, input3, input4};
|
||||
|
||||
float[] parsedOutput1 = new float[1];
|
||||
float[] parsedOutput2 = new float[1];
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, parsedOutput1);
|
||||
outputs.put(1, parsedOutput2);
|
||||
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs);
|
||||
|
||||
float[] expectedOutput1 = {6};
|
||||
float[] expectedOutput2 = {10};
|
||||
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expectedOutput1).inOrder();
|
||||
assertThat(parsedOutput2).usingTolerance(0.1f).containsExactly(expectedOutput2).inOrder();
|
||||
}
|
||||
|
||||
static {
|
||||
TensorFlowLite.initTensorFlow();
|
||||
FlexDelegate.initTensorFlowForTesting();
|
||||
}
|
||||
}
|
||||
|
@ -54,8 +54,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter. Make '
|
||||
'sure you invoke the Flex delegate before inference.',
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
|
||||
def testDeprecatedFlags(self):
|
||||
@ -84,8 +83,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter. Make '
|
||||
'sure you invoke the Flex delegate before inference.',
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
|
||||
|
||||
@ -111,8 +109,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter. Make '
|
||||
'sure you invoke the Flex delegate before inference.',
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
|
||||
|
||||
|
@ -474,8 +474,7 @@ class TestFlexMode(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter. Make '
|
||||
'sure you invoke the Flex delegate before inference.',
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
|
||||
@test_util.run_v2_only
|
||||
@ -499,8 +498,7 @@ class TestFlexMode(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter. Make '
|
||||
'sure you invoke the Flex delegate before inference.',
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
|
||||
|
||||
|
@ -344,6 +344,7 @@ cc_library(
|
||||
"init_tensorflow.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/lite/delegates/flex:__subpackages__",
|
||||
"//tensorflow/lite/java/src/main/native:__subpackages__",
|
||||
"//tensorflow/lite/testing:__subpackages__",
|
||||
"//tensorflow/lite/tools/benchmark:__subpackages__",
|
||||
|
Loading…
Reference in New Issue
Block a user