From 5254b55e35d366eecb05aa2582670857fd14ea59 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 Jun 2020 02:42:05 -0700 Subject: [PATCH] Add C++ and Java APIs for GPU delegate whitelisting. The GPU delegate is not supported on all Android devices due to differences in OpenGL ES versions, driver capabilities and stability. This change provides a whitelist that can be used to detect whether the delegate is supported on the current device. PiperOrigin-RevId: 314495681 Change-Id: I86c444188ebf998d6cfb1ea27428ce0900e926db --- .../main/java/org/tensorflow/lite/gpu/BUILD | 5 +- .../org/tensorflow/lite/gpu/Whitelist.java | 93 ++++++++++ .../delegates/gpu/java/src/main/native/BUILD | 6 + .../java/src/main/native/gpu_delegate_jni.cc | 66 +++++++ .../lite/experimental/acceleration/README.md | 15 ++ .../experimental/acceleration/whitelist/BUILD | 157 ++++++++++++++++ .../acceleration/whitelist/README.md | 13 ++ .../acceleration/whitelist/android_info.cc | 52 ++++++ .../acceleration/whitelist/android_info.h | 43 +++++ .../acceleration/whitelist/database.fbs | 58 ++++++ .../whitelist/devicedb-sample.json | 169 ++++++++++++++++++ .../acceleration/whitelist/devicedb.cc | 91 ++++++++++ .../acceleration/whitelist/devicedb.h | 38 ++++ .../acceleration/whitelist/devicedb_test.cc | 142 +++++++++++++++ .../acceleration/whitelist/gpu_whitelist.bin | Bin 0 -> 33604 bytes .../acceleration/whitelist/gpu_whitelist.cc | 99 ++++++++++ .../acceleration/whitelist/gpu_whitelist.h | 85 +++++++++ .../acceleration/whitelist/json_to_fb.cc | 92 ++++++++++ .../acceleration/whitelist/variables.h | 87 +++++++++ tensorflow/lite/java/BUILD | 1 + .../tensorflow/lite/gpu/WhitelistTest.java | 34 ++++ 21 files changed, 1345 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/Whitelist.java create mode 100644 tensorflow/lite/experimental/acceleration/README.md create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/BUILD create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/README.md create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/android_info.cc create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/android_info.h create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/database.fbs create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/devicedb-sample.json create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/devicedb.cc create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/devicedb.h create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/devicedb_test.cc create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.bin create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.cc create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.h create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/json_to_fb.cc create mode 100644 tensorflow/lite/experimental/acceleration/whitelist/variables.h create mode 100644 tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/WhitelistTest.java diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD index ab2ad036d66..fbd7f09ce65 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD @@ -4,6 +4,9 @@ package( filegroup( name = "gpu_delegate", - srcs = ["GpuDelegate.java"], + srcs = [ + "GpuDelegate.java", + "Whitelist.java", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/Whitelist.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/Whitelist.java new file mode 100644 index 00000000000..c0b3bf2ca37 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/Whitelist.java @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.gpu; + +import java.io.Closeable; + +/** + * GPU Delegate Whitelisting data. + * + *

The GPU delegate is not supported on all Android devices, due to differences in available + * OpenGL versions, driver features, and device resources. This class provides information on + * whether the GPU delegate is suitable for the current device. + * + *

This API is experimental and subject to change. + * + *

WARNING: the whitelist is constructed from testing done on a limited set of models. You + * should plan to verify that your own model(s) work. + * + *

Example usage: + * + *

{@code
+ * Interpreter.Options options = new Interpreter.Options();
+ * try (Whitelist whitelist = new Whitelist()) {
+ *   if (whitelist.isDelegateSupportedOnThisDevice()) {
+ *     GpuDelegate.Options delegateOptions = whitelist.getBestOptionsForThisDevice();
+ *     gpuDelegate = new GpuDelegate(delegateOptions):
+ *     options.addDelegate(gpuDelegate);
+ *   }
+ * }
+ * Interpreter interpreter = new Interpreter(modelBuffer, options);
+ * }
+ */ +public class Whitelist implements Closeable { + + private static final long INVALID_WHITELIST_HANDLE = 0; + private static final String TFLITE_GPU_LIB = "tensorflowlite_gpu_jni"; + + private long whitelistHandle = INVALID_WHITELIST_HANDLE; + + /** Whether the GPU delegate is supported on this device. */ + public boolean isDelegateSupportedOnThisDevice() { + if (whitelistHandle == INVALID_WHITELIST_HANDLE) { + throw new IllegalStateException("Trying to query a closed whitelist."); + } + return nativeIsDelegateSupportedOnThisDevice(whitelistHandle); + } + + /** What options should be used for the GPU delegate. */ + public GpuDelegate.Options getBestOptionsForThisDevice() { + // For forward compatibility, when the whitelist contains more information. + return new GpuDelegate.Options(); + } + + public Whitelist() { + whitelistHandle = createWhitelist(); + } + + /** + * Frees TFLite resources in C runtime. + * + *

User is expected to call this method explicitly. + */ + @Override + public void close() { + if (whitelistHandle != INVALID_WHITELIST_HANDLE) { + deleteWhitelist(whitelistHandle); + whitelistHandle = INVALID_WHITELIST_HANDLE; + } + } + + static { + System.loadLibrary(TFLITE_GPU_LIB); + } + + private static native long createWhitelist(); + + private static native void deleteWhitelist(long handle); + + private static native boolean nativeIsDelegateSupportedOnThisDevice(long handle); +} diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD index 774fd417758..57d6e013a4a 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD +++ b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD @@ -26,7 +26,13 @@ cc_library( ], deps = [ "//tensorflow/lite/delegates/gpu:delegate", + "//tensorflow/lite/delegates/gpu/common:gpu_info", + "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:request_gpu_info", + "//tensorflow/lite/experimental/acceleration/whitelist:android_info", + "//tensorflow/lite/experimental/acceleration/whitelist:gpu_whitelist", "//tensorflow/lite/java/jni", + "@com_google_absl//absl/status", ], alwayslink = 1, ) diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc index 900cc0e0d75..017ffcfcd32 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc +++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc @@ -15,7 +15,13 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" +#include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/android_info.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.h" #ifdef __cplusplus extern "C" { @@ -44,6 +50,66 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_deleteDelegate( TfLiteGpuDelegateV2Delete(reinterpret_cast(delegate)); } +namespace { +class WhitelistHelper { + public: + absl::Status ReadInfo() { + auto status = tflite::acceleration::RequestAndroidInfo(&android_info_); + if (!status.ok()) return status; + + if (android_info_.android_sdk_version < "21") { + // Weakly linked symbols may not be available on pre-21, and the GPU is + // not supported anyway so return early. + return absl::OkStatus(); + } + + std::unique_ptr env; + status = tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env); + if (!status.ok()) return status; + + status = tflite::gpu::gl::RequestGpuInfo(&gpu_info_); + if (!status.ok()) return status; + + return absl::OkStatus(); + } + + bool IsDelegateSupportedOnThisDevice() { + return whitelist_.Includes(android_info_, gpu_info_); + } + + private: + tflite::acceleration::AndroidInfo android_info_; + tflite::gpu::GpuInfo gpu_info_; + tflite::acceleration::GPUWhitelist whitelist_; +}; +} // namespace + +JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_Whitelist_createWhitelist( + JNIEnv* env, jclass clazz) { + WhitelistHelper* whitelist = new WhitelistHelper; + auto status = whitelist->ReadInfo(); + // Errors in ReadInfo should almost always be failures to construct the OpenGL + // environment. Treating that as "GPU unsupported" is reasonable, and we can + // swallow the error. + status.IgnoreError(); + return reinterpret_cast(whitelist); +} + +JNIEXPORT jboolean JNICALL +Java_org_tensorflow_lite_gpu_Whitelist_nativeIsDelegateSupportedOnThisDevice( + JNIEnv* env, jclass clazz, jlong whitelist_handle) { + WhitelistHelper* whitelist = + reinterpret_cast(whitelist_handle); + return whitelist->IsDelegateSupportedOnThisDevice() ? JNI_TRUE : JNI_FALSE; +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_gpu_Whitelist_deleteWhitelist( + JNIEnv* env, jclass clazz, jlong whitelist_handle) { + WhitelistHelper* whitelist = + reinterpret_cast(whitelist_handle); + delete whitelist; +} + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/experimental/acceleration/README.md b/tensorflow/lite/experimental/acceleration/README.md new file mode 100644 index 00000000000..c3209fe99e9 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/README.md @@ -0,0 +1,15 @@ +# Accelerator whitelisting + +Experimental library and tools for determining whether an accelerator engine +works well on a given device, and for a given model. + +## Platform-agnostic, Android-first + +Android-focused, since the much smaller set of configurations on iOS means there +is much less need for whitelisting on iOS. + +## Not just for TfLite + +This code lives in the TfLite codebase, since TfLite is the first open-source +customer. It is however meant to support other users (direct use of NNAPI, +mediapipe). diff --git a/tensorflow/lite/experimental/acceleration/whitelist/BUILD b/tensorflow/lite/experimental/acceleration/whitelist/BUILD new file mode 100644 index 00000000000..96c3a6da8c0 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/BUILD @@ -0,0 +1,157 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +flatbuffer_cc_library( + name = "database_fbs", + srcs = ["database.fbs"], +) + +cc_library( + name = "devicedb", + srcs = [ + "devicedb.cc", + ], + hdrs = [ + "devicedb.h", + "variables.h", + ], + deps = [ + ":database_fbs", + ], +) + +cc_binary( + name = "json_to_fb", + srcs = ["json_to_fb.cc"], + deps = [ + "//tensorflow/lite/tools:command_line_flags", + "@flatbuffers", + ], +) + +genrule( + name = "devicedb-sample_bin", + srcs = [ + "database.fbs", + "devicedb-sample.json", + ], + outs = ["devicedb-sample.bin"], + cmd = """ + $(location :json_to_fb) \ + --fbs=$(location :database.fbs) \ + --json_input=$(location :devicedb-sample.json) \ + --fb_output=$(@) + """, + tools = [":json_to_fb"], +) + +genrule( + name = "devicedb-sample_cc", + srcs = ["devicedb-sample.bin"], + outs = [ + "devicedb-sample.cc", + "devicedb-sample.h", + ], + # convert_file_to_c_source for some reason doesn't define the global with + # 'extern', which is needed for global const variables in C++. + cmd = """ + $(location //tensorflow/lite/python:convert_file_to_c_source) \ + --input_tflite_file $(location :devicedb-sample.bin) \ + --output_header_file $(location :devicedb-sample.h) \ + --output_source_file $(location :devicedb-sample.cc) \ + --array_variable_name g_tflite_acceleration_devicedb_sample_binary + perl -p -i -e 's/const unsigned char/extern const unsigned char/' $(location :devicedb-sample.cc) + """, + tools = ["//tensorflow/lite/python:convert_file_to_c_source"], +) + +cc_test( + name = "devicedb_test", + srcs = [ + "devicedb-sample.cc", + "devicedb-sample.h", + "devicedb_test.cc", + ], + deps = [ + ":database_fbs", + ":devicedb", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +genrule( + name = "gpu_whitelist_binary", + srcs = ["gpu_whitelist.bin"], + outs = [ + "gpu_whitelist_binary.h", + "gpu_whitelist_binary.cc", + ], + # convert_file_to_c_source for some reason doesn't define the global with + # 'extern', which is needed for global const variables in C++. + cmd = """ + $(location //tensorflow/lite/python:convert_file_to_c_source) \ + --input_tflite_file $(location :gpu_whitelist.bin) \ + --output_header_file $(location :gpu_whitelist_binary.h) \ + --output_source_file $(location :gpu_whitelist_binary.cc) \ + --array_variable_name g_tflite_acceleration_gpu_whitelist_binary + perl -p -i -e 's/const unsigned char/extern const unsigned char/' $(location :gpu_whitelist_binary.cc) + """, + tools = ["//tensorflow/lite/python:convert_file_to_c_source"], +) + +cc_library( + name = "android_info", + srcs = ["android_info.cc"], + hdrs = ["android_info.h"], + deps = [ + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "gpu_whitelist", + srcs = [ + "gpu_whitelist.cc", + "gpu_whitelist_binary.cc", + "gpu_whitelist_binary.h", + ], + hdrs = [ + "gpu_whitelist.h", + ], + deps = [ + ":android_info", + ":database_fbs", + ":devicedb", + "//tensorflow/lite/delegates/gpu:delegate", + "//tensorflow/lite/delegates/gpu/common:gpu_info", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@flatbuffers", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/experimental/acceleration/whitelist/README.md b/tensorflow/lite/experimental/acceleration/whitelist/README.md new file mode 100644 index 00000000000..24ee794aef6 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/README.md @@ -0,0 +1,13 @@ +# GPU delegate whitelist + +This package provides data and code for deciding if the GPU delegate is +supported on a specific Android device. + +## Customizing the GPU whitelist + +- Convert from checked-in flatbuffer to json by running `flatc -t --raw-binary + --strict-json database.fbs -- gpu_whitelist.bin` +- Edit the json +- Convert from json to flatbuffer `flatc -b database.fbs -- + gpu_whitelist.json` +- Rebuild ../../../java:tensorflow-lite-gpu diff --git a/tensorflow/lite/experimental/acceleration/whitelist/android_info.cc b/tensorflow/lite/experimental/acceleration/whitelist/android_info.cc new file mode 100644 index 00000000000..4618ac90807 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/android_info.cc @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/acceleration/whitelist/android_info.h" + +#include + +#include "absl/status/status.h" + +#ifdef __ANDROID__ +#include +#endif // __ANDROID__ + +namespace { +std::string GetPropertyValue(const std::string& property) { +#ifdef __ANDROID__ + char value[PROP_VALUE_MAX]; + __system_property_get(property.c_str(), value); + return std::string(value); +#else // !__ANDROID__ + return std::string(); +#endif // __ANDROID__ +} +} // namespace + +namespace tflite { +namespace acceleration { + +absl::Status RequestAndroidInfo(AndroidInfo* info_out) { + if (!info_out) { + return absl::InvalidArgumentError("info_out may not be null"); + } + info_out->android_sdk_version = GetPropertyValue("ro.build.version.sdk"); + info_out->device = GetPropertyValue("ro.product.device"); + info_out->model = GetPropertyValue("ro.product.model"); + info_out->manufacturer = GetPropertyValue("ro.product.manufacturer"); + return absl::OkStatus(); +} + +} // namespace acceleration +} // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/whitelist/android_info.h b/tensorflow/lite/experimental/acceleration/whitelist/android_info.h new file mode 100644 index 00000000000..81b3ee7479c --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/android_info.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_ANDROID_INFO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_ANDROID_INFO_H_ + +#include + +#include "absl/status/status.h" + +namespace tflite { +namespace acceleration { + +// Information about and Android device, used for determining whitelisting +// status. +struct AndroidInfo { + // Property ro.build.version.sdk + std::string android_sdk_version; + // Property ro.product.model + std::string model; + // Property ro.product.device + std::string device; + // Property ro.product.manufacturer + std::string manufacturer; +}; + +absl::Status RequestAndroidInfo(AndroidInfo* info_out); + +} // namespace acceleration +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_ANDROID_INFO_H_ diff --git a/tensorflow/lite/experimental/acceleration/whitelist/database.fbs b/tensorflow/lite/experimental/acceleration/whitelist/database.fbs new file mode 100644 index 00000000000..6340fcfcf3a --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/database.fbs @@ -0,0 +1,58 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite.acceleration; + +enum Comparison : byte { + EQUAL = 0, + MINIMUM = 1, +} + +// Mapping from available device features to whitelisting decisions. Basic usage is to: +// 1) Map easily available device data (like Android version, +// Manufacturer, Device) to things like SoC vendor, SoC model. +// 2) Map complete device data to delegate-specific features and support status +// 3) Map delegate-specific features to delegate configuration. +// +// The structure describes a decision tree, with multiple matching branches. +// The branches are applied depth-first. +table DeviceDatabase { + root:[tflite.acceleration.DeviceDecisionTreeNode]; +} + +table DeviceDecisionTreeNode { + // The variables are strings, as we have multiple clients that want to + // introduce their own fields. Known variables are listed in variables.h. + variable:string (shared); + comparison:tflite.acceleration.Comparison; + items:[tflite.acceleration.DeviceDecisionTreeEdge]; +} + +table DeviceDecisionTreeEdge { + // Under which variable value does this item match. + value:string (key, shared); + // Which child branches should also be consulted and used to override this + // node. + children:[tflite.acceleration.DeviceDecisionTreeNode]; + // What information can be derived about this device. + derived_properties:[tflite.acceleration.DerivedProperty]; +} + +// Derived variable value to combine with detected variables. +table DerivedProperty { + variable:string (shared); + value:string (shared); +} + +root_type DeviceDatabase; diff --git a/tensorflow/lite/experimental/acceleration/whitelist/devicedb-sample.json b/tensorflow/lite/experimental/acceleration/whitelist/devicedb-sample.json new file mode 100644 index 00000000000..187989673d1 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/devicedb-sample.json @@ -0,0 +1,169 @@ +{ + "root": [ + { + "variable": "tflite.device_model", + "items": [ + { + "value": "m712c", + "derived_properties": [ + { + "variable": "tflite.soc_model", + "value": "exynos_7872" + } + ] + }, + { + "value": "sc_02l", + "derived_properties": [ + { + "variable": "tflite.soc_model", + "value": "exynos_7885" + } + ] + } + ] + }, + { + "variable": "tflite.opengl_es_version", + "items": [ + { + "value": "3.1", + "children": [ + { + "variable": "tflite.soc_model", + "items": [ + { + "value": "exynos_7872", + "children": [ + { + "variable": "tflite.android_sdk_version", + "items": [ + { + "value": "24", + "derived_properties": [ + { + "variable": "tflite.gpu.status", + "value": "WHITELISTED" + } + ] + } + ], + "comparison": "MINIMUM" + } + ] + }, + { + "value": "exynos_7883", + "children": [ + { + "variable": "tflite.android_sdk_version", + "items": [ + { + "value": "28", + "derived_properties": [ + { + "variable": "tflite.gpu.status", + "value": "WHITELISTED" + } + ] + } + ], + "comparison": "MINIMUM" + } + ] + } + ] + } + + ] + } + ] + }, + { + "variable": "tflite.android_sdk_version", + "items": [ + { + "value": "21", + "children": [ + { + "variable": "tflite.device_model", + "items": [ + { + "value": "huawei_gra_l09", + "children": [ + { + "variable": "tflite.device_name", + "items": [ + { + "value": "hwgra", + "derived_properties": [ + { + "variable": "tflite.gpu.status", + "value": "WHITELISTED" + }, + { + "variable": "tflite.gpu.opencl_status", + "value": "WHITELISTED" + } + ] + } + ] + } + ] + } + ] + } + ] + }, + { + "value": "24", + "children": [ + { + "variable": "tflite.device_model", + "items": [ + { + "value": "sm_j810f", + "children": [ + { + "variable": "tflite.device_name", + "items": [ + { + "value": "j8y18lte", + "derived_properties": [ + { + "variable": "tflite.gpu.status", + "value": "BLACKLISTED" + } + ] + } + ] + } + ] + }, + { + "value": "sm_j810m", + "children": [ + { + "variable": "tflite.device_name", + "items": [ + { + "value": "j8y18lte", + "derived_properties": [ + { + "variable": "tflite.gpu.opencl_status", + "value": "WHITELISTED" + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] +} diff --git a/tensorflow/lite/experimental/acceleration/whitelist/devicedb.cc b/tensorflow/lite/experimental/acceleration/whitelist/devicedb.cc new file mode 100644 index 00000000000..978495a3234 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/devicedb.cc @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/acceleration/whitelist/devicedb.h" + +#include +#include +#include + +#include "tensorflow/lite/experimental/acceleration/whitelist/database_generated.h" + +namespace tflite { +namespace acceleration { +namespace { + +std::vector Find( + const DeviceDecisionTreeNode* root, const std::string& value) { + std::vector found; + if (root->comparison() == Comparison_EQUAL) { + // Exact match. + const DeviceDecisionTreeEdge* possible = + root->items()->LookupByKey(value.c_str()); + if (possible) { + found.push_back(possible); + } + } else { + // Minimum: value should be at least item's value. + for (const DeviceDecisionTreeEdge* item : *(root->items())) { + if (value >= item->value()->str()) { + found.push_back(item); + } + } + } + return found; +} + +void UpdateVariablesFromDeviceDecisionTreeEdges( + std::map* variable_values, + const DeviceDecisionTreeEdge& item) { + if (item.derived_properties()) { + for (const DerivedProperty* p : *(item.derived_properties())) { + (*variable_values)[p->variable()->str()] = p->value()->str(); + } + } +} + +void Follow(const DeviceDecisionTreeNode* root, + std::map* variable_values) { + if (!root->variable()) { + return; + } + auto possible_value = variable_values->find(root->variable()->str()); + if (possible_value == variable_values->end()) { + return; + } + std::vector edges = + Find(root, possible_value->second); + for (const DeviceDecisionTreeEdge* edge : edges) { + UpdateVariablesFromDeviceDecisionTreeEdges(variable_values, *edge); + if (edge->children()) { + for (const DeviceDecisionTreeNode* root : *(edge->children())) { + Follow(root, variable_values); + } + } + } +} + +} // namespace + +void UpdateVariablesFromDatabase( + std::map* variable_values, + const DeviceDatabase& database) { + if (!database.root()) return; + for (const DeviceDecisionTreeNode* root : *(database.root())) { + Follow(root, variable_values); + } +} + +} // namespace acceleration +} // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/whitelist/devicedb.h b/tensorflow/lite/experimental/acceleration/whitelist/devicedb.h new file mode 100644 index 00000000000..74a0d78e44e --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/devicedb.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_DECISION_TREE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_DECISION_TREE_H_ + +#include +#include + +#include "tensorflow/lite/experimental/acceleration/whitelist/database_generated.h" + +namespace tflite { +namespace acceleration { + +// Use the variables in `variable_values` to evaluate the decision tree in +// `database` and update the `variable_values` based on derived properties in +// the decision tree. +// +// See database.fbs for a description of the decision tree. +void UpdateVariablesFromDatabase( + std::map* variable_values, + const DeviceDatabase& database); + +} // namespace acceleration +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_DECISION_TREE_H_ diff --git a/tensorflow/lite/experimental/acceleration/whitelist/devicedb_test.cc b/tensorflow/lite/experimental/acceleration/whitelist/devicedb_test.cc new file mode 100644 index 00000000000..ae020dd7ba2 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/devicedb_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/acceleration/whitelist/devicedb.h" + +#include +#include + +#include +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/experimental/acceleration/whitelist/database_generated.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/devicedb-sample.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/variables.h" +#include "tensorflow/lite/testing/util.h" + +namespace tflite { +namespace acceleration { +namespace { + +class DeviceDbTest : public ::testing::Test { + protected: + void LoadSample() { + device_db_ = flatbuffers::GetRoot( + g_tflite_acceleration_devicedb_sample_binary); + } + + const DeviceDatabase* device_db_ = nullptr; +}; + +TEST_F(DeviceDbTest, Load) { + LoadSample(); + ASSERT_TRUE(device_db_); + ASSERT_TRUE(device_db_->root()); + EXPECT_EQ(device_db_->root()->size(), 3); +} + +TEST_F(DeviceDbTest, SocLookup) { + LoadSample(); + ASSERT_TRUE(device_db_); + std::map variables; + + // Find first device mapping. + variables[kDeviceModel] = "m712c"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[kSoCModel], "exynos_7872"); + + // Find second device mapping. + variables.clear(); + variables[kDeviceModel] = "sc_02l"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[kSoCModel], "exynos_7885"); + + // Make sure no results are returned without a match. + variables.clear(); + variables[kDeviceModel] = "nosuch"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables.find(kSoCModel), variables.end()); +} + +TEST_F(DeviceDbTest, StatusLookupWithSoC) { + LoadSample(); + ASSERT_TRUE(device_db_); + std::map variables; + + // Find exact match. + variables[kOpenGLESVersion] = "3.1"; + variables[kSoCModel] = "exynos_7872"; + variables[kAndroidSdkVersion] = "24"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[gpu::kStatus], gpu::kStatusWhitelisted); + + // Ensure no results without a match. + variables[kOpenGLESVersion] = "3.0"; + variables.erase(variables.find(gpu::kStatus)); + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables.find(gpu::kStatus), variables.end()); + + // Find no results with too low an android version. + variables.clear(); + variables[kOpenGLESVersion] = "3.1"; + variables[kSoCModel] = "exynos_7883"; + variables[kAndroidSdkVersion] = "24"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables.find(gpu::kStatus), variables.end()); + // Find a match with android version above minimum. + variables[kAndroidSdkVersion] = "29"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[gpu::kStatus], gpu::kStatusWhitelisted); +} + +TEST_F(DeviceDbTest, StatusLookupWithDevice) { + LoadSample(); + ASSERT_TRUE(device_db_); + std::map variables; + // Find blacklisted device (same model, different device). + variables[kAndroidSdkVersion] = "24"; + variables[kDeviceModel] = "sm_j810f"; + variables[kDeviceName] = "j8y18lte"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[gpu::kStatus], gpu::kStatusBlacklisted); + + // Find whitelisted device (same model, different device). + variables.clear(); + variables[kAndroidSdkVersion] = "24"; + variables[kDeviceModel] = "sm_j810m"; + variables[kDeviceName] = "j8y18lte"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[gpu::kOpenCLStatus], gpu::kStatusWhitelisted); +} + +TEST_F(DeviceDbTest, StatusLookupBasedOnDerivedProperties) { + LoadSample(); + ASSERT_TRUE(device_db_); + std::map variables; + // Find status based on SoC derived from model. + variables[kOpenGLESVersion] = "3.1"; + variables[kAndroidSdkVersion] = "24"; + variables[kDeviceModel] = "m712c"; + UpdateVariablesFromDatabase(&variables, *device_db_); + EXPECT_EQ(variables[gpu::kStatus], gpu::kStatusWhitelisted); +} + +} // namespace +} // namespace acceleration +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.bin b/tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.bin new file mode 100644 index 0000000000000000000000000000000000000000..5d42a3ec242ac2975e23b45e9f39a1897b002204 GIT binary patch literal 33604 zcmZ{t4_sYmn%CdSFde4DW|sm0y?g2!Ejn}0RPO~JKY2y(aJ zc8DQ9nPWeZV}B;dt-?(tj47%yW`phZ-y7WjbLjVH$2tL*RjMhMVsrtDa#hA!`~ixapJoo<^8a;%E+h9$-r-1-^Xg!4xFc3Idhx13XM zU0h-o-5UK`G<)qY$NnhDeQt@d1vgBLJYvulO;~X{DDW0I2{9?)o|m=^EjN& zO&webo&|tY#+<21_om|Mj&x@-o#=IQ)Srm-*HqPVfmi#nQ&RCOX56WpKkY~fAGu|dT!8%OCEmGoH0o%I&d;+)reDIjh z;dGg|;D+(&24AI2dwrFT>ML(r%3C-}epNY+!wso8r{HFlt{JX?7zVVlou}&k23$z_ zT!mXvK2vbZ^nf{VjqS?s1nc-yxXm8~xw~)=4hOjnxKTV(V4H2-mk$4-M~|*3pK7=n zRi09~T0Ba?UcSmhj7D6i*f*~9=iwTZ??t#~r5lCY;4HokCfE+sIR1Y90i01jKM_vp zmf@!G7z3+p>%RT?Yn=aeLGP#GbU!J8YsRAvIDNZ9jviRN~B_ZBn_r1Gn+3!8rHeic}xigc~77 z9PF^IbMwp}CU?r`SvZ}W5L_6KV!*V>oGEYb?d!;tct1Dg)gh?ID^s5JU17Bvf0;1tQMr)p#j2`u#?MRWb>2RL z+x>09`rNU^T7jE9ZOjOeKIO(D6c;dG9OwFF{z!J5a<~Z<=LI<37i!^3iBSmTI_bo5 z!}#L3*U4_QIZQNj(0u{O=ViDO73VtK@)JSbG~6P+X;#3t&dVoo>$SmSK8Mrg-GUp& zqZ>$hhuV7wJNhc6srSlTam;P&!q>C=({Z>V73UP3j?)ZRK#b`l?mR16E$7ExjA(Vp z?X$z^bX~5?u`0#~mKbwzvu(yq0huF4t7U*nq-B7Tzq$4jCzLK@djoAfpZ?rG`9~Sw zeNAUxm{-8iL!7&_}=T-cv>^LS`IJ zm-8lEj`I5$Zdui53GNDh;~-Pc@=(<)QqB#F?~~)y|Jh*KkHDpsZxLJro|V8mccS<@ z=V&ea8kIf?*Q|VBh0Cfx-1aH%?Z6lDz|>!EC%;o@^q6)FZihP>e_MWH@n40TzyM z`H05Pbfi1_oO>8APIV?uF~?sy5%fF>m&eH_xh()gm=u5;B%%>)e2IAoI0SAWY5|)x zbPDVq;93saz|I$p83lzv`m{_a#PbsV`daf5+-^;<9CzS!yR5)Xwi+`6SkjRDyku{> zV<>*Er?cnmkW*jAhg4NKWXvD8v_IU0it_^8bYU=dEnF!v3W0ZSFyD9YEn3(&s`RhJ zg_ZAHaE(eg3Agflj9CCz*w$m@$8g)#!Lob`cgV{fsdRF<1~-n!5OC&JH7R!ickZahPs#nTt82MvE`*Er&LYy6pEoc|M8O+ncKk?|%jjAAb7;a4I3gMcF zQ45OMUi((?n3HgE<#Pt^iqf^g<>ImPyP12k-9ovfe`nh0Ex2Ok^Ddmu@in*wJZ3;3 z$1+jozkJXiP4o{WyCTkgXL(Ij_;pLnC*I)NSs5(HGjKZQA-G23g)=cLy#Dk5iCN*r zZ00!KzOTX6s+gB8F(=?wNzx*a_Dyv3oF8aEmg-|IMe6PR9ce!naTv}K^!4!moZf@G zT@mCS`;)Zi!65elU7R>AK(3kbL}aM4ERV+8m-_yx?0HNz+_;Ky z98TAx4z2|MxnD4b%b7`ZyVuirS0qsr3Yj+c>m0oS7gF9=E#6ab%fAb6;N&Pi=&ZZS z^;-NZm@_Si|F?prU*j(S3bw# zbbjjKO7JKEGS}E!2V?CkcF=t0O!$?uuyYr;f1PC=wUb?@9<7cu z2-m3Myl#mz0=M}h>p#Gqqqyroawq5IPS#MJ`?wjjx}EO8ZLmJ#Z@YVzIGb=I#E1hK zBlp@#t!WV}C3PzNGud@I4mYG?or2T3YlbTzj!|{;uW1oShC4H2iSsU8zKZjqCC)tD z%oD7$M-ak+()C(>sYP^9|Ml z(^8)%w7M_63RkD%ylIIu4!6e1y#)MmqCLHPeStX9^Te4$ThFJz-9Lj{{ly^n5UxO7 z_jlmZ#ApMVK2}i?>P+I7|7Wwy|2*7?@;wQs^VBFDo->TGavhU(G+E(dx}Ax`dR=Is zjD16BTUEKT;BFrY9`o2#_O<;1x;P#!z{^8d zby-mosoJ%JTwTj?$&!c78i!D;QzBRqcNlE?@b42v?$X^KdhGTmnmM>pI?t zTPg{b@!(smB~lN6xpUxRcu~%ks$@>I#J+AIZF-wybXgBv!TT9ie>h#%Jh;|hCk|km zX5s@wk%3sh)9;7UWmVPXTm#Thok0IN4cDgp2QB__xCh5s2Le>ZBs8lD&~m@jVFQ9m!taggF~#J#(O+wnMLTb`vf~#roJ1YYFZOamIo4-9$8* z>`8X1II;}LIES_#O%CAnaXz2JR<_RU?z`o4D5WBdmA9neCHph-+=9}a4!T3!AvgaJQYfV2Q<_Ks02$~ z3*@??@0*%Av1EO90u+N~D%dTR1T%bO=Fhv7P3PdqH1DMy4%Xp0IGuw^xXNFnZveuX z^T$G_0Z#j!hig!N7cG9Ha2qd>AK>`0E_=!1_YvIgw}a)oWAR&oo5X7*<5v?hn{b)B zOFE?OpZZITugdQzoUU&%oV>#n0@820dU}%`9g+TX=gO*SdZ)}otnf6V)BX2?9IIly zVTmyYC+{#V11VdAe#M6~uNmwQSIABGjCk=W`!jX81GiBWWfOq{>-kH{F&vm*GlO zoa>f2({PJad=^NYzMd{84yT53<8(M<@G?3if&4s_V_BP#<1#sW;xA{{B^R!ZgBt;- zpy^Nbw0A^0D?;R`n0mKck0)A&yS$;p`-cTJql%)DNnOLQjqdmAmq3FPDOHO9`TgIqOSTIovRYKbB5 zEJ+->?_;gKi%rwOKAofUa1F}$BAo6=qi`F~bN&OT{N9?r_%Z}$;yC4>wZ!=ZZvE$j z<^S9gXA5qa7~PpT75jP0)JN_EC2rwg$*#|FxFHqi6r8S4Gh6{N43P5gy??;A7pLo3 zmHP~Q#1iLSxO^4oLra`_xS1a?w*=ywdH;Y-UmWlLAFXcpr`}=wqwqQCGv~%pIDNld z3|CK#kQaw`@2aYCpRv@jPq%v?T$A#B4Nm827;dYMca?yYKk91_uic|ebvonhyM|Ve z-#>-h{JCKH?^@z)z>N|k1-RtyKcA9?1^)XQ(9cd}jp5K=&5n5jF0EpohSPOxfy=Ms zn*FPRYsR_L@tP`mv-OJ-v%24mpqpdc-v;l>v0e<(N&Ul~=9VSaBHT1_CV-R2bazLx zBO<4-d1%q^z@^c)1@%t9KRT0LmtwejwDQi9TsxwUzdKI4&x#`@mEULVM5p`01v&Pq zVEx{(#2ABHe~#-PkhV`nlO6GRBqnudYGj-^?e0YJ23mcc{S;0=qq}d3vkf;yjLbVr zFIQGN^IWE6CegziwTb9Z#3K*=wd^=2;L>Uh{WM%w`@`iE=K=S0GS0V`N1aLT&-&vy zWWO0j+X(OP1Mk7*sW=~5;w->jBgQ0Ap6^>cXW{1XoA!CiG;Yb_c^__x=Nt=R8w@&wv#$J15r z!{=Jjs^gPz`W^IFEpdk6wz&?h0ew$V7Iyn!v@8$Xly|4a zyA>`EUwLOqt^uj)imE7Mii%SuamLWrOPocxX=2DbOLFg$=ud=?Rh37o zIm_g4r%rsQsiup5fmYY&@D1kwPXu#S4X4|r6t0#SC4g>V%BtllwtKyHzoS#*jBm~8 zbd1;JSQX>4CB_8YYBkS@fsD!NzK#ltote)jkCn-{F6e{kWDXP_gCG8S9&7Qj#V)}vcy<`yN3THkh(|=dD0@!p7-N~ z2H9n{(AM+m&(G8E(RcELTsd4?=}O@0h*1iBzCGew&puuLLAXZc`#PM?(Foio{d5J0 zZ>Id>>x0^G2Q z^Aen{PZL}rF><^(^6X3AO!n7@f#ntXdl;=gPp-ois5l>3;>^L#aJ zIHPbIOk!>WE;%Mz6Lm7?)g{U-#BF~|j5~5H*E@e*?pb1N!j0e`2a+$o`RC79w8DFr zA^-1W*W-D(5f$SkoX%IHaQNo}8M}HD9nL}(msr!+5li%VFRi65F|Nw7D#rVk7_)G5 zjJnf6VsxKN#n197bz&w%0X-RBaoL?iKP z-dAv9IDeIunJ_xt-e=@k72|>>MjCD>WXw8{Iwa0kr;=VBn0KG${Wx<8jjqE-aJv-B zUxzyu{}s4Nyhniaq4T}Hi3s17WYd}7de0qLwvIYCe>s0Q`~IsMZd}DW4yW@~2UkLz z0wDKasng}oyE$GA{)|*tak_V-)AhM5$Eq0DEitCy7Jr%bPayT-<@K`a(|loo{}?JV z&r$a16-%tg=4pQ>1OD7S1*hwi57$DRdVef8O8T`I|a7`-KYj8tKHw?FR zl=UCLFk!j}qUSn#B7@0(9=mc5O^J^><`&$J%v<3z*Yi){bUjw#Ch#2wa&B~=6Wi|W zIP~+Jzn9%M)o|k~&T%*$rw*=!7zIGi^X_wfJ-yCZ7^u@8nL@Ag^)_6QiuE3x82alu z12_Lm%(a2EyO*!7zG!41&Yjw_#rq-LBAv$X^Tf^UdgQ{j;oAu0ob2@ORg2(td0&K+ z_xk*vofgklxIFxJnSf^}%JY4=6ncN0S-3g;rh&BUX%>m%!F4PDq{Tl5mxK5AFM8v2bwy1l z4;Ex@?)~OKwLE7SN2kY#n{up*@v$Yw65JL1#{ol%iT3mO<$NeK5Gixw?EM|%UL38i z%QFi+|K#H9&(*VVy1$3u!o(;B(xx5p{wPyiXS*sCcGgdt(dc@-23M>6FI)U4;8u?? z{|Azv{_Y-@GM(GtOpK6>BXj7wDS*GcpTP}z_XIzKMxKqP!7c;y7$^aY+^e>L9R`na zka>S1j$XX-|3P+co`)Mzu};G2+%yVDxpRT!X0U?AL_Vb+&i#=`q|w#GNxaM$c2$m5 zG2VwOq|AQZEZp4p>Hk2k1?iqN?~u69>X+H4+w~#bA_jhsCqBq7doElXo{d1tezqfd zzL(ED;Lcs-*?EMSAUQ)@&!^wB6RuUoId6%Rg4-?Ooh%@8y;QY3DH?-sQJ(L^XHU zlz2oC} zGp|drUyluM!xbs-_blEsaP!Y_{R1+_b)S>oFLTo>|MhV1lvqQr+vDIr;@wG3I)9!H z!|8ey!Zi`67C6`1bbExgoAw0HB+>(M^G%5l7whQn>FbEpl!qg+foMF^-;*%S9INBL z1}E!d{ zMj+!=dn#OhEF9^Pk%gu)G9N2NGp)Q%!R4@4>W|qB*R6D~!PP46GMvu)1l(#d^#@-< z-+hkh4WCbdbf;Vgj>b9$<=K2Pn&=43v1A?;53x?_l$VLdeY}dOkCaWg%=-Ecd@c=` z;di}uIr5LW_x9FVA45M0Uk`TJo(7qITZv8{$jR~Sc0P};!SQEpage`&wn%YlxSj79 zvkpcirnjD!@jCB$ZL{AW<0G_%-w1LGaM#dHf@QXIl-Gxz*R=9_a*6(}yz<~$(KUcV zwl}zr`QtzF@7y}JD6i9KN4_4^w!r11li%Wqv3*VE*}i<8^1Skt*LAep9OKtc!!5$i z3fOK_b$rM3T2RN`N4xyhK-#~sN>qahfO|boCJis!Y zEz<`lDATnk-7+7*vJOZtF5Yo-zeYj*xw(Trhcf!P6}U-sBj6$1x4mQJe6scDoR640 zelh6r(2EyW@lK#EL2TYxFH zr=JSOdeie-VZXl&pQ2s=xgfU&H;!%yY_ffg-20FD#Pb?cUgiIk{3)*zxH@#Dppxx2 zjQn0l*S+&!d0jx;puEy>JHNpB4@TIox3u>K&uOSMRU0+K6`(Vq zjqL_D{pEeh^U7CV@1Q-Tysp44JPggC`zQDRVdb=F~WK+Vd8@zkKhaEl}Rq;1(DwX21=$D^>rz z?s?6t@#G=eCGT3F=?^<_X>@JC{Bzcc7#RKW?|WV;<#inGu=1*dD?wKP8rjzE{lbQu zzhdR}Cfbnl8i!kBtXKjv|8VbPxPHLrDeYUH@3I@IXO9JnT$9p2_u(P0= z|1Yxh^Z2LU`LDcRLK|1!O>l*x3wYO5hlmrmcz1fnZer^#?ew$+g_{v@$Nu(;of5YN` z$MYXo{zv~M{hRrp-@h2H9$g6d%3dzsb?nRBPx|}&bDM5%N|pab^iv4^{-bakoZs@! z(sj0rykkToHyG1=WUkG#>`$xl{2ecr{;iAqXg4XaKh`$f5OHGQ0o%n)68!r=_q@{V z_wz5TFvnNpcP(5gxu_Uurh&|V+Xqhfrup=>BX=hJ{jajif8aB>{M{Vu_k0oUD#!S_2Dl=0 zc>?x1V@eky)Gu-o1=OaXAo^%jq7o^ z2L;T^^oMBAwHGY+@L*s}z(@>5oAW-R_!J%71(^_TT? zv|H*Lxdk_juG{KgW%5?D%m3JabITu7{wL9nEB{6~xp$U#ma@xVZsHdIPR~C_`M-y@ zRQb=q%`&|@AAA0nl>gzM z;?JPvk68%Ugsv78v#rbe^nZ8j-Ke~hXyyLW?-hf~LAT9)YgSoVup71bU-105IL7aF z3+*N4zX&JqFilwf!)Am1JF0$nJpXa!fArtt&w7nNW-*+s`Gf>)%lG&E`;Tn9xu{cK z=g~GQuN2%aNsxDzhS_dXx0XlA47_n0j*xS4Eu)EIDFqLLrRt5I9gqXIyianR=zXi zl%a~b8eB;A(G&OF934`gZ=sdn$?=zW5^jY)wg9fMt8I@}n# zG$6F;;+r9TJf}$Xp6-cvxX)-Rj+UET*84JJo77{@%g=MN4tUs&BYd;Rhv6R-1>=|h zPq&V;{^>8*t7!GRO+#?DBHVg4s^lu+EV@F7TVj2TZEJGYXaP5JI{+R&xj`I z%cWq6`;nO3?13qA6b38Iu~w;+{2K#O-aOU=OTF#Pz2t zp4Wsr?g(p+X~h-6HNaH@QpUSRnVu84ca~pLG>vG7RJmmxv&r*xa?JjDv~58C9ua?U zJCpPC|I5wcwi<^%Lc5~41-NT)lYrCMBnM9SMEP8`_hbL`>xRS1w-8R=VXF0f zo!_0qHv~VVv`0R7%Uh@T^Jt3|mx9}+4{m^AwpU1kznup?Fa3_whiK&)f?qoiHv@MG z@W0GDhO^eN%DxTdE$f(zo(HFhiHRAIxrO|;#1`j*d~duGtkM|6ARowBkQpyScg&w}+$srJ1XAAa zp8k$xe}}nd@qOF#mEVH!*ZCgW6^`|Dn{XrOjm{@ObS4RgMtP$!t=zY&OrF@U0 z&6@wf$?p;s0Gia4SG)f0(f$-aCv4g*{xAJsHy`=R{~e3}6}TlTasxR2m2SP{eKx}R z#yx)NjK6QA*ZKd*i+2SNe@>oa4N!h_(9h+=$!`_c1E#Yk>RY4WyY42$e(itcfp`8Z zuk&b&6_LA>H zJkJF*{^ReXomSj7oLpC8!13(rFgdIN`gQs2leeJqsnM5N&1DKF|9zp2|MUC5#BueC zYl18EIHtYkHyeK3{3GwXWh{YY-eKxW4IJg?0n^SPM$;Q!VKzJh%IuFOV|~3nu}FXH zY=nQKp`Txzam&j)XU@M@=DiEW&m0YOM%&W^eGIl)=S!%{`CVvFOH6M2VjXe9ICZZ$ zT4toI-hO@?JlPeGbnvq+nTW8?r!4wrIOzkmL48v67cKfxI9bQc3zVhGyl>IY!llrz(_i-M z>AVtdKC|f8;p))OqTjD)$UMpVpUz1hTmjlqpLa#b{eAowEc#lw4a(e%Ub+(B>Fnp< z#&CYKV$h?g84op#H4rPg(TMaLZ`dD7)82Ql3uHU$p2);ik~bTHe0i zZPS|;{S7#&>u^x-)c3AMzX2z8Z4Bz2`aaECpRR8KoYYm;>U`yK>U+YXuZNSm%KDj4 z@6`8#MW2R~y2_fEr%zObsuGf)>lXbqoYYm;!hCwi|Bgk!0w;Bq@B3x+^s}V+A7t%N z*EdHv^y5LjQ=aE6`bs#dYim&7CH_f^J_aXacqw}CJV{U<=l6Esvgjw_#?e2JaU4(1 zD@QN>w=DWaxK{Ms8aUT#^zwVegQ*n#pM9n9GH%L{DxW*Tp_gq!n2w=WlD++rnlN8< z_2N0lv&%fkI;755C7iT{uY6tfZ)ZGATJ$kEX^UlodG#0lIT_F2vgjw_q%EfQ>o0nz zf8Vm`7vZEW;z4~}%Ky-!-+`01aITYR$d~w^75!1x26cUl;iRs4L4T+JowDeg;iRr> zTyMR6;_sf%vW}_!N8zNd^4k>q`3X7sxoOegfRnn)nw`(zJ^wBG4LGT*{O*NM@8suc z)&zBV3gD!!hl2XElAjY6eLbAib&JCL{N3}PeWe`d?|*4HsjKRLRh7JCzju7d{hfoG zmUuVdWKJ_oyiC5r;czlUoVynN23#%rMql~RJLP+t^~S9F!{wo$qcC21qBT`+`<<}p z>*3`0M8<-8XS}&!(Wl{L-LeI}SAOyDk^`<=^wV%s*OH*WGv3^>=vUyRuCmVN^LNIZ zgRKAQ`sN6SP`)eZtFM!v=Pde4I9Zpx6x6p&Xr{svqez0ARV{!V|pYte7O)uP{Ie(uvd z<$szrKV9DfxIFYb#NX3<`8;9K*Tbz*mc^jnDgOnFJ`Fd6elntAa+{#*i zCSGO8y7_$GXVF)}$?qYxqR-S<{E?faMIVEkM_(HBcjrGA{UqEtdbv;B zFTeY5`Q5VU7vbc)m&?BMnfau1K0LJOci>9V%X^4EeYhpgVOux+2 ztR3q57Qp49m*-zTeVOP_SoHO9Ym~h)sF%yJxnR+!;bzbuLZ8V$^I_-Tg}83fPs0tP z-@-3b9?^^c9gBVit`U7}raWEUkz_fGzx;QQwLM+m9O2NH`t@ZsPJN%V=qus2(6-{o zFv-7(Fj!d;88Bxo+BUcaw33W{?HThXydH}2H3ReVW1&jJ>FLJd^JB}gOK>fg zW5eMpXT8(;w;35R53*L6Re#|)R?0)Oo8;h-sfN?#D}|G5f+}A)% +#include +#include + +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/experimental/acceleration/whitelist/database_generated.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/devicedb.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist_binary.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/variables.h" + +namespace tflite { +namespace acceleration { +namespace { + +std::string CanonicalizeValue(absl::string_view input) { + // This assumes ASCII, which holds for all values we have in the whitelist. + std::string output(input); + for (int i = 0; i < output.size(); i++) { + char c = output[i]; + if (c == ' ' || c == '-') { + output[i] = '_'; + } else if (isalpha(c)) { + output[i] = tolower(c); + } + } + return output; +} + +void CanonicalizeValues(std::map* variable_values) { + for (auto& i : *variable_values) { + i.second = CanonicalizeValue(i.second); + } +} + +} // namespace + +GPUWhitelist::GPUWhitelist() + : GPUWhitelist(g_tflite_acceleration_gpu_whitelist_binary) {} + +GPUWhitelist::GPUWhitelist(const unsigned char* whitelist_flatbuffer) { + if (!whitelist_flatbuffer) return; + database_ = flatbuffers::GetRoot(whitelist_flatbuffer); +} + +std::map GPUWhitelist::CalculateVariables( + const AndroidInfo& android_info, + const ::tflite::gpu::GpuInfo& gpu_info) const { + std::map variables; + + variables[kAndroidSdkVersion] = android_info.android_sdk_version; + variables[kDeviceModel] = android_info.model; + variables[kDeviceName] = android_info.device; + variables[kManufacturer] = android_info.manufacturer; + variables[kGPUModel] = gpu_info.renderer_name; + char buffer[128]; + int len = snprintf(buffer, 128 - 1, "%d.%d", gpu_info.major_version, + gpu_info.minor_version); + buffer[len] = '\0'; + variables[kOpenGLESVersion] = std::string(buffer); + CanonicalizeValues(&variables); + if (!database_) return variables; + UpdateVariablesFromDatabase(&variables, *database_); + return variables; +} + +bool GPUWhitelist::Includes(const AndroidInfo& android_info, + const ::tflite::gpu::GpuInfo& gpu_info) const { + auto variables = CalculateVariables(android_info, gpu_info); + return variables[gpu::kStatus] == std::string(gpu::kStatusWhitelisted); +} + +TfLiteGpuDelegateOptionsV2 GPUWhitelist::GetBestOptionsFor( + const AndroidInfo& /* android_info */, + const ::tflite::gpu::GpuInfo& /* gpu_info */) const { + // This method is for forwards-compatibility: the whitelist may later include + // information about which backend to choose (OpenGL/OpenCL/Vulkan) or other + // options. + return TfLiteGpuDelegateOptionsV2Default(); +} + +} // namespace acceleration +} // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.h b/tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.h new file mode 100644 index 00000000000..a28e0d1e5d2 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/gpu_whitelist.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_GPU_WHITELIST_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_GPU_WHITELIST_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/android_info.h" +#include "tensorflow/lite/experimental/acceleration/whitelist/devicedb.h" + +namespace tflite { +namespace acceleration { + +// This class provides information on GPU delegate support. +// +// The GPU delegate is supported on a subset of Android devices, depending on +// Android version, OpenGL ES version, GPU chipset etc. The support is based on +// measure stability, correctness and peformance. For more detail see README.md. +// +// Example usage: +// tflite::Interpreter* interpreter = ... ; +// tflite::acceleration::AndroidInfo android_info; +// tflite::gpu::GpuInfo gpu_info; +// EXPECT_OK(tflite::acceleration::RequestAndroidInfo(&android_info)); +// EXPECT_OK(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env)); +// EXPECT_OK(tflite::gpu::gl::RequestGpuInfo(&tflite_gpu_info)); +// tflite::acceleration::GPUWhitelist whitelist; +// TfLiteDelegate* gpu_delegate = nullptr; +// TfLiteGpuDelegateOptions gpu_options; +// if (whitelist.Includes(android_info, gpu_info)) { +// gpu_options = whitelist.BestOptionsFor(android_info, gpu_info); +// gpu_delegate = TfLiteGpuDelegateCreate(&gpu_options); +// EXPECT_EQ(interpreter->ModifyGraphWithDelegate(gpu_delegate), TfLiteOk); +// } else { +// // Fallback path. +// } +class GPUWhitelist { + public: + // Construct whitelist from bundled data. + GPUWhitelist(); + // Returns true if the provided device specs are whitelisted by the database. + bool Includes(const AndroidInfo& android_info, + const ::tflite::gpu::GpuInfo& gpu_info) const; + + // Returns the best TfLiteGpuDelegateOptionsV2 for the provided device specs + // based on the database. The output can be modified as desired before passing + // to delegate creation. + TfLiteGpuDelegateOptionsV2 GetBestOptionsFor( + const AndroidInfo& android_info, + const ::tflite::gpu::GpuInfo& gpu_info) const; + + // Convert android_info and gpu_info into a set of variables used for querying + // the whitelist, and update variables from whitelist data. See variables.h + // and devicedb.h for more information. + std::map CalculateVariables( + const AndroidInfo& android_info, + const ::tflite::gpu::GpuInfo& gpu_info) const; + + GPUWhitelist(const GPUWhitelist&) = delete; + GPUWhitelist& operator=(const GPUWhitelist&) = delete; + + protected: + explicit GPUWhitelist(const unsigned char* whitelist_flatbuffer); + const DeviceDatabase* database_; +}; + +} // namespace acceleration +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_GPU_WHITELIST_H_ diff --git a/tensorflow/lite/experimental/acceleration/whitelist/json_to_fb.cc b/tensorflow/lite/experimental/acceleration/whitelist/json_to_fb.cc new file mode 100644 index 00000000000..11638895ae8 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/json_to_fb.cc @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Simple program to convert from JSON to binary flatbuffers for given schema. +// +// Used for creating the binary version of a whitelist. +// +// The flatc command line is not available in all build environments. +#include +#include +#include +#include + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/idl.h" // from @flatbuffers +#include "flatbuffers/reflection.h" // from @flatbuffers +#include "flatbuffers/reflection_generated.h" // from @flatbuffers +#include "flatbuffers/util.h" // from @flatbuffers +#include "tensorflow/lite/tools/command_line_flags.h" + +int main(int argc, char** argv) { + std::string json_path, fbs_path, fb_path; + std::vector flags = { + tflite::Flag::CreateFlag("json_input", &json_path, + "Path to input json file."), + tflite::Flag::CreateFlag("fbs", &fbs_path, + "Path to flatbuffer schema to use."), + tflite::Flag::CreateFlag("fb_output", &fb_path, + "Path to a output binary flatbuffer."), + }; + const bool parse_result = + tflite::Flags::Parse(&argc, const_cast(argv), flags); + if (!parse_result || json_path.empty() || fbs_path.empty() || + fb_path.empty()) { + std::cerr << tflite::Flags::Usage(argv[0], flags); + return 1; + } + std::string json_contents; + if (!flatbuffers::LoadFile(json_path.c_str(), false, &json_contents)) { + std::cerr << "Unable to load file " << json_path << std::endl; + return 2; + } + std::string fbs_contents; + if (!flatbuffers::LoadFile(fbs_path.c_str(), false, &fbs_contents)) { + std::cerr << "Unable to load file " << fbs_path << std::endl; + return 3; + } + const char* include_directories[] = {nullptr}; + flatbuffers::Parser schema_parser; + if (!schema_parser.Parse(fbs_contents.c_str(), include_directories)) { + std::cerr << "Unable to parse schema " << schema_parser.error_ << std::endl; + return 4; + } + schema_parser.Serialize(); + auto schema = + reflection::GetSchema(schema_parser.builder_.GetBufferPointer()); + auto root_table = schema->root_table(); + flatbuffers::Parser parser; + parser.Deserialize(schema_parser.builder_.GetBufferPointer(), + schema_parser.builder_.GetSize()); + + if (!parser.Parse(json_contents.c_str(), include_directories, + json_path.c_str())) { + std::cerr << "Unable to parse json " << parser.error_ << std::endl; + return 5; + } + + // Use CopyTable() to deduplicate the strings. + const uint8_t* buffer = parser.builder_.GetBufferPointer(); + flatbuffers::FlatBufferBuilder fbb; + auto root_offset = flatbuffers::CopyTable( + fbb, *schema, *root_table, *flatbuffers::GetAnyRoot(buffer), true); + fbb.Finish(root_offset); + std::string binary(reinterpret_cast(fbb.GetBufferPointer()), + fbb.GetSize()); + std::ofstream output; + output.open(fb_path); + output << binary; + output.close(); + return 0; +} diff --git a/tensorflow/lite/experimental/acceleration/whitelist/variables.h b/tensorflow/lite/experimental/acceleration/whitelist/variables.h new file mode 100644 index 00000000000..178343e5c9c --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/whitelist/variables.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_VARIABLES_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_VARIABLES_H_ + +// This file lists generally useful whitelisting properties. +// Properties starting with "tflite." are reserved. +// Users of the whitelisting library can use arbitrary other property names. + +namespace tflite { +namespace acceleration { +// System properties, not specific to any single delegate. + +// Android properties. +// +// Android SDK version number. Android system property ro.build.version.sdk. +// E.g., "28". +constexpr char kAndroidSdkVersion[] = "tflite.android_sdk_version"; +// SoC model. Looked up from database or possibly returned from Android system +// property ro.board.platform, normalized. E.g., "sdm450". +constexpr char kSoCModel[] = "tflite.soc_model"; +// SoC vendor. Looked up from database. E.g., "qualcomm". +constexpr char kSoCVendor[] = "tflite.soc_vendor"; +// Device manufacturer. Android API android.os.Build.MANUFACTURER, normalized. +// E.g., "google". +constexpr char kManufacturer[] = "tflite.manufacturer"; +// Device model. Android API android.os.Build.MODEL, normalized. +// E.g., "pixel_2". +constexpr char kDeviceModel[] = "tflite.device_model"; +// Device name. Android API android.os.Build.DEVICE, normalized. +// E.g., "walleye". +constexpr char kDeviceName[] = "tflite.device_name"; + +// GPU-related properties. +// +// OpenGL ES version. E.g., 3.2. +constexpr char kOpenGLESVersion[] = "tflite.opengl_es_version"; +// GPU model, result of querying GL_RENDERER, normalized. E.g., +// "adreno_(tm)_505". +constexpr char kGPUModel[] = "tflite.gpu_model"; +// GPU vendor, normalized. E.g., "adreno_(tm)_505". +constexpr char kGPUVendor[] = "tflite.gpu_vendor"; +// OpenGL driver version, result of querying GL_VERSION. E.g., +// "opengl_es_3.2_v@328.0_(git@6fb5a5b,_ife855c4895)_(date:08/21/18)" +constexpr char kOpenGLDriverVersion[] = "tflite.opengl_driver_version"; + +// NNAPI-related properties. +// +// NNAPI accelerator name, returned by ANeuralNetworksDevice_getName. E.g., +// "qti-dsp". +constexpr char kNNAPIAccelerator[] = "tflite.nnapi_accelerator"; +// NNAPI accelerator feature level, returned by +// ANeuralNetworksDevice_getFeatureLevel. E.g., 29. Actual variables are named +// "tflite.nnapi_feature_level.", e.g., +// "tflite.nnapi_feature_level.qti-dsp". +constexpr char kNNAPIFeatureLevelPrefix[] = "tflite.nnapi_feature_level"; + +namespace gpu { +// GPU-delegate derived properties. + +// Whether the GPU delegate works in general. +// ("UNSET", "UNKNOWN", "WHITELISTED", "BLACKLISTED"). +constexpr char kStatus[] = "tflite.gpu.status"; + +// Whether OpenCL should be allowed. Possible values are the SupportStatus enums +// ("UNSET", "UNKNOWN", "WHITELISTED", "BLACKLISTED"). +constexpr char kOpenCLStatus[] = "tflite.gpu.opencl_status"; +constexpr char kStatusWhitelisted[] = "WHITELISTED"; +constexpr char kStatusBlacklisted[] = "BLACKLISTED"; +} // namespace gpu + +} // namespace acceleration +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_WHITELIST_VARIABLES_H_ diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 5eb5e8ab023..32305d8bd89 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -355,6 +355,7 @@ filegroup( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterTestHelper.java", "src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java", + "src/test/java/org/tensorflow/lite/gpu/WhitelistTest.java", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/WhitelistTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/WhitelistTest.java new file mode 100644 index 00000000000..2c6b2d95f55 --- /dev/null +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/WhitelistTest.java @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.gpu; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.gpu.Whitelist}. */ +@RunWith(JUnit4.class) +public final class WhitelistTest { + + @Test + public void testBasic() throws Exception { + try (Whitelist whitelist = new Whitelist()) { + assertThat(whitelist.isDelegateSupportedOnThisDevice()).isTrue(); + } + } +}