Updates the configure Python script to support building Bazel rules on Apple platforms.
PiperOrigin-RevId: 234053872
This commit is contained in:
parent
03ebaa77ac
commit
93e707396e
33
configure.py
33
configure.py
@ -55,6 +55,12 @@ NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
]
|
||||
|
||||
# List of files to be configured for using Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD'
|
||||
]
|
||||
|
||||
if platform.machine() == 'ppc64le':
|
||||
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
|
||||
else:
|
||||
@ -1534,6 +1540,23 @@ def config_info_line(name, help_text):
|
||||
print('\t--config=%-12s\t# %s' % (name, help_text))
|
||||
|
||||
|
||||
def configure_apple_bazel_rules():
|
||||
"""Configures Bazel rules for building on Apple platforms.
|
||||
|
||||
Enables analyzing and building Apple Bazel rules on Apple platforms. This
|
||||
function will only be executed if `is_macos()` is true.
|
||||
"""
|
||||
if not is_macos():
|
||||
return
|
||||
for filepath in APPLE_BAZEL_FILES:
|
||||
print(
|
||||
'Configuring %s file to analyze and build Bazel rules on Apple platforms.'
|
||||
% filepath)
|
||||
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
||||
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
||||
os.rename(existing_filepath, renamed_filepath)
|
||||
|
||||
|
||||
def main():
|
||||
global _TF_WORKSPACE_ROOT
|
||||
global _TF_BAZELRC
|
||||
@ -1574,6 +1597,8 @@ def main():
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
else:
|
||||
environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0'
|
||||
|
||||
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
|
||||
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
|
||||
@ -1676,6 +1701,14 @@ def main():
|
||||
create_android_ndk_rule(environ_cp)
|
||||
create_android_sdk_rule(environ_cp)
|
||||
|
||||
if get_var(
|
||||
environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES',
|
||||
'Configure Bazel rules for Apple platforms', False,
|
||||
('Would you like to configure Bazel rules for building on Apple platforms?'
|
||||
), 'Configuring Bazel rules for Apple platforms.',
|
||||
'Not configuring Bazel rules for Apple platforms.'):
|
||||
configure_apple_bazel_rules()
|
||||
|
||||
print('Preconfigured Bazel build configs. You can use any of the below by '
|
||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||
'details.')
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "third_party/tensorflow/core/framework/types.h"
|
||||
|
||||
std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
|
||||
int* out_width,
|
||||
|
@ -16,8 +16,8 @@
|
||||
#import <UIKit/UIKit.h>
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#include "third_party/tensorflow/core/public/session.h"
|
||||
#include "third_party/tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
@interface CameraExampleViewController
|
||||
: UIViewController<UIGestureRecognizerDelegate,
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "third_party/tensorflow/core/framework/types.h"
|
||||
|
||||
std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
|
||||
int* out_width,
|
||||
|
@ -18,8 +18,8 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#include "third_party/tensorflow/core/public/session.h"
|
||||
#include "third_party/tensorflow/core/util/memmapped_file_system.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
// Reads a serialized GraphDef protobuf file from the bundle, typically
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "third_party/tensorflow/core/framework/types.h"
|
||||
|
||||
std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
|
||||
int* out_width,
|
||||
|
@ -24,17 +24,17 @@
|
||||
#include <queue>
|
||||
|
||||
#if TFLITE_USE_CONTRIB_LITE
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/op_resolver.h"
|
||||
#include "tensorflow/contrib/lite/string_util.h"
|
||||
#include "third_party/tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "third_party/tensorflow/contrib/lite/model.h"
|
||||
#include "third_party/tensorflow/contrib/lite/op_resolver.h"
|
||||
#include "third_party/tensorflow/contrib/lite/string_util.h"
|
||||
#else
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/op_resolver.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "third_party/tensorflow/lite/kernels/register.h"
|
||||
#include "third_party/tensorflow/lite/model.h"
|
||||
#include "third_party/tensorflow/lite/op_resolver.h"
|
||||
#include "third_party/tensorflow/lite/string_util.h"
|
||||
#if TFLITE_USE_GPU_DELEGATE
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#include "third_party/tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
104
tensorflow/lite/experimental/objc/BUILD.apple
Normal file
104
tensorflow/lite/experimental/objc/BUILD.apple
Normal file
@ -0,0 +1,104 @@
|
||||
# TensorFlow Lite Objective-C API.
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||
|
||||
SOURCES = glob([
|
||||
"sources/*.h",
|
||||
"sources/*.m",
|
||||
"sources/*.mm",
|
||||
])
|
||||
|
||||
API_HEADERS = glob([
|
||||
"apis/*.h",
|
||||
])
|
||||
|
||||
MINIMUM_OS_VERSION = "9.0"
|
||||
|
||||
# Compiler flags for building regular non-test libraries.
|
||||
RELEASE_COPTS = [
|
||||
# Enables language-specific warnings for Objective-C, Objective-C++, C, and C++.
|
||||
"-Wall",
|
||||
# Warns if functions, variables, and types marked with the deprecated attribute are being used.
|
||||
"-Wdeprecated-declarations",
|
||||
# Warns for errors in documentation.
|
||||
"-Wdocumentation",
|
||||
# Turns all warnings into errors.
|
||||
"-Werror",
|
||||
# Enables extra warning flags that are not enabled by -Wall.
|
||||
"-Wextra",
|
||||
# Warns if a global function is defined without a previous prototype declaration.
|
||||
"-Wmissing-prototypes",
|
||||
# From -Wextra. Disables warning when signed value is converted to unsigned value during comparison.
|
||||
"-Wno-sign-compare",
|
||||
# From -Wextra. Disables warning for unused parameters, which are common in delegate methods and block callbacks.
|
||||
"-Wno-unused-parameter",
|
||||
# Warns if a global or local variable or type declaration shadows another variable, parameter, type, class member, or instance variable.
|
||||
"-Wshadow",
|
||||
# Warns if a function is declared or defined without specifying the argument types. For a block with no args, use (void) instead of ().
|
||||
"-Wstrict-prototypes",
|
||||
# Warns if an @selector() expression is encountered with a method name that hasn't been defined yet.
|
||||
"-Wundeclared-selector",
|
||||
# Turn off warnings for headers not part of TensorFlow Lite Objective-C API.
|
||||
"--system-header-prefix=tensorflow/lite/experimental/c/",
|
||||
]
|
||||
|
||||
# Compiler flags for building test libraries.
|
||||
TEST_COPTS = RELEASE_COPTS + [
|
||||
# From -Wall. Disables warning when passing nil to a callee that requires a non-null argument.
|
||||
"-Wno-nonnull",
|
||||
# Disables warning when a global or local variable or type declaration shadows another.
|
||||
"-Wno-shadow",
|
||||
]
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLite",
|
||||
srcs = SOURCES,
|
||||
hdrs = API_HEADERS,
|
||||
copts = RELEASE_COPTS,
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/c:c_api",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "TensorFlowLiteTests",
|
||||
size = "small",
|
||||
minimum_os_version = MINIMUM_OS_VERSION,
|
||||
tags = [
|
||||
"manual",
|
||||
# These sanitizer tests are not supported by iOS build toolchain (b/74292221).
|
||||
# Disabled these for iOS test targets.
|
||||
"noasan",
|
||||
"notsan",
|
||||
"nomsan",
|
||||
],
|
||||
deps = [":TensorFlowLiteTestsLib"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLiteTestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob([
|
||||
"tests/*.m",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"apis/*.h",
|
||||
"sources/*.h",
|
||||
"tests/*.h",
|
||||
]),
|
||||
copts = TEST_COPTS,
|
||||
resources = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
],
|
||||
)
|
52
tensorflow/lite/experimental/objc/README.md
Normal file
52
tensorflow/lite/experimental/objc/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# TensorFlow Lite Objective-C Library
|
||||
|
||||
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
|
||||
solution for Objective-C developers. It enables low-latency inference of
|
||||
on-device machine learning models with a small binary size and fast performance
|
||||
supporting hardware acceleration.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Bazel
|
||||
|
||||
In your `BUILD` file, add the `TensorFlowLite` dependency:
|
||||
|
||||
```python
|
||||
objc_library(
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/objc:TensorFlowLite",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
If you would like to build the Objective-C TensorFlow Lite library using Bazel on Apple
|
||||
platforms, clone or download the [TensorFlow GitHub repo](https://github.com/tensorflow/tensorflow),
|
||||
then navigate to the root `tensorflow` directory and execute the `configure.py` script:
|
||||
|
||||
```shell
|
||||
python configure.py
|
||||
```
|
||||
|
||||
Follow the prompts and when asked to configure the Bazel rules for Apple
|
||||
platforms, enter `y`.
|
||||
|
||||
Build the `TensorFlowLite` Objective-C library target:
|
||||
|
||||
```shell
|
||||
bazel build tensorflow/lite/experimental/objc:TensorFlowLite
|
||||
```
|
||||
|
||||
Build the `TensorFlowLiteTests` target:
|
||||
|
||||
```shell
|
||||
bazel test tensorflow/lite/experimental/objc:TensorFlowLiteTests
|
||||
```
|
||||
|
||||
### Tulsi
|
||||
|
||||
Open the `TensorFlowLiteObjc.tulsiproj` using the Tulsi application on Mac or by
|
||||
running the following command in Terminal from the root source directory:
|
||||
|
||||
```shell
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLiteObjc.tulsiproj:TensorFlowLiteObjC --outputfolder ~/path/to/xcodeproj
|
||||
```
|
@ -0,0 +1,60 @@
|
||||
{
|
||||
"sourceFilters" : [
|
||||
"tensorflow/lite",
|
||||
"tensorflow/lite/experimental/c",
|
||||
"tensorflow/lite/experimental/objc",
|
||||
"tensorflow/lite/experimental/objc/apis",
|
||||
"tensorflow/lite/experimental/objc/sources",
|
||||
"tensorflow/lite/experimental/objc/tests",
|
||||
"tensorflow/lite/kernels",
|
||||
"tensorflow/lite/kernels/internal",
|
||||
"tensorflow/lite/nnapi",
|
||||
"tensorflow/lite/schema",
|
||||
],
|
||||
"buildTargets" : [
|
||||
"//tensorflow/lite/experimental/objc:TensorFlowLite",
|
||||
"//tensorflow/lite/experimental/objc:TensorFlowLiteTests",
|
||||
],
|
||||
"projectName" : "TensorFlowLiteObjC",
|
||||
"optionSet" : {
|
||||
"LaunchActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"EnvironmentVariables" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"CommandlineArguments" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"LaunchActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
}
|
||||
},
|
||||
"additionalFilePaths" : [
|
||||
"tensorflow/lite/experimental/objc/BUILD",
|
||||
]
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
{
|
||||
"configDefaults" : {
|
||||
"optionSet" : {
|
||||
"BazelBuildOptionsDebug" : {
|
||||
|
||||
},
|
||||
"BazelBuildOptionsRelease" : {
|
||||
|
||||
},
|
||||
}
|
||||
},
|
||||
"projectName" : "TensorFlowLiteObjC",
|
||||
"packages" : [
|
||||
"tensorflow/lite/experimental/objc"
|
||||
],
|
||||
"workspaceRoot" : "../../../../.."
|
||||
}
|
179
tensorflow/lite/experimental/objc/apis/TFLInterpreter.h
Normal file
179
tensorflow/lite/experimental/objc/apis/TFLInterpreter.h
Normal file
@ -0,0 +1,179 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
@class TFLInterpreterOptions;
|
||||
@class TFLTensor;
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @enum TFLInterpreterErrorCode
|
||||
* This enum specifies various error codes related to `TFLInterpreter`.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, TFLInterpreterErrorCode) {
|
||||
/** Provided tensor index is invalid. */
|
||||
TFLInterpreterErrorCodeInvalidTensorIndex,
|
||||
|
||||
/** Input data has invalid byte size. */
|
||||
TFLInterpreterErrorCodeInvalidInputByteSize,
|
||||
|
||||
/** Provided shape is invalid. It must be a non-empty array of positive unsigned integers. */
|
||||
TFLInterpreterErrorCodeInvalidShape,
|
||||
|
||||
/** Provided model cannot be loaded. */
|
||||
TFLInterpreterErrorCodeFailedToLoadModel,
|
||||
|
||||
/** Failed to create `TFLInterpreter`. */
|
||||
TFLInterpreterErrorCodeFailedToCreateInterpreter,
|
||||
|
||||
/** Failed to invoke `TFLInterpreter`. */
|
||||
TFLInterpreterErrorCodeFailedToInvoke,
|
||||
|
||||
/** Failed to retrieve a tensor. */
|
||||
TFLInterpreterErrorCodeFailedToGetTensor,
|
||||
|
||||
/** Invalid tensor. */
|
||||
TFLInterpreterErrorCodeInvalidTensor,
|
||||
|
||||
/** Failed to resize an input tensor. */
|
||||
TFLInterpreterErrorCodeFailedToResizeInputTensor,
|
||||
|
||||
/** Failed to copy data into an input tensor. */
|
||||
TFLInterpreterErrorCodeFailedToCopyDataToInputTensor,
|
||||
|
||||
/** Copying data into an output tensor not allowed. */
|
||||
TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed,
|
||||
|
||||
/** Failed to get data from a tensor. */
|
||||
TFLInterpreterErrorCodeFailedToGetDataFromTensor,
|
||||
|
||||
/** Failed to allocate memory for tensors. */
|
||||
TFLInterpreterErrorCodeFailedToAllocateTensors,
|
||||
|
||||
/** Operaton not allowed without allocating memory for tensors first. */
|
||||
TFLInterpreterErrorCodeAllocateTensorsRequired,
|
||||
|
||||
/** Operaton not allowed without invoking the interpreter first. */
|
||||
TFLInterpreterErrorCodeInvokeInterpreterRequired,
|
||||
};
|
||||
|
||||
/**
|
||||
* A TensorFlow Lite model interpreter.
|
||||
*/
|
||||
@interface TFLInterpreter : NSObject
|
||||
|
||||
/** The total number of input tensors. 0 if the interpreter creation failed. */
|
||||
@property(nonatomic, readonly) NSUInteger inputTensorCount;
|
||||
|
||||
/** The total number of output tensors. 0 if the interpreter creation failed. */
|
||||
@property(nonatomic, readonly) NSUInteger outputTensorCount;
|
||||
|
||||
/** Unavailable. */
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Initializes a new TensorFlow Lite interpreter instance with the given model file path and the
|
||||
* default interpreter options.
|
||||
*
|
||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* interpreter.
|
||||
*
|
||||
* @return A new instance of `TFLInterpreter` with the given model and the default interpreter
|
||||
* options. `nil` if there is an error in initializing the interpreter.
|
||||
*/
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Initializes a new TensorFlow Lite interpreter instance with the given model file path and
|
||||
* options.
|
||||
*
|
||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
* @param options Options to use for configuring the TensorFlow Lite interpreter.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* interpreter.
|
||||
*
|
||||
* @return A new instance of `TFLInterpreter` with the given model and options. `nil` if there is an
|
||||
* error in initializing the interpreter.
|
||||
*/
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath
|
||||
options:(TFLInterpreterOptions *)options
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Invokes the interpreter to run inference.
|
||||
*
|
||||
* @param error An optional error parameter populated when there is an error in invoking the
|
||||
* interpreter.
|
||||
*
|
||||
* @return Whether the invocation is successful. Returns NO if an error occurred.
|
||||
*/
|
||||
- (BOOL)invokeWithError:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Returns the input tensor at the given index.
|
||||
*
|
||||
* @param index The index of an input tensor.
|
||||
* @param error An optional error parameter populated when there is an error in looking up the input
|
||||
* tensor.
|
||||
*
|
||||
* @return The input tensor at the given index. `nil` if there is an error. See the `TFLTensor`
|
||||
* class documentation for more details on the life expectancy between the returned tensor and
|
||||
* this interpreter.
|
||||
*/
|
||||
- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Returns the output tensor at the given index.
|
||||
*
|
||||
* @param index The index of an output tensor.
|
||||
* @param error An optional error parameter populated when there is an error in looking up the
|
||||
* output tensor.
|
||||
*
|
||||
* @return The output tensor at the given index. `nil` if there is an error. See the `TFLTensor`
|
||||
* class documentation for more details on the life expectancy between the returned tensor and
|
||||
* this interpreter.
|
||||
*/
|
||||
- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Resizes the input tensor at the given index to the specified shape (an array of positive unsigned
|
||||
* integers).
|
||||
*
|
||||
* @param index The index of an input tensor.
|
||||
* @param shape Shape that the given input tensor should be resized to. It should be an array of
|
||||
* positive unsigned integer(s) containing the size of each dimension.
|
||||
* @param error An optional error parameter populated when there is an error in resizing the input
|
||||
* tensor.
|
||||
*
|
||||
* @return Whether the input tensor was resized successfully. Returns NO if an error occurred.
|
||||
*/
|
||||
- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index
|
||||
toShape:(NSArray<NSNumber *> *)shape
|
||||
error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Allocates memory for tensors.
|
||||
*
|
||||
* @param error An optional error parameter populated when there is an error in allocating memory.
|
||||
*
|
||||
* @return Whether memory allocation is successful. Returns NO if an error occurred.
|
||||
*/
|
||||
- (BOOL)allocateTensorsWithError:(NSError **)error;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,37 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Custom configuration options for a TensorFlow Lite interpreter. */
|
||||
@interface TFLInterpreterOptions : NSObject
|
||||
|
||||
/**
|
||||
* Maximum number of threads that the interpreter should run on. Defaults to 0 (unspecified, letting
|
||||
* TensorFlow Lite to optimize the threading decision).
|
||||
*/
|
||||
@property(nonatomic) NSUInteger numberOfThreads;
|
||||
|
||||
/**
|
||||
* Initializes a new instance of `TFLInterpreterOptions`.
|
||||
*
|
||||
* @return A new instance of `TFLInterpreterOptions`.
|
||||
*/
|
||||
- (instancetype)init NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,36 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* Parameters for asymmetric quantization. Quantized values can be converted to float values using:
|
||||
* `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||
*/
|
||||
@interface TFLQuantizationParameters : NSObject
|
||||
|
||||
/** Scale of asymmetric quantization. */
|
||||
@property(nonatomic, readonly) float scale;
|
||||
|
||||
/** Zero point of asymmetric quantization. */
|
||||
@property(nonatomic, readonly) int32_t zeroPoint;
|
||||
|
||||
/** Unavailable. */
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
111
tensorflow/lite/experimental/objc/apis/TFLTensor.h
Normal file
111
tensorflow/lite/experimental/objc/apis/TFLTensor.h
Normal file
@ -0,0 +1,111 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
@class TFLQuantizationParameters;
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @enum TFLTensorDataType
|
||||
* This enum specifies supported TensorFlow Lite tensor data types.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, TFLTensorDataType) {
|
||||
/** Tensor data type not available. This indicates an error with the model. */
|
||||
TFLTensorDataTypeNoType,
|
||||
|
||||
/** 32-bit single precision floating point. */
|
||||
TFLTensorDataTypeFloat32,
|
||||
|
||||
/** 32-bit signed integer. */
|
||||
TFLTensorDataTypeInt32,
|
||||
|
||||
/** 8-bit unsigned integer. */
|
||||
TFLTensorDataTypeUInt8,
|
||||
|
||||
/** 64-bit signed integer. */
|
||||
TFLTensorDataTypeInt64,
|
||||
|
||||
/** Boolean. */
|
||||
TFLTensorDataTypeBool,
|
||||
|
||||
/** 16-bit signed integer. */
|
||||
TFLTensorDataTypeInt16,
|
||||
|
||||
/** 8-bit signed integer. */
|
||||
TFLTensorDataTypeInt8,
|
||||
};
|
||||
|
||||
/**
|
||||
* An input or output tensor in a TensorFlow Lite model.
|
||||
*
|
||||
* @warning Each `TFLTensor` instance is associated with a `TFLInterpreter` instance. Multiple
|
||||
* `TFLTensor` instances of the same TensorFlow Lite model are associated with the same
|
||||
* `TFLInterpreter` instance. As long as a `TFLTensor` instance is still in use, its associated
|
||||
* `TFLInterpreter` instance will not be deallocated.
|
||||
*/
|
||||
@interface TFLTensor : NSObject
|
||||
|
||||
/** Name of the tensor. */
|
||||
@property(nonatomic, readonly, copy) NSString *name;
|
||||
|
||||
/** Data type of the tensor. */
|
||||
@property(nonatomic, readonly) TFLTensorDataType dataType;
|
||||
|
||||
/** Parameters for asymmetric quantization. `nil` if the tensor does not use quantization. */
|
||||
@property(nonatomic, readonly, nullable) TFLQuantizationParameters *quantizationParameters;
|
||||
|
||||
/** Unavailable. */
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Copies the given data into an input tensor. This is allowed only for an input tensor and only
|
||||
* before the interpreter is invoked; otherwise an error will be returned.
|
||||
*
|
||||
* @param data The data to set. The byte size of the data must match what's required by the input
|
||||
* tensor.
|
||||
* @param error An optional error parameter populated when there is an error in copying the data.
|
||||
*
|
||||
* @return Whether the data was copied into the input tensor successfully. Returns NO if an error
|
||||
* occurred.
|
||||
*/
|
||||
- (BOOL)copyData:(NSData *)data error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Retrieves a copy of data in the tensor. For an output tensor, the data is only available after
|
||||
* the interpreter invocation has successfully completed; otherwise an error will be returned.
|
||||
*
|
||||
* @param error An optional error parameter populated when there is an error in retrieving the data.
|
||||
*
|
||||
* @return A copy of data in the tensor. `nil` if there is an error in retrieving the data or the
|
||||
* data is not available.
|
||||
*/
|
||||
- (nullable NSData *)dataWithError:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Retrieves the shape of the tensor, an array of positive unsigned integers containing the size
|
||||
* of each dimension. For example: the shape of [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] is
|
||||
* [2, 2, 3] (i.e. an array of 2 arrays of 2 arrays of 3 numbers).
|
||||
*
|
||||
* @param error An optional error parameter populated when there is an error in retrieving the
|
||||
* shape.
|
||||
*
|
||||
* @return The shape of the tensor. `nil` if there is an error in retrieving the shape.
|
||||
*/
|
||||
- (nullable NSArray<NSNumber *> *)shapeWithError:(NSError **)error;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
40
tensorflow/lite/experimental/objc/sources/TFLErrorUtil.h
Normal file
40
tensorflow/lite/experimental/objc/sources/TFLErrorUtil.h
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Helper utility for error reporting. */
|
||||
@interface TFLErrorUtil : NSObject
|
||||
|
||||
/**
|
||||
* Creates and saves an interpreter error with the given error code and description.
|
||||
*
|
||||
* @param code Error code.
|
||||
* @param description Error description.
|
||||
* @param error Pointer to where to save the created error. If `nil`, no error will be saved.
|
||||
*/
|
||||
+ (void)saveInterpreterErrorWithCode:(TFLInterpreterErrorCode)code
|
||||
description:(NSString *)description
|
||||
error:(NSError **)error;
|
||||
|
||||
/** Unavailable. */
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
38
tensorflow/lite/experimental/objc/sources/TFLErrorUtil.m
Normal file
38
tensorflow/lite/experimental/objc/sources/TFLErrorUtil.m
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "TFLErrorUtil.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Error domain of TensorFlow Lite interpreter related errors. */
|
||||
static NSString *const TFLInterpreterErrorDomain = @"org.tensorflow.lite.interpreter";
|
||||
|
||||
@implementation TFLErrorUtil
|
||||
|
||||
#pragma mark - Public
|
||||
|
||||
+ (void)saveInterpreterErrorWithCode:(TFLInterpreterErrorCode)code
|
||||
description:(NSString *)description
|
||||
error:(NSError **)error {
|
||||
if (error) {
|
||||
*error = [NSError errorWithDomain:TFLInterpreterErrorDomain
|
||||
code:code
|
||||
userInfo:@{NSLocalizedDescriptionKey : description}];
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,63 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h"
|
||||
|
||||
@class TFLTensor;
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface TFLInterpreter (Internal)
|
||||
|
||||
/**
|
||||
* Copies the given data into the input tensor at the given index. This is allowed only before the
|
||||
* interpreter is invoked.
|
||||
*
|
||||
* @param data The data to set. The byte size of the data must match what's required by the input
|
||||
* tensor at the given index.
|
||||
* @param index An input tensor index.
|
||||
* @param error An optional error parameter populated when there is an error in setting the data.
|
||||
*
|
||||
* @return Whether the data was copied into the input tensor at the given index successfully.
|
||||
* Returns NO if an error occurred.
|
||||
*/
|
||||
- (BOOL)copyData:(NSData *)data toInputTensorAtIndex:(NSUInteger)index error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Retrieves a copy of the data from the given tensor. For an output tensor, the interpreter
|
||||
* invocation has to complete before the data can be retrieved.
|
||||
*
|
||||
* @param tensor A tensor.
|
||||
* @param error An optional error parameter populated when there is an error in getting the data.
|
||||
*
|
||||
* @return The data of the given tensor. `nil` if there is an error or data is not available.
|
||||
*/
|
||||
- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Retrieves the shape of the given tensor, an array of positive unsigned integer(s) containing the
|
||||
* size of each dimension. For example: shape of [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] is
|
||||
* [2, 2, 3].
|
||||
*
|
||||
* @param tensor An input or output tensor.
|
||||
* @param error An optional error parameter populated when there is an error in retrieving the
|
||||
* shape.
|
||||
*
|
||||
* @return The shape of the tensor. `nil` if there is an error in retrieving the shape.
|
||||
*/
|
||||
- (nullable NSArray<NSNumber *> *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
407
tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
Normal file
407
tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
Normal file
@ -0,0 +1,407 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h"
|
||||
|
||||
#import "TFLErrorUtil.h"
|
||||
#import "TFLQuantizationParameters+Internal.h"
|
||||
#import "TFLTensor+Internal.h"
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
||||
|
||||
#include "tensorflow/lite/experimental/c/c_api.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* Error reporter for TFLInterpreter.
|
||||
*
|
||||
* @param user_data User data. Not used.
|
||||
* @param format Error message which may contain argument formatting specifiers.
|
||||
* @param args Values of the arguments in the error message.
|
||||
*/
|
||||
static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_list args) {
|
||||
NSLog(@"%@", [[NSString alloc] initWithFormat:@(format) arguments:args]);
|
||||
}
|
||||
|
||||
@interface TFLInterpreter ()
|
||||
|
||||
/** TFL_Interpreter backed by C API. */
|
||||
@property(nonatomic, nullable) TFL_Interpreter *interpreter;
|
||||
|
||||
@end
|
||||
|
||||
@implementation TFLInterpreter
|
||||
|
||||
#pragma mark - NSObject
|
||||
|
||||
- (void)dealloc {
|
||||
TFL_DeleteInterpreter(_interpreter);
|
||||
}
|
||||
|
||||
#pragma mark - Public
|
||||
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||
return [self initWithModelPath:modelPath
|
||||
options:[[TFLInterpreterOptions alloc] init]
|
||||
error:error];
|
||||
}
|
||||
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath
|
||||
options:(TFLInterpreterOptions *)options
|
||||
error:(NSError **)error {
|
||||
self = [super init];
|
||||
|
||||
if (self != nil) {
|
||||
TFL_Model *model = nullptr;
|
||||
TFL_InterpreterOptions *cOptions = nullptr;
|
||||
|
||||
@try {
|
||||
const char *modelPathCString = modelPath.UTF8String;
|
||||
NSString *pathErrorString =
|
||||
[NSString stringWithFormat:@"Cannot load model from path (%@).", modelPath];
|
||||
if (modelPathCString == nullptr) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel
|
||||
description:pathErrorString
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
model = TFL_NewModelFromFile(modelPathCString);
|
||||
if (model == nullptr) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel
|
||||
description:pathErrorString
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
cOptions = TFL_NewInterpreterOptions();
|
||||
if (cOptions == nullptr) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
|
||||
description:@"Failed to create the interpreter."
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
if (options.numberOfThreads > 0) {
|
||||
TFL_InterpreterOptionsSetNumThreads(cOptions, (int32_t)options.numberOfThreads);
|
||||
}
|
||||
TFL_InterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr);
|
||||
|
||||
_interpreter = TFL_NewInterpreter(model, cOptions);
|
||||
if (_interpreter == nullptr) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
|
||||
description:@"Failed to create the interpreter."
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
_inputTensorCount = (NSUInteger)TFL_InterpreterGetInputTensorCount(_interpreter);
|
||||
_outputTensorCount = (NSUInteger)TFL_InterpreterGetOutputTensorCount(_interpreter);
|
||||
if (_inputTensorCount <= 0 || _outputTensorCount <= 0) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
|
||||
description:@"Failed to create the interpreter."
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
} @finally {
|
||||
TFL_DeleteInterpreterOptions(cOptions);
|
||||
TFL_DeleteModel(model);
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
- (BOOL)invokeWithError:(NSError **)error {
|
||||
if (TFL_InterpreterInvoke(self.interpreter) != kTfLiteOk) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToInvoke
|
||||
description:@"Failed to invoke the interpreter."
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
return YES;
|
||||
}
|
||||
|
||||
- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error {
|
||||
if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
return [self tensorOfType:TFLTensorTypeInput atIndex:index error:error];
|
||||
}
|
||||
|
||||
- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error {
|
||||
if (![self isValidTensorIndex:index belowLimit:self.outputTensorCount error:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
return [self tensorOfType:TFLTensorTypeOutput atIndex:index error:error];
|
||||
}
|
||||
|
||||
- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index
|
||||
toShape:(NSArray<NSNumber *> *)shape
|
||||
error:(NSError **)error {
|
||||
if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) {
|
||||
return NO;
|
||||
}
|
||||
|
||||
if (shape.count == 0) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape
|
||||
description:@"Invalid shape. Must not be empty."
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
int cDimensions[self.inputTensorCount];
|
||||
for (int dimIndex = 0; dimIndex < shape.count; ++dimIndex) {
|
||||
int dimension = shape[dimIndex].intValue;
|
||||
if (dimension <= 0) {
|
||||
NSString *errorDescription = @"Invalid shape. Dimensions must be positive integers.";
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
cDimensions[dimIndex] = dimension;
|
||||
}
|
||||
|
||||
if (TFL_InterpreterResizeInputTensor(self.interpreter, (int32_t)index, cDimensions,
|
||||
(int32_t)shape.count) != kTfLiteOk) {
|
||||
NSString *errorDescription = [NSString
|
||||
stringWithFormat:@"Failed to resize input tensor at index (%lu).", (unsigned long)index];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToResizeInputTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
return YES;
|
||||
}
|
||||
|
||||
- (BOOL)allocateTensorsWithError:(NSError **)error {
|
||||
if (TFL_InterpreterAllocateTensors(self.interpreter) != kTfLiteOk) {
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToAllocateTensors
|
||||
description:@"Failed to allocate memory for tensors."
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
return YES;
|
||||
}
|
||||
|
||||
#pragma mark - TFLInterpreter (Internal)
|
||||
|
||||
- (BOOL)copyData:(NSData *)data toInputTensorAtIndex:(NSUInteger)index error:(NSError **)error {
|
||||
const TFL_Tensor *cTensor = [self cTensorOfType:TFLTensorTypeInput atIndex:index error:error];
|
||||
if (cTensor == nullptr) {
|
||||
return NO;
|
||||
}
|
||||
|
||||
NSUInteger byteSize = (NSUInteger)TFL_TensorByteSize(cTensor);
|
||||
if (data.length != byteSize) {
|
||||
NSString *errorDescription = [NSString
|
||||
stringWithFormat:@"Input tensor at index (%lu) expects data size (%lu), but got (%lu).",
|
||||
(unsigned long)index, byteSize, (unsigned long)data.length];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidInputByteSize
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
if (TFL_TensorCopyFromBuffer((TFL_Tensor *)cTensor, data.bytes, data.length) != kTfLiteOk) {
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"Failed to copy data into input tensor at index (%lu).",
|
||||
(unsigned long)index];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCopyDataToInputTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
return YES;
|
||||
}
|
||||
|
||||
- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error {
|
||||
const TFL_Tensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error];
|
||||
if (cTensor == nullptr) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
void *bytes = TFL_TensorData(cTensor);
|
||||
NSUInteger byteSize = (NSUInteger)TFL_TensorByteSize(cTensor);
|
||||
if (bytes == nullptr || byteSize == 0) {
|
||||
NSString *tensorType = [TFLTensor stringForTensorType:tensor.type];
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"Failed to get data from %@ tensor at index (%lu).", tensorType,
|
||||
(unsigned long)index];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetDataFromTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
return [NSData dataWithBytes:bytes length:byteSize];
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSNumber *> *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error {
|
||||
const TFL_Tensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error];
|
||||
if (cTensor == nullptr) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
NSString *tensorType = [TFLTensor stringForTensorType:tensor.type];
|
||||
int32_t rank = TFL_TensorNumDims(cTensor);
|
||||
if (rank <= 0) {
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid rank (%d).", tensorType,
|
||||
(unsigned long)index, rank];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:rank];
|
||||
for (int32_t dimIndex = 0; dimIndex < rank; dimIndex++) {
|
||||
int32_t dimension = TFL_TensorDim(cTensor, dimIndex);
|
||||
if (dimension <= 0) {
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid %d-th dimension (%d).",
|
||||
tensorType, (unsigned long)index, dimIndex, dimension];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
shape[dimIndex] = @((NSUInteger)dimension);
|
||||
}
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
||||
#pragma mark - Private
|
||||
|
||||
- (const TFL_Tensor *)cTensorOfType:(TFLTensorType)type
|
||||
atIndex:(NSUInteger)index
|
||||
error:(NSError **)error {
|
||||
const TFL_Tensor *tensor = nullptr;
|
||||
|
||||
switch (type) {
|
||||
case TFLTensorTypeInput:
|
||||
tensor = TFL_InterpreterGetInputTensor(self.interpreter, (int32_t)index);
|
||||
break;
|
||||
case TFLTensorTypeOutput:
|
||||
tensor = TFL_InterpreterGetOutputTensor(self.interpreter, (int32_t)index);
|
||||
break;
|
||||
}
|
||||
|
||||
if (tensor == nullptr) {
|
||||
NSString *tensorType = [TFLTensor stringForTensorType:type];
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"Failed to get %@ tensor at index (%lu).", tensorType,
|
||||
(unsigned long)index];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
- (nullable TFLTensor *)tensorOfType:(TFLTensorType)type
|
||||
atIndex:(NSUInteger)index
|
||||
error:(NSError **)error {
|
||||
const TFL_Tensor *tensor = [self cTensorOfType:type atIndex:index error:error];
|
||||
|
||||
if (tensor == nullptr) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
NSString *tensorType = [TFLTensor stringForTensorType:type];
|
||||
const char *cName = TFL_TensorName(tensor);
|
||||
if (cName == nullptr) {
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"Failed to get name of %@ tensor at index (%lu).", tensorType,
|
||||
(unsigned long)index];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return nil;
|
||||
}
|
||||
NSString *name = [NSString stringWithUTF8String:cName];
|
||||
|
||||
TFLTensorDataType dataType = [self tensorDataTypeFromCTensorType:TFL_TensorType(tensor)];
|
||||
|
||||
TFL_QuantizationParams cParams = TFL_TensorQuantizationParams(tensor);
|
||||
TFLQuantizationParameters *quantizationParams;
|
||||
|
||||
// TODO(b/119735362): Update this check once the TFL_QuantizationParams struct has a mode.
|
||||
if (cParams.scale != 0.0) {
|
||||
quantizationParams = [[TFLQuantizationParameters alloc] initWithScale:cParams.scale
|
||||
zeroPoint:cParams.zero_point];
|
||||
}
|
||||
|
||||
// TODO: Set quantization parameters when C API supports it.
|
||||
return [[TFLTensor alloc] initWithInterpreter:self
|
||||
type:type
|
||||
index:index
|
||||
name:name
|
||||
dataType:dataType
|
||||
quantizationParameters:quantizationParams];
|
||||
}
|
||||
|
||||
- (TFLTensorDataType)tensorDataTypeFromCTensorType:(TFL_Type)cTensorType {
|
||||
switch (cTensorType) {
|
||||
case kTfLiteFloat32:
|
||||
return TFLTensorDataTypeFloat32;
|
||||
case kTfLiteInt32:
|
||||
return TFLTensorDataTypeInt32;
|
||||
case kTfLiteUInt8:
|
||||
return TFLTensorDataTypeUInt8;
|
||||
case kTfLiteInt8:
|
||||
return TFLTensorDataTypeInt8;
|
||||
case kTfLiteInt64:
|
||||
return TFLTensorDataTypeInt64;
|
||||
case kTfLiteBool:
|
||||
return TFLTensorDataTypeBool;
|
||||
case kTfLiteInt16:
|
||||
return TFLTensorDataTypeInt16;
|
||||
case kTfLiteNoType:
|
||||
case kTfLiteString:
|
||||
case kTfLiteComplex64:
|
||||
// kTfLiteString and kTfLiteComplex64 are not supported in TensorFlow Lite Objc API.
|
||||
return TFLTensorDataTypeNoType;
|
||||
}
|
||||
}
|
||||
|
||||
- (BOOL)isValidTensorIndex:(NSUInteger)index
|
||||
belowLimit:(NSUInteger)totalTensorCount
|
||||
error:(NSError **)error {
|
||||
if (index >= totalTensorCount) {
|
||||
NSString *errorDescription =
|
||||
[NSString stringWithFormat:@"Invalid tensor index (%lu) exceeds max (%lu).",
|
||||
(unsigned long)index, (unsigned long)(totalTensorCount - 1)];
|
||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensorIndex
|
||||
description:errorDescription
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
return YES;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,30 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation TFLInterpreterOptions
|
||||
|
||||
#pragma mark - Public
|
||||
|
||||
- (instancetype)init {
|
||||
self = [super init];
|
||||
return self;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,33 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface TFLQuantizationParameters (Internal)
|
||||
|
||||
/**
|
||||
* Initializes a `TFLQuantizationParameters` instance with the given scale and zero point.
|
||||
*
|
||||
* @param scale Scale of asymmetric quantization.
|
||||
* @param zeroPoint Zero point of asymmetric quantization.
|
||||
*
|
||||
* @return A new instance of `TFLQuantizationParameters` with the given scale and zero point.
|
||||
*/
|
||||
- (instancetype)initWithScale:(float)scale zeroPoint:(int32_t)zeroPoint;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,36 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h"
|
||||
|
||||
#import "TFLQuantizationParameters+Internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation TFLQuantizationParameters
|
||||
|
||||
#pragma mark - TFLTensor (Internal)
|
||||
|
||||
- (instancetype)initWithScale:(float)scale zeroPoint:(int32_t)zeroPoint {
|
||||
self = [super init];
|
||||
if (self != nil) {
|
||||
_scale = scale;
|
||||
_zeroPoint = zeroPoint;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,74 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
||||
|
||||
@class TFLInterpreter;
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @enum TFLTensorType
|
||||
* This enum specifies input or output tensor types.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, TFLTensorType) {
|
||||
/** Input tensor type. */
|
||||
TFLTensorTypeInput,
|
||||
|
||||
/** Output tensor type. */
|
||||
TFLTensorTypeOutput,
|
||||
};
|
||||
|
||||
@interface TFLTensor (Internal)
|
||||
|
||||
/** Input or output tensor type. */
|
||||
@property(nonatomic, readonly) TFLTensorType type;
|
||||
|
||||
/** Index of the tensor. */
|
||||
@property(nonatomic, readonly) NSUInteger index;
|
||||
|
||||
/**
|
||||
* Initializes a `TFLTensor` with the given interpreter, name, data type, and quantization
|
||||
* parameters.
|
||||
*
|
||||
* @param interpreter Interpreter backing the tensor.
|
||||
* @param type Input or output tensor type.
|
||||
* @param index Index of the tensor.
|
||||
* @param name Name of the tensor.
|
||||
* @param dataType Data type of the tensor.
|
||||
* @param quantizationParameters Quantization parameters of the tensor. `nil` if the tensor does not
|
||||
* use quantization.
|
||||
*
|
||||
* @return A new instance of `TFLTensor` with the given name, data type, shape, and quantization
|
||||
* parameters.
|
||||
*/
|
||||
- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter
|
||||
type:(TFLTensorType)type
|
||||
index:(NSUInteger)index
|
||||
name:(NSString *)name
|
||||
dataType:(TFLTensorDataType)dataType
|
||||
quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters;
|
||||
|
||||
/**
|
||||
* Returns the string name of the given input or output tensor type.
|
||||
*
|
||||
* @param type Input or output tensor type.
|
||||
*
|
||||
* @return The string name of the given input or output tensor type.
|
||||
*/
|
||||
+ (NSString *)stringForTensorType:(TFLTensorType)type;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
103
tensorflow/lite/experimental/objc/sources/TFLTensor.m
Normal file
103
tensorflow/lite/experimental/objc/sources/TFLTensor.m
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
||||
|
||||
#import "TFLErrorUtil.h"
|
||||
#import "TFLInterpreter+Internal.h"
|
||||
#import "TFLTensor+Internal.h"
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
// String names of input or output tensor types.
|
||||
static NSString *const kTFLInputTensorTypeString = @"input";
|
||||
static NSString *const kTFLOutputTensorTypeString = @"output";
|
||||
|
||||
@interface TFLTensor ()
|
||||
|
||||
// Redefines readonly properties.
|
||||
@property(nonatomic) TFLTensorType type;
|
||||
@property(nonatomic) NSUInteger index;
|
||||
@property(nonatomic, copy) NSString *name;
|
||||
@property(nonatomic) TFLTensorDataType dataType;
|
||||
@property(nonatomic, nullable) TFLQuantizationParameters *quantizationParameters;
|
||||
|
||||
/**
|
||||
* The backing interpreter. It's a strong reference to ensure that the interpreter is never released
|
||||
* before this tensor is released.
|
||||
*
|
||||
* @warning Never let the interpreter hold a strong reference to the tensor to avoid retain cycles.
|
||||
*/
|
||||
@property(nonatomic) TFLInterpreter *interpreter;
|
||||
|
||||
@end
|
||||
|
||||
@implementation TFLTensor
|
||||
|
||||
#pragma mark - Public
|
||||
|
||||
- (BOOL)copyData:(NSData *)data error:(NSError **)error {
|
||||
if (self.type == TFLTensorTypeOutput) {
|
||||
[TFLErrorUtil
|
||||
saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed
|
||||
description:@"Cannot copy data into an output tensor."
|
||||
error:error];
|
||||
return NO;
|
||||
}
|
||||
|
||||
return [self.interpreter copyData:data toInputTensorAtIndex:self.index error:error];
|
||||
}
|
||||
|
||||
- (nullable NSData *)dataWithError:(NSError **)error {
|
||||
return [self.interpreter dataFromTensor:self error:error];
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSNumber *> *)shapeWithError:(NSError **)error {
|
||||
return [self.interpreter shapeOfTensor:self error:error];
|
||||
}
|
||||
|
||||
#pragma mark - TFLTensor (Internal)
|
||||
|
||||
- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter
|
||||
type:(TFLTensorType)type
|
||||
index:(NSUInteger)index
|
||||
name:(NSString *)name
|
||||
dataType:(TFLTensorDataType)dataType
|
||||
quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters {
|
||||
self = [super init];
|
||||
if (self != nil) {
|
||||
_interpreter = interpreter;
|
||||
_type = type;
|
||||
_index = index;
|
||||
_name = [name copy];
|
||||
_dataType = dataType;
|
||||
_quantizationParameters = quantizationParameters;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
+ (NSString *)stringForTensorType:(TFLTensorType)type {
|
||||
switch (type) {
|
||||
case TFLTensorTypeInput:
|
||||
return kTFLInputTensorTypeString;
|
||||
case TFLTensorTypeOutput:
|
||||
return kTFLOutputTensorTypeString;
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,49 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* Unit tests for TFLInterpreterOptions.
|
||||
*/
|
||||
@interface TFLInterpreterOptionsTests : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation TFLInterpreterOptionsTests
|
||||
|
||||
#pragma mark - Tests
|
||||
|
||||
- (void)testInit {
|
||||
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||
XCTAssertNotNil(options);
|
||||
XCTAssertEqual(options.numberOfThreads, 0);
|
||||
}
|
||||
|
||||
- (void)testSetNumberOfThread {
|
||||
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||
options.numberOfThreads = 2;
|
||||
XCTAssertEqual(options.numberOfThreads, 2);
|
||||
options.numberOfThreads = 0;
|
||||
XCTAssertEqual(options.numberOfThreads, 0);
|
||||
options.numberOfThreads = 3;
|
||||
XCTAssertEqual(options.numberOfThreads, 3);
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
358
tensorflow/lite/experimental/objc/tests/TFLInterpreterTests.m
Normal file
358
tensorflow/lite/experimental/objc/tests/TFLInterpreterTests.m
Normal file
@ -0,0 +1,358 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h"
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Float model resource name. */
|
||||
static NSString *const kAddFloatModelResourceName = @"add";
|
||||
|
||||
/** Quantized model resource name. */
|
||||
static NSString *const kAddQuantizedModelResourceName = @"add_quantized";
|
||||
|
||||
/** Model resource type. */
|
||||
static NSString *const kAddModelResourceType = @"bin";
|
||||
|
||||
/** Rank of the input and output tensor in the Add model. */
|
||||
static const NSUInteger kAddModelTensorRank = 1U;
|
||||
|
||||
/** Size of the first (and only) dimension of the input and output tensor in the Add model. */
|
||||
static const NSUInteger kAddModelTensorFirstDimensionSize = 2U;
|
||||
|
||||
/** Quantization scale of the quantized model. */
|
||||
static const float kAddQuantizedModelScale = 0.003922F;
|
||||
|
||||
/** Quantization zero point of the quantized model. */
|
||||
static const int32_t kAddQuantizedModelZeroPoint = 0;
|
||||
|
||||
/** Invalid input tensor index. */
|
||||
static const NSUInteger kInvalidInputTensorIndex = 1U;
|
||||
|
||||
/** Invalid output tensor index. */
|
||||
static const NSUInteger kInvalidOutputTensorIndex = 1U;
|
||||
|
||||
/** Accurary used in comparing floating numbers. */
|
||||
static const float kTestAccuracy = 1E-5F;
|
||||
|
||||
/**
|
||||
* Unit tests for TFLInterpreter.
|
||||
*/
|
||||
@interface TFLInterpreterTests : XCTestCase
|
||||
|
||||
/** Absolute path of the Add float model resource. */
|
||||
@property(nonatomic, nullable) NSString *floatModelPath;
|
||||
|
||||
/** Default interpreter using the Add model. */
|
||||
@property(nonatomic, nullable) TFLInterpreter *interpreter;
|
||||
|
||||
@end
|
||||
|
||||
@implementation TFLInterpreterTests
|
||||
|
||||
#pragma mark - XCTestCase
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
|
||||
NSBundle *bundle = [NSBundle bundleForClass:[self class]];
|
||||
self.floatModelPath = [bundle pathForResource:kAddFloatModelResourceName
|
||||
ofType:kAddModelResourceType];
|
||||
NSError *error;
|
||||
self.interpreter = [[TFLInterpreter alloc] initWithModelPath:self.floatModelPath error:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertNotNil(self.interpreter);
|
||||
XCTAssertTrue([self.interpreter allocateTensorsWithError:nil]);
|
||||
}
|
||||
|
||||
- (void)tearDown {
|
||||
self.floatModelPath = nil;
|
||||
self.interpreter = nil;
|
||||
|
||||
[super tearDown];
|
||||
}
|
||||
|
||||
#pragma mark - Tests
|
||||
|
||||
- (void)testSuccessfulFullRunAddFloatModel {
|
||||
// Shape for both input and output tensor.
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank];
|
||||
shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize];
|
||||
|
||||
// Creates the interpreter options.
|
||||
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||
XCTAssertNotNil(options);
|
||||
options.numberOfThreads = 2;
|
||||
|
||||
// Creates the interpreter.
|
||||
NSError *error;
|
||||
TFLInterpreter *customInterpreter = [[TFLInterpreter alloc] initWithModelPath:self.floatModelPath
|
||||
options:options
|
||||
error:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertNotNil(customInterpreter);
|
||||
|
||||
// Allocates memory for tensors.
|
||||
XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies input and output tensor counts.
|
||||
XCTAssertEqual(customInterpreter.inputTensorCount, 1);
|
||||
XCTAssertEqual(customInterpreter.outputTensorCount, 1);
|
||||
|
||||
// Resizes the intput tensor.
|
||||
XCTAssertTrue([customInterpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Re-allocates memory for tensors.
|
||||
XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies the input tensor.
|
||||
TFLTensor *inputTensor = [customInterpreter inputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(inputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([inputTensor.name isEqualToString:@"input"]);
|
||||
XCTAssertEqual(inputTensor.dataType, TFLTensorDataTypeFloat32);
|
||||
NSArray *inputTensorShape = [inputTensor shapeWithError:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([shape isEqualToArray:inputTensorShape]);
|
||||
|
||||
// Copies the input data.
|
||||
NSMutableData *inputData = [NSMutableData dataWithCapacity:0];
|
||||
float one = 1.f;
|
||||
float three = 3.f;
|
||||
[inputData appendBytes:&one length:sizeof(float)];
|
||||
[inputData appendBytes:&three length:sizeof(float)];
|
||||
XCTAssertTrue([inputTensor copyData:inputData error:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Invokes the interpreter.
|
||||
XCTAssertTrue([customInterpreter invokeWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies the output tensor.
|
||||
TFLTensor *outputTensor = [customInterpreter outputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(outputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([outputTensor.name isEqualToString:@"output"]);
|
||||
XCTAssertEqual(outputTensor.dataType, TFLTensorDataTypeFloat32);
|
||||
NSArray *outputTensorShape = [outputTensor shapeWithError:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([shape isEqualToArray:outputTensorShape]);
|
||||
|
||||
// Tries to query an invalid output tensor index.
|
||||
TFLTensor *invalidOutputTensor = [customInterpreter outputTensorAtIndex:kInvalidOutputTensorIndex
|
||||
error:&error];
|
||||
XCTAssertNil(invalidOutputTensor);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex);
|
||||
|
||||
// Gets the output tensor data.
|
||||
error = nil;
|
||||
NSData *outputData = [outputTensor dataWithError:&error];
|
||||
XCTAssertNotNil(outputData);
|
||||
XCTAssertNil(error);
|
||||
float output[kAddModelTensorFirstDimensionSize];
|
||||
[outputData getBytes:output length:(sizeof(float) * kAddModelTensorFirstDimensionSize)];
|
||||
XCTAssertEqualWithAccuracy(output[0], 3.f, kTestAccuracy);
|
||||
XCTAssertEqualWithAccuracy(output[1], 9.f, kTestAccuracy);
|
||||
}
|
||||
|
||||
- (void)testSuccessfulFullRunQuantizedModel {
|
||||
// Shape for both input and output tensor.
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank];
|
||||
shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize];
|
||||
|
||||
// Creates the interpreter options.
|
||||
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||
XCTAssertNotNil(options);
|
||||
options.numberOfThreads = 2;
|
||||
|
||||
NSBundle *bundle = [NSBundle bundleForClass:[self class]];
|
||||
NSString *quantizedModelPath = [bundle pathForResource:kAddQuantizedModelResourceName
|
||||
ofType:kAddModelResourceType];
|
||||
|
||||
// Creates the interpreter.
|
||||
NSError *error;
|
||||
TFLInterpreter *customInterpreter =
|
||||
[[TFLInterpreter alloc] initWithModelPath:quantizedModelPath options:options error:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertNotNil(customInterpreter);
|
||||
|
||||
// Allocates memory for tensors.
|
||||
XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies input and output tensor counts.
|
||||
XCTAssertEqual(customInterpreter.inputTensorCount, 1);
|
||||
XCTAssertEqual(customInterpreter.outputTensorCount, 1);
|
||||
|
||||
// Resizes the intput tensor.
|
||||
XCTAssertTrue([customInterpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Re-allocates memory for tensors.
|
||||
XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies the input tensor.
|
||||
TFLTensor *inputTensor = [customInterpreter inputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(inputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([inputTensor.name isEqualToString:@"input"]);
|
||||
XCTAssertEqual(inputTensor.dataType, TFLTensorDataTypeUInt8);
|
||||
XCTAssertEqualWithAccuracy(inputTensor.quantizationParameters.scale, kAddQuantizedModelScale,
|
||||
kTestAccuracy);
|
||||
XCTAssertEqual(inputTensor.quantizationParameters.zeroPoint, kAddQuantizedModelZeroPoint);
|
||||
NSArray *inputTensorShape = [inputTensor shapeWithError:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([shape isEqualToArray:inputTensorShape]);
|
||||
|
||||
// Copies the input data.
|
||||
NSMutableData *inputData = [NSMutableData dataWithCapacity:0];
|
||||
uint8_t one = 1;
|
||||
uint8_t three = 3;
|
||||
[inputData appendBytes:&one length:sizeof(uint8_t)];
|
||||
[inputData appendBytes:&three length:sizeof(uint8_t)];
|
||||
XCTAssertTrue([inputTensor copyData:inputData error:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Invokes the interpreter.
|
||||
XCTAssertTrue([customInterpreter invokeWithError:&error]);
|
||||
XCTAssertNil(error);
|
||||
|
||||
// Verifies the output tensor.
|
||||
TFLTensor *outputTensor = [customInterpreter outputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(outputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([outputTensor.name isEqualToString:@"output"]);
|
||||
XCTAssertEqual(outputTensor.dataType, TFLTensorDataTypeUInt8);
|
||||
XCTAssertEqualWithAccuracy(outputTensor.quantizationParameters.scale, kAddQuantizedModelScale,
|
||||
kTestAccuracy);
|
||||
XCTAssertEqual(outputTensor.quantizationParameters.zeroPoint, kAddQuantizedModelZeroPoint);
|
||||
NSArray *outputTensorShape = [outputTensor shapeWithError:&error];
|
||||
XCTAssertNil(error);
|
||||
XCTAssertTrue([shape isEqualToArray:outputTensorShape]);
|
||||
|
||||
// Tries to query an invalid output tensor index.
|
||||
TFLTensor *invalidOutputTensor = [customInterpreter outputTensorAtIndex:kInvalidOutputTensorIndex
|
||||
error:&error];
|
||||
XCTAssertNil(invalidOutputTensor);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex);
|
||||
|
||||
// Gets the output tensor data.
|
||||
error = nil;
|
||||
NSData *outputData = [outputTensor dataWithError:&error];
|
||||
XCTAssertNotNil(outputData);
|
||||
XCTAssertNil(error);
|
||||
uint8_t output[kAddModelTensorFirstDimensionSize];
|
||||
[outputData getBytes:output length:(sizeof(uint8_t) * kAddModelTensorFirstDimensionSize)];
|
||||
XCTAssertEqual(output[0], 3);
|
||||
XCTAssertEqual(output[1], 9);
|
||||
}
|
||||
|
||||
- (void)testInitWithModelPath_invalidPath {
|
||||
// Shape for both input and output tensor.
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank];
|
||||
shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize];
|
||||
|
||||
// Creates the interpreter.
|
||||
NSError *error;
|
||||
TFLInterpreter *brokenInterpreter = [[TFLInterpreter alloc] initWithModelPath:@"InvalidPath"
|
||||
error:&error];
|
||||
XCTAssertNil(brokenInterpreter);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeFailedToLoadModel);
|
||||
}
|
||||
|
||||
- (void)testInvoke_beforeAllocation {
|
||||
NSError *error;
|
||||
TFLInterpreter *interpreterWithoutAllocation =
|
||||
[[TFLInterpreter alloc] initWithModelPath:self.floatModelPath error:&error];
|
||||
XCTAssertNotNil(interpreterWithoutAllocation);
|
||||
XCTAssertNil(error);
|
||||
|
||||
XCTAssertFalse([interpreterWithoutAllocation invokeWithError:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeFailedToInvoke);
|
||||
}
|
||||
|
||||
- (void)testInputTensorAtIndex_invalidIndex {
|
||||
NSError *error;
|
||||
TFLTensor *inputTensor = [self.interpreter inputTensorAtIndex:kInvalidInputTensorIndex
|
||||
error:&error];
|
||||
XCTAssertNil(inputTensor);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex);
|
||||
}
|
||||
|
||||
- (void)testResizeInputTensorAtIndex_invalidIndex {
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank];
|
||||
shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize];
|
||||
NSError *error;
|
||||
XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:kInvalidInputTensorIndex
|
||||
toShape:shape
|
||||
error:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex);
|
||||
}
|
||||
|
||||
- (void)testResizeInputTensorAtIndex_emptyShape {
|
||||
NSMutableArray *emptyShape = [NSMutableArray arrayWithCapacity:0];
|
||||
NSError *error;
|
||||
XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:0 toShape:emptyShape error:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidShape);
|
||||
}
|
||||
|
||||
- (void)testResizeInputTensorAtIndex_zeroDimensionSize {
|
||||
NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank];
|
||||
shape[0] = [NSNumber numberWithUnsignedInteger:0];
|
||||
NSError *error;
|
||||
XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidShape);
|
||||
}
|
||||
|
||||
- (void)testCopyDataToInputTensorAtIndex_invalidInputDataByteSize {
|
||||
NSMutableData *inputData = [NSMutableData dataWithCapacity:0];
|
||||
float one = 1.f;
|
||||
float three = 3.f;
|
||||
[inputData appendBytes:&one length:sizeof(float)];
|
||||
[inputData appendBytes:&three length:(sizeof(float) - 1)];
|
||||
NSError *error;
|
||||
TFLTensor *inputTensor = [self.interpreter inputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(inputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertFalse([inputTensor copyData:inputData error:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidInputByteSize);
|
||||
}
|
||||
|
||||
- (void)testCopyDataToOutputTensorAtIndex_notAllowed {
|
||||
NSMutableData *data = [NSMutableData dataWithCapacity:0];
|
||||
float one = 1.f;
|
||||
float three = 3.f;
|
||||
[data appendBytes:&one length:sizeof(float)];
|
||||
[data appendBytes:&three length:(sizeof(float) - 1)];
|
||||
NSError *error;
|
||||
TFLTensor *outputTensor = [self.interpreter outputTensorAtIndex:0 error:&error];
|
||||
XCTAssertNotNil(outputTensor);
|
||||
XCTAssertNil(error);
|
||||
XCTAssertFalse([outputTensor copyData:data error:&error]);
|
||||
XCTAssertEqual(error.code, TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed);
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
@ -0,0 +1,48 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters+Internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Test scale of quantization parameters. */
|
||||
static const float kTestScale = 2.0;
|
||||
|
||||
/** Test zero point of quantization parameters. */
|
||||
static const int32_t kTestZeroPoint = 128;
|
||||
|
||||
/**
|
||||
* Unit tests for TFLQuantizationParameters.
|
||||
*/
|
||||
@interface TFLQuantizationParametersTests : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation TFLQuantizationParametersTests
|
||||
|
||||
#pragma mark - Tests
|
||||
|
||||
- (void)testInitWithScaleAndZeroPoint {
|
||||
TFLQuantizationParameters *params =
|
||||
[[TFLQuantizationParameters alloc] initWithScale:kTestScale zeroPoint:kTestZeroPoint];
|
||||
XCTAssertEqual(params.scale, kTestScale);
|
||||
XCTAssertEqual(params.zeroPoint, kTestZeroPoint);
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
101
tensorflow/lite/experimental/swift/BUILD.apple
Normal file
101
tensorflow/lite/experimental/swift/BUILD.apple
Normal file
@ -0,0 +1,101 @@
|
||||
# TensorFlow Lite for Swift.
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_unit_test")
|
||||
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
|
||||
|
||||
MINIMUM_OS_VERSION = "9.0"
|
||||
|
||||
SWIFT_COPTS = [
|
||||
"-wmo",
|
||||
]
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLite",
|
||||
srcs = glob(["Sources/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
module_name = "TensorFlowLite",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/c:c_api",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "TensorFlowLiteTests",
|
||||
size = "small",
|
||||
minimum_os_version = MINIMUM_OS_VERSION,
|
||||
tags = [
|
||||
"manual",
|
||||
# DISABLED: Following sanitizer tests are not supported by iOS test targets.
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [":TensorFlowLiteTestsLib"],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteTestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob(["Tests/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
":TestResources",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TestResources",
|
||||
resources = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "TensorFlowLiteApp",
|
||||
app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]),
|
||||
bundle_id = "com.tensorflow.lite.swift.TensorFlowLite",
|
||||
families = [
|
||||
"ipad",
|
||||
"iphone",
|
||||
],
|
||||
infoplists = ["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist"],
|
||||
launch_storyboard = "TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard",
|
||||
minimum_os_version = MINIMUM_OS_VERSION,
|
||||
sdk_frameworks = [
|
||||
"CoreGraphics",
|
||||
],
|
||||
tags = ["manual"],
|
||||
deps = [":TensorFlowLiteAppLib"],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteAppLib",
|
||||
srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]),
|
||||
module_name = "TensorFlowLiteAppLib",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
":TensorFlowLiteAppResources",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLiteAppResources",
|
||||
storyboards = glob([
|
||||
"TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard",
|
||||
]),
|
||||
tags = ["manual"],
|
||||
deps = [":TestResources"],
|
||||
)
|
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
76
tensorflow/lite/experimental/swift/README.md
Normal file
76
tensorflow/lite/experimental/swift/README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# TensorFlow Lite for Swift
|
||||
|
||||
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
|
||||
solution for Swift developers. It enables low-latency inference of on-device
|
||||
machine learning models with a small binary size and fast performance supporting
|
||||
hardware acceleration.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Bazel
|
||||
|
||||
In your `BUILD` file, add the `TensorFlowLite` dependency:
|
||||
|
||||
```python
|
||||
swift_library(
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/swift:TensorFlowLite",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
In your Swift files, import the module:
|
||||
|
||||
```swift
|
||||
import TensorFlowLite
|
||||
```
|
||||
|
||||
If you would like to build the Swift TensorFlow Lite library using Bazel on Apple
|
||||
platforms, clone or download the [TensorFlow GitHub repo](https://github.com/tensorflow/tensorflow),
|
||||
then navigate to the root `tensorflow` directory and execute the `configure.py` script:
|
||||
|
||||
```shell
|
||||
python configure.py
|
||||
```
|
||||
|
||||
Follow the prompts and when asked to configure the Bazel rules for Apple
|
||||
platforms, enter `y`.
|
||||
|
||||
Build the `TensorFlowLite` Swift library target:
|
||||
|
||||
```shell
|
||||
bazel build tensorflow/lite/experimental/swift:TensorFlowLite
|
||||
```
|
||||
|
||||
Build the `TensorFlowLiteTests` target:
|
||||
|
||||
```shell
|
||||
bazel test tensorflow/lite/experimental/swift:TensorFlowLiteTests --swiftcopt=-enable-testing
|
||||
```
|
||||
|
||||
### Tulsi
|
||||
|
||||
Open the `TensorFlowLite.tulsiproj` using the [TulsiApp](https://github.com/bazelbuild/tulsi) or by
|
||||
running the [`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||
script:
|
||||
|
||||
```shell
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||
```
|
||||
|
||||
### CocoaPods
|
||||
|
||||
Add the following to your `Podfile`:
|
||||
|
||||
```ruby
|
||||
use_frameworks!
|
||||
pod 'TensorFlowLiteSwift'
|
||||
```
|
||||
|
||||
Then, run `pod install`.
|
||||
|
||||
In your Swift files, import the module:
|
||||
|
||||
```swift
|
||||
import TensorFlowLite
|
||||
```
|
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
@ -0,0 +1,265 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
||||
public final class Interpreter {
|
||||
|
||||
/// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
|
||||
private typealias CInterpreter = OpaquePointer
|
||||
|
||||
/// Total number of input tensors associated with the model.
|
||||
public var inputTensorCount: Int {
|
||||
return Int(TFL_InterpreterGetInputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// Total number of output tensors associated with the model.
|
||||
public var outputTensorCount: Int {
|
||||
return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// The underlying `TFL_Interpreter` C pointer.
|
||||
private var cInterpreter: CInterpreter?
|
||||
|
||||
/// Creates a new model interpreter instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that
|
||||
/// interpreter will determine the configuration options.
|
||||
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
||||
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||
|
||||
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
||||
guard let cOptions = TFL_NewInterpreterOptions() else {
|
||||
throw InterpreterError.failedToCreateInterpreter
|
||||
}
|
||||
if let threadCount = options.threadCount, threadCount > 0 {
|
||||
TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount))
|
||||
}
|
||||
if options.isErrorLoggingEnabled {
|
||||
TFL_InterpreterOptionsSetErrorReporter(
|
||||
cOptions,
|
||||
{ (_, format, arguments) in
|
||||
guard let cFormat = format,
|
||||
let message = String(cFormat: cFormat, arguments: arguments)
|
||||
else {
|
||||
return
|
||||
}
|
||||
print(String(describing: InterpreterError.tensorFlowLiteError(message)))
|
||||
},
|
||||
nil
|
||||
)
|
||||
}
|
||||
return cOptions
|
||||
}
|
||||
defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) }
|
||||
|
||||
guard let cInterpreter = TFL_NewInterpreter(model.cModel, cInterpreterOptions) else {
|
||||
throw InterpreterError.failedToCreateInterpreter
|
||||
}
|
||||
self.cInterpreter = cInterpreter
|
||||
}
|
||||
|
||||
deinit {
|
||||
TFL_DeleteInterpreter(cInterpreter)
|
||||
}
|
||||
|
||||
/// Invokes the interpreter to perform inference from the loaded graph.
|
||||
///
|
||||
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
||||
public func invoke() throws {
|
||||
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the input tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - index: The index for the input tensor.
|
||||
/// - Throws: An error if the index is invalid or the tensors have not been allocated.
|
||||
/// - Returns: The input tensor at the given index.
|
||||
public func input(at index: Int) throws -> Tensor {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TFL_TensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||
let scale = cQuantizationParams.scale
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
return tensor
|
||||
}
|
||||
|
||||
/// Returns the output tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - index: The index for the output tensor.
|
||||
/// - Throws: An error if the index is invalid, tensors haven't been allocated, or interpreter
|
||||
/// hasn't been invoked for models that dynamically compute output tensors based on the values
|
||||
/// of its input tensors.
|
||||
/// - Returns: The output tensor at the given index.
|
||||
public func output(at index: Int) throws -> Tensor {
|
||||
let maxIndex = outputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.invokeInterpreterRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TFL_TensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||
let scale = cQuantizationParams.scale
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
return tensor
|
||||
}
|
||||
|
||||
/// Resizes the input tensor at the given index to the specified tensor shape.
|
||||
///
|
||||
/// - Note: After resizing an input tensor, the client **must** explicitly call
|
||||
/// `allocateTensors()` before attempting to access the resized tensor data or invoking the
|
||||
/// interpreter to perform inference.
|
||||
/// - Parameters:
|
||||
/// - index: The index for the input tensor.
|
||||
/// - shape: The shape that the input tensor should be resized to.
|
||||
/// - Throws: An error if the input tensor at the given index could not be resized.
|
||||
public func resizeInput(at index: Int, to shape: TensorShape) throws {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard TFL_InterpreterResizeInputTensor(
|
||||
cInterpreter,
|
||||
Int32(index),
|
||||
shape.int32Dimensions,
|
||||
Int32(shape.rank)
|
||||
) == kTfLiteOk
|
||||
else {
|
||||
throw InterpreterError.failedToResizeInputTensor(index: index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the given data to the input tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - data: The data to be copied to the input tensor's data buffer.
|
||||
/// - index: The index for the input tensor.
|
||||
/// - Throws: An error if the `data.count` does not match the input tensor's `data.count` or if
|
||||
/// the given index is invalid.
|
||||
/// - Returns: The input tensor with the copied data.
|
||||
@discardableResult
|
||||
public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)) else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
guard data.count == byteCount else {
|
||||
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
||||
}
|
||||
|
||||
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||
return try input(at: index)
|
||||
}
|
||||
|
||||
/// Allocates memory for all input tensors based on their `TensorShape`s.
|
||||
///
|
||||
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
||||
/// interpreter and/or resizing any input tensors.
|
||||
/// - Throws: An error if memory could not be allocated for the input tensors.
|
||||
public func allocateTensors() throws {
|
||||
guard TFL_InterpreterAllocateTensors(cInterpreter) == kTfLiteOk else {
|
||||
throw InterpreterError.failedToAllocateTensors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension String {
|
||||
/// Returns a new `String` initialized by using the given format C array as a template into which
|
||||
/// the remaining argument values are substituted according to the user’s default locale.
|
||||
///
|
||||
/// - Note: Returns `nil` if a new `String` could not be constructed from the given values.
|
||||
/// - Parameters:
|
||||
/// - cFormat: The format C array as a template for substituting values.
|
||||
/// - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`.
|
||||
init?(cFormat: UnsafePointer<CChar>, arguments: CVaListPointer) {
|
||||
var buffer: UnsafeMutablePointer<CChar>?
|
||||
guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil }
|
||||
self.init(validatingUTF8: cString)
|
||||
}
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// TensorFlow Lite interpreter errors.
|
||||
public enum InterpreterError: Error {
|
||||
case invalidTensorIndex(index: Int, maxIndex: Int)
|
||||
case invalidTensorDataCount(provided: Int, required: Int)
|
||||
case invalidTensorDataType
|
||||
case failedToLoadModel
|
||||
case failedToCreateInterpreter
|
||||
case failedToResizeInputTensor(index: Int)
|
||||
case failedToCopyDataToInputTensor
|
||||
case failedToAllocateTensors
|
||||
case allocateTensorsRequired
|
||||
case invokeInterpreterRequired
|
||||
case tensorFlowLiteError(String)
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension InterpreterError: LocalizedError {
|
||||
/// Localized description of the interpreter error.
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .invalidTensorIndex(let index, let maxIndex):
|
||||
return "Invalid tensor index \(index), max index is \(maxIndex)."
|
||||
case .invalidTensorDataCount(let providedCount, let requiredCount):
|
||||
return "Provided data count \(providedCount) must match the required count \(requiredCount)."
|
||||
case .invalidTensorDataType:
|
||||
return "Tensor data type is unsupported or could not be determined because of a model error."
|
||||
case .failedToLoadModel:
|
||||
return "Failed to load the given model."
|
||||
case .failedToCreateInterpreter:
|
||||
return "Failed to create the interpreter."
|
||||
case .failedToResizeInputTensor(let index):
|
||||
return "Failed to resize input tesnor at index \(index)."
|
||||
case .failedToCopyDataToInputTensor:
|
||||
return "Failed to copy data to input tensor."
|
||||
case .failedToAllocateTensors:
|
||||
return "Failed to allocate memory for input tensors."
|
||||
case .allocateTensorsRequired:
|
||||
return "Must call allocateTensors()."
|
||||
case .invokeInterpreterRequired:
|
||||
return "Must call invoke()."
|
||||
case .tensorFlowLiteError(let message):
|
||||
return "TensorFlow Lite Error: \(message)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension InterpreterError: CustomStringConvertible {
|
||||
/// Textual representation of the TensorFlow Lite interpreter error.
|
||||
public var description: String {
|
||||
return errorDescription ?? "Unknown error."
|
||||
}
|
||||
}
|
||||
|
||||
#if swift(>=4.2)
|
||||
extension InterpreterError: Equatable {}
|
||||
#else
|
||||
extension InterpreterError: Equatable {
|
||||
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
|
||||
switch (lhs, rhs) {
|
||||
case (.invalidTensorDataType, .invalidTensorDataType),
|
||||
(.failedToLoadModel, .failedToLoadModel),
|
||||
(.failedToCreateInterpreter, .failedToCreateInterpreter),
|
||||
(.failedToAllocateTensors, .failedToAllocateTensors),
|
||||
(.allocateTensorsRequired, .allocateTensorsRequired),
|
||||
(.invokeInterpreterRequired, .invokeInterpreterRequired):
|
||||
return true
|
||||
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
|
||||
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
|
||||
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
|
||||
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
|
||||
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
|
||||
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
|
||||
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
|
||||
return lhsIndex == rhsIndex
|
||||
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
|
||||
return lhsMessage == rhsMessage
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // swift(>=4.2)
|
@ -0,0 +1,29 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Custom configuration options for a TensorFlow Lite interpreter.
|
||||
public struct InterpreterOptions: Equatable {
|
||||
|
||||
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
|
||||
/// indicates that the `Interpreter` will decide the number of threads to use.
|
||||
public var threadCount: Int? = nil
|
||||
|
||||
/// Whether error logging to the console is enabled. The default is `false`.
|
||||
public var isErrorLoggingEnabled = false
|
||||
|
||||
/// Creates a new instance of interpreter options.
|
||||
public init() {}
|
||||
}
|
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
|
||||
final class Model {
|
||||
|
||||
/// The `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
|
||||
typealias CModel = OpaquePointer
|
||||
|
||||
/// The underlying `TFL_Model` C pointer.
|
||||
let cModel: CModel?
|
||||
|
||||
/// Creates a new model instance.
|
||||
///
|
||||
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
||||
/// - Parameters:
|
||||
/// - filePath: Local file path to a TensorFlow Lite model.
|
||||
init?(filePath: String) {
|
||||
guard !filePath.isEmpty, let cModel = TFL_NewModelFromFile(filePath) else { return nil }
|
||||
self.cModel = cModel
|
||||
}
|
||||
|
||||
deinit {
|
||||
TFL_DeleteModel(cModel)
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Parameters that determine the mapping of quantized values to real values. Quantized values can
|
||||
/// be mapped to float values using the following conversion:
|
||||
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||
public struct QuantizationParameters {
|
||||
|
||||
/// Difference between real values corresponding to consecutive quantized values differing by 1.
|
||||
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||
public let scale: Float
|
||||
|
||||
/// Quantized value that corresponds to the real 0 value.
|
||||
public let zeroPoint: Int
|
||||
|
||||
/// Creates a new quantization parameters instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - scale: Scale value for asymmetric quantization.
|
||||
/// - zeroPoint: Zero point for asymmetric quantization.
|
||||
init(scale: Float, zeroPoint: Int) {
|
||||
self.scale = scale
|
||||
self.zeroPoint = zeroPoint
|
||||
}
|
||||
}
|
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// An input or output tensor in a TensorFlow Lite graph.
|
||||
public struct Tensor {
|
||||
|
||||
/// Name of the tensor.
|
||||
public let name: String
|
||||
|
||||
/// Data type of the tensor.
|
||||
public let dataType: TensorDataType
|
||||
|
||||
/// Shape of the tensor.
|
||||
public let shape: TensorShape
|
||||
|
||||
/// Data in the input or output tensor.
|
||||
public let data: Data
|
||||
|
||||
/// Quantization parameters for the tensor if using a quantized model.
|
||||
public let quantizationParameters: QuantizationParameters?
|
||||
|
||||
/// Creates a new input or output tensor instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - name: Name of the tensor.
|
||||
/// - dataType: Data type of the tensor.
|
||||
/// - data: Data in the input tensor.
|
||||
/// - quantizationParameters Quantization parameters for the tensor if using a quantized model.
|
||||
/// The default is `nil`.
|
||||
init(
|
||||
name: String,
|
||||
dataType: TensorDataType,
|
||||
shape: TensorShape,
|
||||
data: Data,
|
||||
quantizationParameters: QuantizationParameters? = nil
|
||||
) {
|
||||
self.name = name
|
||||
self.dataType = dataType
|
||||
self.shape = shape
|
||||
self.data = data
|
||||
self.quantizationParameters = quantizationParameters
|
||||
}
|
||||
}
|
||||
|
||||
/// Supported TensorFlow Lite tensor data types.
|
||||
public enum TensorDataType: Equatable {
|
||||
/// 32-bit single precision floating point tensor data type.
|
||||
case float32
|
||||
/// 8-bit unsigned integer tensor data type.
|
||||
case uInt8
|
||||
/// 16-bit signed integer tensor data type.
|
||||
case int16
|
||||
/// 32-bit signed integer tensor data type.
|
||||
case int32
|
||||
/// 64-bit signed integer tensor data type.
|
||||
case int64
|
||||
/// Boolean tensor data type.
|
||||
case bool
|
||||
|
||||
/// Creates a new tensor data type from the given `TFL_Type` or `nil` if the data type is
|
||||
/// unsupported or could not be determined because there was an error.
|
||||
///
|
||||
/// - Parameter type: A data type supported by a tensor.
|
||||
init?(type: TFL_Type) {
|
||||
switch type {
|
||||
case kTfLiteFloat32:
|
||||
self = .float32
|
||||
case kTfLiteUInt8:
|
||||
self = .uInt8
|
||||
case kTfLiteInt16:
|
||||
self = .int16
|
||||
case kTfLiteInt32:
|
||||
self = .int32
|
||||
case kTfLiteInt64:
|
||||
self = .int64
|
||||
case kTfLiteBool:
|
||||
self = .bool
|
||||
case kTfLiteNoType:
|
||||
fallthrough
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The shape of a TensorFlow Lite tensor.
|
||||
public struct TensorShape {
|
||||
|
||||
/// The number of dimensions of the tensor.
|
||||
public let rank: Int
|
||||
|
||||
/// Array of dimensions for the tensor.
|
||||
public let dimensions: [Int]
|
||||
|
||||
/// Array of `Int32` dimensions for the tensor.
|
||||
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||
|
||||
/// Creates a new tensor shape instance with the given array of dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - dimensions: Dimensions for the tensor.
|
||||
public init(_ dimensions: [Int]) {
|
||||
self.rank = dimensions.count
|
||||
self.dimensions = dimensions
|
||||
}
|
||||
|
||||
/// Creates a new tensor shape instance with the given elements representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - elements: Dimensions for the tensor.
|
||||
public init(_ elements: Int...) {
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
||||
|
||||
extension TensorShape: ExpressibleByArrayLiteral {
|
||||
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - arrayLiteral: Dimensions for the tensor.
|
||||
public init(arrayLiteral: Int...) {
|
||||
self.init(arrayLiteral)
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
{
|
||||
"sourceFilters" : [
|
||||
"tensorflow/lite/experimental/c",
|
||||
"tensorflow/lite/experimental/swift",
|
||||
"tensorflow/lite/experimental/swift/Sources",
|
||||
"tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp",
|
||||
"tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj",
|
||||
"tensorflow/lite/experimental/swift/Tests",
|
||||
],
|
||||
"buildTargets" : [
|
||||
"//tensorflow/lite/experimental/swift:TensorFlowLite",
|
||||
"//tensorflow/lite/experimental/swift:TensorFlowLiteApp",
|
||||
"//tensorflow/lite/experimental/swift:TensorFlowLiteTests",
|
||||
],
|
||||
"projectName" : "TensorFlowLite",
|
||||
"optionSet" : {
|
||||
"LaunchActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"EnvironmentVariables" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"CommandlineArguments" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"LaunchActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
}
|
||||
},
|
||||
"additionalFilePaths" : [
|
||||
"tensorflow/lite/experimental/swift/BUILD"
|
||||
]
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
{
|
||||
"configDefaults" : {
|
||||
"optionSet" : {
|
||||
"ProjectPrioritizesSwift" : {
|
||||
"p" : "YES"
|
||||
}
|
||||
}
|
||||
},
|
||||
"projectName" : "TensorFlowLite",
|
||||
"packages" : [
|
||||
"tensorflow/lite/experimental/swift"
|
||||
],
|
||||
"workspaceRoot" : "../../../../.."
|
||||
}
|
@ -0,0 +1,345 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 50;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */; };
|
||||
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B722146ED64006C3AEF /* AppDelegate.swift */; };
|
||||
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B742146ED64006C3AEF /* ViewController.swift */; };
|
||||
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B762146ED64006C3AEF /* Main.storyboard */; };
|
||||
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B792146ED66006C3AEF /* Assets.xcassets */; };
|
||||
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */; };
|
||||
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Data+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TensorFlowLiteApp.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
4AA72B722146ED64006C3AEF /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
|
||||
4AA72B742146ED64006C3AEF /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = "<group>"; };
|
||||
4AA72B772146ED64006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = "<group>"; };
|
||||
4AA72B792146ED66006C3AEF /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
4AA72B7C2146ED66006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
|
||||
4AA72B7E2146ED66006C3AEF /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
|
||||
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
4AA72B6C2146ED64006C3AEF /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
4AA72B662146ED64006C3AEF = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||
4AA72B702146ED64006C3AEF /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B702146ED64006C3AEF /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B722146ED64006C3AEF /* AppDelegate.swift */,
|
||||
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */,
|
||||
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */,
|
||||
4AA72B742146ED64006C3AEF /* ViewController.swift */,
|
||||
4AA72B762146ED64006C3AEF /* Main.storyboard */,
|
||||
4AA72B792146ED66006C3AEF /* Assets.xcassets */,
|
||||
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */,
|
||||
4AA72B7E2146ED66006C3AEF /* Info.plist */,
|
||||
);
|
||||
path = TensorFlowLiteApp;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */;
|
||||
buildPhases = (
|
||||
4AA72B6B2146ED64006C3AEF /* Sources */,
|
||||
4AA72B6C2146ED64006C3AEF /* Frameworks */,
|
||||
4AA72B6D2146ED64006C3AEF /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = TensorFlowLiteApp;
|
||||
productName = TensorFlowLiteApp;
|
||||
productReference = 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
4AA72B672146ED64006C3AEF /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
LastSwiftUpdateCheck = 0940;
|
||||
LastUpgradeCheck = 0940;
|
||||
ORGANIZATIONNAME = Google;
|
||||
TargetAttributes = {
|
||||
4AA72B6E2146ED64006C3AEF = {
|
||||
CreatedOnToolsVersion = 9.4.1;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = 4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */;
|
||||
compatibilityVersion = "Xcode 9.3";
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = 4AA72B662146ED64006C3AEF;
|
||||
productRefGroup = 4AA72B702146ED64006C3AEF /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
4AA72B6D2146ED64006C3AEF /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */,
|
||||
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */,
|
||||
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
4AA72B6B2146ED64006C3AEF /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */,
|
||||
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */,
|
||||
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */,
|
||||
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXVariantGroup section */
|
||||
4AA72B762146ED64006C3AEF /* Main.storyboard */ = {
|
||||
isa = PBXVariantGroup;
|
||||
children = (
|
||||
4AA72B772146ED64006C3AEF /* Base */,
|
||||
);
|
||||
name = Main.storyboard;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */ = {
|
||||
isa = PBXVariantGroup;
|
||||
children = (
|
||||
4AA72B7C2146ED66006C3AEF /* Base */,
|
||||
);
|
||||
name = LaunchScreen.storyboard;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXVariantGroup section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
4AA72B7F2146ED66006C3AEF /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||
MTL_ENABLE_DEBUG_INFO = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
4AA72B802146ED66006C3AEF /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-O";
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
4AA72B822146ED66006C3AEF /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_VERSION = 4.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
4AA72B832146ED66006C3AEF /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_VERSION = 4.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
4AA72B7F2146ED66006C3AEF /* Debug */,
|
||||
4AA72B802146ED66006C3AEF /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
4AA72B822146ED66006C3AEF /* Debug */,
|
||||
4AA72B832146ED66006C3AEF /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
};
|
||||
rootObject = 4AA72B672146ED64006C3AEF /* Project object */;
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
import UIKit
|
||||
|
||||
@UIApplicationMain
|
||||
|
||||
final class AppDelegate: UIResponder, UIApplicationDelegate {
|
||||
|
||||
/// The main window of the app.
|
||||
var window: UIWindow?
|
||||
|
||||
func application(
|
||||
_ application: UIApplication,
|
||||
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? = nil
|
||||
) -> Bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
#if !swift(>=4.2)
|
||||
extension UIApplication {
|
||||
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
|
||||
}
|
||||
#endif // !swift(>=4.2)
|
@ -0,0 +1,22 @@
|
||||
import Foundation
|
||||
|
||||
extension Array {
|
||||
/// Creates a new array from the bytes of the given unsafe data.
|
||||
///
|
||||
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
|
||||
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
|
||||
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
|
||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||
/// `MemoryLayout<Element>.stride`.
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
}
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "20x20",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "20x20",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "29x29",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "29x29",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "40x40",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "40x40",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "60x60",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "60x60",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "20x20",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "20x20",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "29x29",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "29x29",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "40x40",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "40x40",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "76x76",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "76x76",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "83.5x83.5",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ios-marketing",
|
||||
"size" : "1024x1024",
|
||||
"scale" : "1x"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"version" : 1,
|
||||
"author" : "xcode"
|
||||
}
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"version" : 1,
|
||||
"author" : "xcode"
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14109" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" launchScreen="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="01J-lp-oVM">
|
||||
<device id="retina4_7" orientation="portrait">
|
||||
<adaptation id="fullscreen"/>
|
||||
</device>
|
||||
<dependencies>
|
||||
<deployment identifier="iOS"/>
|
||||
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14088"/>
|
||||
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||
</dependencies>
|
||||
<scenes>
|
||||
<!--View Controller-->
|
||||
<scene sceneID="EHf-IW-A2E">
|
||||
<objects>
|
||||
<viewController id="01J-lp-oVM" sceneMemberID="viewController">
|
||||
<layoutGuides>
|
||||
<viewControllerLayoutGuide type="top" id="Llm-lL-Icb"/>
|
||||
<viewControllerLayoutGuide type="bottom" id="xb3-aO-Qok"/>
|
||||
</layoutGuides>
|
||||
<view key="view" contentMode="scaleToFill" id="Ze5-6b-2t3">
|
||||
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||
<subviews>
|
||||
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="TensorFlowLite" textAlignment="center" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="3Gq-PV-hia">
|
||||
<rect key="frame" x="16" y="315" width="343" height="38.5"/>
|
||||
<fontDescription key="fontDescription" type="boldSystem" pointSize="32"/>
|
||||
<nil key="textColor"/>
|
||||
<nil key="highlightedColor"/>
|
||||
</label>
|
||||
</subviews>
|
||||
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstItem="3Gq-PV-hia" firstAttribute="leading" secondItem="Ze5-6b-2t3" secondAttribute="leading" constant="16" id="aXL-9T-5Pf"/>
|
||||
<constraint firstItem="3Gq-PV-hia" firstAttribute="centerY" secondItem="Ze5-6b-2t3" secondAttribute="centerY" id="cDf-Go-1FR"/>
|
||||
<constraint firstAttribute="trailing" secondItem="3Gq-PV-hia" secondAttribute="trailing" constant="16" id="fB9-BX-A3B"/>
|
||||
</constraints>
|
||||
</view>
|
||||
</viewController>
|
||||
<placeholder placeholderIdentifier="IBFirstResponder" id="iYj-Kq-Ea1" userLabel="First Responder" sceneMemberID="firstResponder"/>
|
||||
</objects>
|
||||
<point key="canvasLocation" x="52" y="374.66266866566718"/>
|
||||
</scene>
|
||||
</scenes>
|
||||
</document>
|
@ -0,0 +1,95 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14460.31" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="BYZ-38-t0r">
|
||||
<device id="retina4_7" orientation="portrait">
|
||||
<adaptation id="fullscreen"/>
|
||||
</device>
|
||||
<dependencies>
|
||||
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14460.20"/>
|
||||
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||
</dependencies>
|
||||
<scenes>
|
||||
<!--View Controller-->
|
||||
<scene sceneID="tne-QT-ifu">
|
||||
<objects>
|
||||
<viewController storyboardIdentifier="viewController" useStoryboardIdentifierAsRestorationIdentifier="YES" id="BYZ-38-t0r" customClass="ViewController" customModule="TensorFlowLiteAppLib" sceneMemberID="viewController">
|
||||
<layoutGuides>
|
||||
<viewControllerLayoutGuide type="top" id="y3c-jy-aDJ"/>
|
||||
<viewControllerLayoutGuide type="bottom" id="wfy-db-euE"/>
|
||||
</layoutGuides>
|
||||
<view key="view" contentMode="scaleToFill" id="8bC-Xf-vdC">
|
||||
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||
<subviews>
|
||||
<textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" editable="NO" selectable="NO" translatesAutoresizingMaskIntoConstraints="NO" id="7Mj-sL-hrd">
|
||||
<rect key="frame" x="0.0" y="367" width="375" height="300"/>
|
||||
<color key="backgroundColor" red="0.0" green="0.47843137250000001" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="300" id="YUb-MC-D5w"/>
|
||||
</constraints>
|
||||
<color key="textColor" cocoaTouchSystemColor="tableCellGroupedBackgroundColor"/>
|
||||
<fontDescription key="fontDescription" type="system" pointSize="14"/>
|
||||
<textInputTraits key="textInputTraits" autocapitalizationType="sentences"/>
|
||||
</textView>
|
||||
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Qwg-EP-bd6" userLabel="Bottom Toolbar">
|
||||
<rect key="frame" x="0.0" y="323" width="375" height="44"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="44" id="jhT-Q0-E9N"/>
|
||||
</constraints>
|
||||
<items>
|
||||
<barButtonItem style="plain" systemItem="flexibleSpace" id="P3q-uA-YUa"/>
|
||||
<barButtonItem title="Invoke Interpreter" id="A4J-Mg-nmd" userLabel="Invoke Button">
|
||||
<connections>
|
||||
<action selector="invokeInterpreter:" destination="BYZ-38-t0r" id="lZU-x7-PsJ"/>
|
||||
</connections>
|
||||
</barButtonItem>
|
||||
<barButtonItem style="plain" systemItem="flexibleSpace" id="Qad-Pa-ySg"/>
|
||||
</items>
|
||||
</toolbar>
|
||||
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Gkb-TR-PCB" userLabel="Top Toolbar">
|
||||
<rect key="frame" x="0.0" y="28" width="375" height="44"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="44" id="hSD-2q-fUE"/>
|
||||
</constraints>
|
||||
<items>
|
||||
<barButtonItem style="plain" id="LKw-TX-bbH">
|
||||
<segmentedControl key="customView" opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="left" contentVerticalAlignment="top" segmentControlStyle="bar" selectedSegmentIndex="0" id="rhA-nW-xzT">
|
||||
<rect key="frame" x="16" y="7" width="343" height="30"/>
|
||||
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
|
||||
<segments>
|
||||
<segment title="Add"/>
|
||||
<segment title="AddQuantized"/>
|
||||
<segment title="MultiAdd"/>
|
||||
</segments>
|
||||
<connections>
|
||||
<action selector="modelChanged:" destination="BYZ-38-t0r" eventType="valueChanged" id="YnG-Ov-B5D"/>
|
||||
</connections>
|
||||
</segmentedControl>
|
||||
</barButtonItem>
|
||||
</items>
|
||||
</toolbar>
|
||||
</subviews>
|
||||
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="trailing" secondItem="Gkb-TR-PCB" secondAttribute="trailing" id="4Cr-Sf-I7n"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="bottom" secondItem="wfy-db-euE" secondAttribute="top" id="6ot-zD-sze"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="top" secondItem="Qwg-EP-bd6" secondAttribute="bottom" id="ELA-C6-NiG"/>
|
||||
<constraint firstAttribute="trailing" secondItem="7Mj-sL-hrd" secondAttribute="trailing" id="HDO-xr-mBl"/>
|
||||
<constraint firstItem="Gkb-TR-PCB" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="Kmo-6K-gS4"/>
|
||||
<constraint firstItem="Qwg-EP-bd6" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="hGu-lm-fMG"/>
|
||||
<constraint firstAttribute="trailing" secondItem="Qwg-EP-bd6" secondAttribute="trailing" id="iXR-LK-nTO"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="nr7-jW-ZYf"/>
|
||||
<constraint firstItem="Gkb-TR-PCB" firstAttribute="top" secondItem="y3c-jy-aDJ" secondAttribute="bottom" constant="8" id="uCF-VW-rR0"/>
|
||||
</constraints>
|
||||
</view>
|
||||
<connections>
|
||||
<outlet property="invokeButton" destination="A4J-Mg-nmd" id="UxZ-Ft-E45"/>
|
||||
<outlet property="modelControl" destination="rhA-nW-xzT" id="KKf-TT-BQ2"/>
|
||||
<outlet property="resultsTextView" destination="7Mj-sL-hrd" id="T4I-z4-tYA"/>
|
||||
</connections>
|
||||
</viewController>
|
||||
<placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/>
|
||||
</objects>
|
||||
<point key="canvasLocation" x="125.59999999999999" y="133.5832083958021"/>
|
||||
</scene>
|
||||
</scenes>
|
||||
</document>
|
@ -0,0 +1,13 @@
|
||||
import Foundation
|
||||
|
||||
extension Data {
|
||||
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||
///
|
||||
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||
/// data from the resulting buffer has undefined behavior.
|
||||
/// - Parameter array: An array with elements of type `T`.
|
||||
init<T>(copyingBufferOf array: [T]) {
|
||||
self = array.withUnsafeBufferPointer(Data.init)
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleDevelopmentRegion</key>
|
||||
<string>en</string>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>$(EXECUTABLE_NAME)</string>
|
||||
<key>CFBundleIdentifier</key>
|
||||
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
|
||||
<key>CFBundleInfoDictionaryVersion</key>
|
||||
<string>6.0</string>
|
||||
<key>CFBundleName</key>
|
||||
<string>$(PRODUCT_NAME)</string>
|
||||
<key>CFBundlePackageType</key>
|
||||
<string>APPL</string>
|
||||
<key>CFBundleShortVersionString</key>
|
||||
<string>1.0</string>
|
||||
<key>CFBundleVersion</key>
|
||||
<string>0.0.1</string>
|
||||
<key>LSRequiresIPhoneOS</key>
|
||||
<true/>
|
||||
<key>NSCameraUsageDescription</key>
|
||||
<string>NSCameraUsageDescription</string>
|
||||
<key>NSPhotoLibraryUsageDescription</key>
|
||||
<string>Select a photo to detect objects in.</string>
|
||||
<key>UILaunchStoryboardName</key>
|
||||
<string>LaunchScreen</string>
|
||||
<key>UIMainStoryboardFile</key>
|
||||
<string>Main</string>
|
||||
<key>UIRequiredDeviceCapabilities</key>
|
||||
<array>
|
||||
<string>armv7</string>
|
||||
</array>
|
||||
<key>UISupportedInterfaceOrientations</key>
|
||||
<array>
|
||||
<string>UIInterfaceOrientationPortrait</string>
|
||||
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||
</array>
|
||||
<key>UISupportedInterfaceOrientations~ipad</key>
|
||||
<array>
|
||||
<string>UIInterfaceOrientationPortrait</string>
|
||||
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
@ -0,0 +1,299 @@
|
||||
import TensorFlowLite
|
||||
import UIKit
|
||||
|
||||
class ViewController: UIViewController {
|
||||
|
||||
// MARK: - Properties
|
||||
|
||||
/// TensorFlowLite interpreter object for performing inference from a given model.
|
||||
private var interpreter: Interpreter?
|
||||
|
||||
/// Serial dispatch queue for managing `Interpreter` calls.
|
||||
private let interpreterQueue = DispatchQueue(
|
||||
label: Constant.dispatchQueueLabel,
|
||||
qos: .userInitiated
|
||||
)
|
||||
|
||||
/// The currently selected model.
|
||||
private var currentModel: Model {
|
||||
guard let currentModel = Model(rawValue: modelControl.selectedSegmentIndex) else {
|
||||
preconditionFailure("Invalid model for selected segment index.")
|
||||
}
|
||||
return currentModel
|
||||
}
|
||||
|
||||
/// A description of the current model.
|
||||
private var modelDescription: String {
|
||||
guard let interpreter = interpreter else { return "" }
|
||||
let inputCount = interpreter.inputTensorCount
|
||||
let outputCount = interpreter.outputTensorCount
|
||||
let inputTensors = (0..<inputCount).map { index in
|
||||
var tensorInfo = " Input \(index + 1): "
|
||||
do {
|
||||
let tensor = try interpreter.input(at: index)
|
||||
tensorInfo += "\(tensor)"
|
||||
} catch let error {
|
||||
tensorInfo += "\(error.localizedDescription)"
|
||||
}
|
||||
return tensorInfo
|
||||
}.joined(separator: "\n")
|
||||
let outputTensors = (0..<outputCount).map { index in
|
||||
var tensorInfo = " Output \(index + 1): "
|
||||
do {
|
||||
let tensor = try interpreter.output(at: index)
|
||||
tensorInfo += "\(tensor)"
|
||||
} catch let error {
|
||||
tensorInfo += "\(error.localizedDescription)"
|
||||
}
|
||||
return tensorInfo
|
||||
}.joined(separator: "\n")
|
||||
return "Model Description:\n" +
|
||||
" Input Tensor Count = \(inputCount)\n\(inputTensors)\n\n" +
|
||||
" Output Tensor Count = \(outputCount)\n\(outputTensors)"
|
||||
}
|
||||
|
||||
// MARK: - IBOutlets
|
||||
|
||||
/// A segmented control for changing models. See the `Model` enum for available models.
|
||||
@IBOutlet private var modelControl: UISegmentedControl!
|
||||
|
||||
@IBOutlet private var resultsTextView: UITextView!
|
||||
@IBOutlet private var invokeButton: UIBarButtonItem!
|
||||
|
||||
// MARK: - UIViewController
|
||||
|
||||
override func viewDidLoad() {
|
||||
super.viewDidLoad()
|
||||
|
||||
invokeButton.isEnabled = false
|
||||
loadModel()
|
||||
}
|
||||
|
||||
// MARK: - IBActions
|
||||
|
||||
@IBAction func modelChanged(_ sender: Any) {
|
||||
invokeButton.isEnabled = false
|
||||
updateResultsText("Switched to the \(currentModel.description) model.")
|
||||
loadModel()
|
||||
}
|
||||
|
||||
@IBAction func invokeInterpreter(_ sender: Any) {
|
||||
switch currentModel {
|
||||
case .add:
|
||||
invokeAdd()
|
||||
case .addQuantized:
|
||||
invokeAddQuantized()
|
||||
case .multiAdd:
|
||||
invokeMultiAdd()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private func loadModel() {
|
||||
let fileInfo = currentModel.fileInfo
|
||||
guard let modelPath = Bundle.main.path(forResource: fileInfo.name, ofType: fileInfo.extension)
|
||||
else {
|
||||
updateResultsText("Failed to load the \(currentModel.description) model.")
|
||||
return
|
||||
}
|
||||
setUpInterpreter(withModelPath: modelPath)
|
||||
}
|
||||
|
||||
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||
interpreterQueue.async {
|
||||
do {
|
||||
var options = InterpreterOptions()
|
||||
options.isErrorLoggingEnabled = true
|
||||
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to create the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
safeDispatchOnMain { self.invokeButton.isEnabled = true }
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeAdd() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
try interpreter.resizeInput(at: 0, to: [2])
|
||||
try interpreter.allocateTensors()
|
||||
let input: [Float32] = [1, 3]
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 2 add operations on input \(input.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
let data = Data(copyingBufferOf: input)
|
||||
try interpreter.copy(data, toInputAt: 0)
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: 0)
|
||||
let results: () -> String = {
|
||||
guard let results = [Float32](unsafeData: outputTensor.data) else { return "No results." }
|
||||
return resultsText + results.description
|
||||
}
|
||||
self.updateResultsText(results())
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeAddQuantized() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
try interpreter.resizeInput(at: 0, to: [2])
|
||||
try interpreter.allocateTensors()
|
||||
let input: [UInt8] = [1, 3]
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 2 add operations on quantized input \(input.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
let data = Data(input)
|
||||
try interpreter.copy(data, toInputAt: 0)
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: 0)
|
||||
let results: () -> String = {
|
||||
guard let quantizationParameters = outputTensor.quantizationParameters else {
|
||||
return "No results."
|
||||
}
|
||||
let quantizedResults = [UInt8](outputTensor.data)
|
||||
let dequantizedResults = quantizedResults.map {
|
||||
quantizationParameters.scale * Float(Int($0) - quantizationParameters.zeroPoint)
|
||||
}
|
||||
return resultsText + quantizedResults.description +
|
||||
", dequantized results: " + dequantizedResults.description
|
||||
}
|
||||
self.updateResultsText(results())
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeMultiAdd() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
let shape = TensorShape(2)
|
||||
try (0..<interpreter.inputTensorCount).forEach { index in
|
||||
try interpreter.resizeInput(at: index, to: shape)
|
||||
}
|
||||
try interpreter.allocateTensors()
|
||||
let inputs = try (0..<interpreter.inputTensorCount).map { index -> [Float32] in
|
||||
let input = [Float32(index + 1), Float32(index + 2)]
|
||||
let data = Data(copyingBufferOf: input)
|
||||
try interpreter.copy(data, toInputAt: index)
|
||||
return input
|
||||
}
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 3 add operations on inputs \(inputs.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
try interpreter.invoke()
|
||||
let results = try (0..<interpreter.outputTensorCount).map { index -> [Float32] in
|
||||
let tensor = try interpreter.output(at: index)
|
||||
return [Float32](unsafeData: tensor.data) ?? []
|
||||
}
|
||||
self.updateResultsText(resultsText + results.description)
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func updateResultsText(_ text: String? = nil) {
|
||||
safeDispatchOnMain { self.resultsTextView.text = text }
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
static let dispatchQueueLabel = "TensorFlowLiteInterpreterQueue"
|
||||
static let nilInterpreterErrorMessage =
|
||||
"Failed to invoke the interpreter because the interpreter was nil."
|
||||
}
|
||||
|
||||
/// Models that can be loaded by the TensorFlow Lite `Interpreter`.
|
||||
private enum Model: Int, CustomStringConvertible {
|
||||
/// A float model that performs two add operations on one input tensor and returns the result in
|
||||
/// one output tensor.
|
||||
case add = 0
|
||||
/// A quantized model that performs two add operations on one input tensor and returns the result
|
||||
/// in one output tensor.
|
||||
case addQuantized = 1
|
||||
/// A float model that performs three add operations on four input tensors and returns the results
|
||||
/// in 2 output tensors.
|
||||
case multiAdd = 2
|
||||
|
||||
var fileInfo: (name: String, extension: String) {
|
||||
switch self {
|
||||
case .add:
|
||||
return Add.fileInfo
|
||||
case .addQuantized:
|
||||
return AddQuantized.fileInfo
|
||||
case .multiAdd:
|
||||
return MultiAdd.fileInfo
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - CustomStringConvertible
|
||||
|
||||
var description: String {
|
||||
switch self {
|
||||
case .add:
|
||||
return Add.name
|
||||
case .addQuantized:
|
||||
return AddQuantized.name
|
||||
case .multiAdd:
|
||||
return MultiAdd.name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Values for the `Add` model.
|
||||
private enum Add {
|
||||
static let name = "Add"
|
||||
static let fileInfo = (name: "add", extension: "bin")
|
||||
}
|
||||
|
||||
/// Values for the `AddQuantized` model.
|
||||
private enum AddQuantized {
|
||||
static let name = "AddQuantized"
|
||||
static let fileInfo = (name: "add_quantized", extension: "bin")
|
||||
}
|
||||
|
||||
/// Values for the `MultiAdd` model.
|
||||
private enum MultiAdd {
|
||||
static let name = "MultiAdd"
|
||||
static let fileInfo = (name: "multi_add", extension: "bin")
|
||||
}
|
||||
|
||||
// MARK: - Fileprivate
|
||||
|
||||
/// Safely dispatches the given block on the main queue. If the current thread is `main`, the block
|
||||
/// is executed synchronously; otherwise, the block is executed asynchronously on the main thread.
|
||||
fileprivate func safeDispatchOnMain(_ block: @escaping () -> Void) {
|
||||
if Thread.isMainThread { block(); return }
|
||||
DispatchQueue.main.async { block() }
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class InterpreterOptionsTests: XCTestCase {
|
||||
|
||||
func testInterpreterOptions_InitWithDefaultValues() {
|
||||
let options = InterpreterOptions()
|
||||
XCTAssertNil(options.threadCount)
|
||||
XCTAssertFalse(options.isErrorLoggingEnabled)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_InitWithCustomValues() {
|
||||
var options = InterpreterOptions()
|
||||
options.threadCount = 2
|
||||
XCTAssertEqual(options.threadCount, 2)
|
||||
options.isErrorLoggingEnabled = true
|
||||
XCTAssertTrue(options.isErrorLoggingEnabled)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_Equatable() {
|
||||
var options1 = InterpreterOptions()
|
||||
var options2 = InterpreterOptions()
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options1.threadCount = 2
|
||||
options2.threadCount = 2
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.threadCount = 3
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
options2.threadCount = 2
|
||||
|
||||
options1.isErrorLoggingEnabled = true
|
||||
options2.isErrorLoggingEnabled = true
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.isErrorLoggingEnabled = false
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
}
|
||||
}
|
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
@ -0,0 +1,315 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class InterpreterTests: XCTestCase {
|
||||
|
||||
var interpreter: Interpreter!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
interpreter = try! Interpreter(modelPath: AddModel.path)
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
interpreter = nil
|
||||
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPath() {
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
||||
}
|
||||
|
||||
func testInterpreter_Init_ThrowsFailedToLoadModel() {
|
||||
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPathAndOptions() {
|
||||
var options = InterpreterOptions()
|
||||
options.threadCount = 2
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorCount() {
|
||||
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorCount() {
|
||||
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertNoThrow(try interpreter.invoke())
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
||||
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
||||
XCTAssertEqual(outputTensor, AddModel.outputTensor)
|
||||
let expectedResults = [Float32](unsafeData: outputTensor.data)
|
||||
XCTAssertEqual(expectedResults, AddModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex)
|
||||
XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor)
|
||||
let expectedResults = [UInt8](outputTensor.data)
|
||||
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.invoke()
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
||||
let maxIndex = AddModel.outputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape() {
|
||||
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.resizeInput(
|
||||
at: AddModel.invalidIndex,
|
||||
to: [2, 2, 3]
|
||||
)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.copy(
|
||||
AddModel.inputData,
|
||||
toInputAt: AddModel.invalidIndex
|
||||
)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||
XCTAssertThrowsError(try interpreter.copy(
|
||||
invalidData,
|
||||
toInputAt: AddModel.validIndex
|
||||
)) { error in
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_AllocateTensors() {
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private func setUpAddModelInputTensor() throws {
|
||||
precondition(interpreter != nil)
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||
}
|
||||
|
||||
private func setUpAddQuantizedModelInputTensor() throws {
|
||||
precondition(interpreter != nil)
|
||||
try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex)
|
||||
}
|
||||
|
||||
private func assertEqualErrors(actual: Error, expected: InterpreterError) {
|
||||
guard let actual = actual as? InterpreterError else {
|
||||
XCTFail("Actual error should be of type InterpreterError.")
|
||||
return
|
||||
}
|
||||
XCTAssertEqual(actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
/// Values for the `add.bin` model.
|
||||
private enum AddModel {
|
||||
static let info = (name: "add", extension: "bin")
|
||||
static let inputTensorCount = 1
|
||||
static let outputTensorCount = 1
|
||||
static let invalidIndex = 1
|
||||
static let validIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let dataCount = inputData.count
|
||||
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
||||
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
||||
static let results = [Float32(3.0), Float32(9.0)]
|
||||
|
||||
static let inputTensor = Tensor(
|
||||
name: "input",
|
||||
dataType: .float32,
|
||||
shape: shape,
|
||||
data: inputData
|
||||
)
|
||||
static let outputTensor = Tensor(
|
||||
name: "output",
|
||||
dataType: .float32,
|
||||
shape: shape,
|
||||
data: outputData
|
||||
)
|
||||
|
||||
static var path: String = {
|
||||
let bundle = Bundle(for: InterpreterTests.self)
|
||||
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||
return path
|
||||
}()
|
||||
}
|
||||
|
||||
/// Values for the `add_quantized.bin` model.
|
||||
private enum AddQuantizedModel {
|
||||
static let info = (name: "add_quantized", extension: "bin")
|
||||
static let inputOutputIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let inputData = Data([1, 3])
|
||||
static let outputData = Data([3, 9])
|
||||
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
||||
static let results: [UInt8] = [3, 9]
|
||||
|
||||
static let inputTensor = Tensor(
|
||||
name: "input",
|
||||
dataType: .uInt8,
|
||||
shape: shape,
|
||||
data: inputData,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
static let outputTensor = Tensor(
|
||||
name: "output",
|
||||
dataType: .uInt8,
|
||||
shape: shape,
|
||||
data: outputData,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
|
||||
static var path: String = {
|
||||
let bundle = Bundle(for: InterpreterTests.self)
|
||||
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||
return path
|
||||
}()
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension Array {
|
||||
/// Creates a new array from the bytes of the given unsafe data.
|
||||
///
|
||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||
/// `MemoryLayout<Element>.stride`.
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
}
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
||||
|
||||
extension Data {
|
||||
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||
///
|
||||
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||
/// data from the resulting buffer has undefined behavior.
|
||||
/// - Parameter array: An array with elements of type `T`.
|
||||
init<T>(copyingBufferOf array: [T]) {
|
||||
self = array.withUnsafeBufferPointer(Data.init)
|
||||
}
|
||||
}
|
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class ModelTests: XCTestCase {
|
||||
|
||||
var modelPath: String!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
let bundle = Bundle(for: type(of: self))
|
||||
guard let modelPath = bundle.path(
|
||||
forResource: Constant.modelInfo.name,
|
||||
ofType: Constant.modelInfo.extension)
|
||||
else {
|
||||
XCTFail("Failed to get the model file path.")
|
||||
return
|
||||
}
|
||||
self.modelPath = modelPath
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
modelPath = nil
|
||||
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testModel_InitWithFilePath() {
|
||||
XCTAssertNotNil(Model(filePath: modelPath))
|
||||
}
|
||||
|
||||
func testModel_InitWithEmptyFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: ""))
|
||||
}
|
||||
|
||||
func testModel_InitWithInvalidFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: "invalid/path"))
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
static let modelInfo = (name: "add", extension: "bin")
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class QuantizationParametersTests: XCTestCase {
|
||||
|
||||
func testQuantizationParameters_InitWithCustomValues() {
|
||||
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters.scale, 0.5)
|
||||
XCTAssertEqual(parameters.zeroPoint, 1)
|
||||
}
|
||||
|
||||
func testQuantizationParameters_Equatable() {
|
||||
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters1, parameters2)
|
||||
|
||||
let parameters3 = QuantizationParameters(scale: 0.4, zeroPoint: 1)
|
||||
XCTAssertNotEqual(parameters1, parameters3)
|
||||
XCTAssertNotEqual(parameters2, parameters3)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension QuantizationParameters: Equatable {
|
||||
public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool {
|
||||
return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint
|
||||
}
|
||||
}
|
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
@ -0,0 +1,83 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class TensorTests: XCTestCase {
|
||||
|
||||
// MARK: - Tensor
|
||||
|
||||
func testTensor_Init() {
|
||||
let name = "InputTensor"
|
||||
let dataType: TensorDataType = .uInt8
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let inputTensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
XCTAssertEqual(inputTensor.name, name)
|
||||
XCTAssertEqual(inputTensor.dataType, dataType)
|
||||
XCTAssertEqual(inputTensor.shape, shape)
|
||||
XCTAssertEqual(inputTensor.data, data)
|
||||
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
||||
}
|
||||
|
||||
// MARK: - TensorShape
|
||||
|
||||
func testTensorShape_InitWithArray() {
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithElements() {
|
||||
let shape = TensorShape(2, 2, 3)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithArrayLiteral() {
|
||||
let shape: TensorShape = [2, 2, 3]
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
|
||||
static let dimensions = [2, 2, 3]
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension TensorShape: Equatable {
|
||||
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
|
||||
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
extension Tensor: Equatable {
|
||||
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
|
||||
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
|
||||
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
|
||||
}
|
||||
}
|
@ -34,6 +34,7 @@ PIP_PACKAGE_QUERY_EXPRESSION = (
|
||||
# pip smoke test.
|
||||
BUILD_BLACKLIST = [
|
||||
"tensorflow/lite/examples/android",
|
||||
"tensorflow/lite/experimental/objc",
|
||||
"tensorflow/lite/experimental/swift",
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user