parent
d183818d4d
commit
59d5fbcc1a
100
tensorflow/lite/experimental/swift/BUILD
Normal file
100
tensorflow/lite/experimental/swift/BUILD
Normal file
@ -0,0 +1,100 @@
|
||||
# TensorFlow Lite for Swift.
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_unit_test")
|
||||
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
|
||||
|
||||
MINIMUM_OS_VERSION = "9.0"
|
||||
|
||||
SWIFT_COPTS = [
|
||||
"-wmo",
|
||||
]
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLite",
|
||||
srcs = glob(["Sources/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
module_name = "TensorFlowLite",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/c:c_api",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "TensorFlowLiteTests",
|
||||
size = "small",
|
||||
minimum_os_version = MINIMUM_OS_VERSION,
|
||||
tags = [
|
||||
"manual",
|
||||
# DISABLED: Following sanitizer tests are not supported by iOS test targets.
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [":TensorFlowLiteTestsLib"],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteTestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob(["Tests/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
":TestResources",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TestResources",
|
||||
resources = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "TensorFlowLiteApp",
|
||||
app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]),
|
||||
bundle_id = "com.tensorflow.lite.swift.TensorFlowLite",
|
||||
families = [
|
||||
"ipad",
|
||||
"iphone",
|
||||
],
|
||||
infoplists = ["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist"],
|
||||
launch_storyboard = "TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard",
|
||||
minimum_os_version = MINIMUM_OS_VERSION,
|
||||
sdk_frameworks = [
|
||||
"CoreGraphics",
|
||||
],
|
||||
tags = ["manual"],
|
||||
deps = [":TensorFlowLiteAppLib"],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteAppLib",
|
||||
srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]),
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
":TensorFlowLiteAppResources",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLiteAppResources",
|
||||
storyboards = glob([
|
||||
"TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard",
|
||||
]),
|
||||
tags = ["manual"],
|
||||
deps = [":TestResources"],
|
||||
)
|
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
54
tensorflow/lite/experimental/swift/README.md
Normal file
54
tensorflow/lite/experimental/swift/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# TensorFlow Lite for Swift
|
||||
|
||||
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
|
||||
solution for Swift developers. It enables low-latency inference of on-device
|
||||
machine learning models with a small binary size and fast performance supporting
|
||||
hardware acceleration.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Bazel
|
||||
|
||||
In your `BUILD` file, add the `TensorFlowLite` dependency:
|
||||
|
||||
```python
|
||||
swift_library(
|
||||
# ...
|
||||
deps = [
|
||||
"//tensorflow/lite/swift:TensorFlowLite",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
In your Swift files, import the module:
|
||||
|
||||
```swift
|
||||
import TensorFlowLite
|
||||
```
|
||||
|
||||
### Tulsi
|
||||
|
||||
Open the `TensorFlowLite.tulsiproj` using the [TulsiApp](https://github.com/bazelbuild/tulsi) or by
|
||||
running the [`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||
script:
|
||||
|
||||
```shell
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||
```
|
||||
|
||||
### CocoaPods
|
||||
|
||||
Add the following to your `Podfile`:
|
||||
|
||||
```ruby
|
||||
use_frameworks!
|
||||
pod 'TensorFlowLiteSwift'
|
||||
```
|
||||
|
||||
Then, run `pod install`.
|
||||
|
||||
In your Swift files, import the module:
|
||||
|
||||
```swift
|
||||
import TensorFlowLite
|
||||
```
|
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
@ -0,0 +1,265 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
||||
public final class Interpreter {
|
||||
|
||||
/// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
|
||||
private typealias CInterpreter = OpaquePointer
|
||||
|
||||
/// Total number of input tensors associated with the model.
|
||||
public var inputTensorCount: Int {
|
||||
return Int(TFL_InterpreterGetInputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// Total number of output tensors associated with the model.
|
||||
public var outputTensorCount: Int {
|
||||
return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// The underlying `TFL_Interpreter` C pointer.
|
||||
private var cInterpreter: CInterpreter?
|
||||
|
||||
/// Creates a new model interpreter instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that
|
||||
/// interpreter will determine the configuration options.
|
||||
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
||||
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||
|
||||
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
||||
guard let cOptions = TFL_NewInterpreterOptions() else {
|
||||
throw InterpreterError.failedToCreateInterpreter
|
||||
}
|
||||
if let threadCount = options.threadCount, threadCount > 0 {
|
||||
TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount))
|
||||
}
|
||||
if options.isErrorLoggingEnabled {
|
||||
TFL_InterpreterOptionsSetErrorReporter(
|
||||
cOptions,
|
||||
{ (_, format, arguments) in
|
||||
guard let cFormat = format,
|
||||
let message = String(cFormat: cFormat, arguments: arguments)
|
||||
else {
|
||||
return
|
||||
}
|
||||
print(String(describing: InterpreterError.tensorFlowLiteError(message)))
|
||||
},
|
||||
nil
|
||||
)
|
||||
}
|
||||
return cOptions
|
||||
}
|
||||
defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) }
|
||||
|
||||
guard let cInterpreter = TFL_NewInterpreter(model.cModel, cInterpreterOptions) else {
|
||||
throw InterpreterError.failedToCreateInterpreter
|
||||
}
|
||||
self.cInterpreter = cInterpreter
|
||||
}
|
||||
|
||||
deinit {
|
||||
TFL_DeleteInterpreter(cInterpreter)
|
||||
}
|
||||
|
||||
/// Invokes the interpreter to perform inference from the loaded graph.
|
||||
///
|
||||
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
||||
public func invoke() throws {
|
||||
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the input tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - index: The index for the input tensor.
|
||||
/// - Throws: An error if the index is invalid or the tensors have not been allocated.
|
||||
/// - Returns: The input tensor at the given index.
|
||||
public func input(at index: Int) throws -> Tensor {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TFL_TensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||
let scale = cQuantizationParams.scale
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
return tensor
|
||||
}
|
||||
|
||||
/// Returns the output tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - index: The index for the output tensor.
|
||||
/// - Throws: An error if the index is invalid, tensors haven't been allocated, or interpreter
|
||||
/// hasn't been invoked for models that dynamically compute output tensors based on the values
|
||||
/// of its input tensors.
|
||||
/// - Returns: The output tensor at the given index.
|
||||
public func output(at index: Int) throws -> Tensor {
|
||||
let maxIndex = outputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.invokeInterpreterRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TFL_TensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||
let scale = cQuantizationParams.scale
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
return tensor
|
||||
}
|
||||
|
||||
/// Resizes the input tensor at the given index to the specified tensor shape.
|
||||
///
|
||||
/// - Note: After resizing an input tensor, the client **must** explicitly call
|
||||
/// `allocateTensors()` before attempting to access the resized tensor data or invoking the
|
||||
/// interpreter to perform inference.
|
||||
/// - Parameters:
|
||||
/// - index: The index for the input tensor.
|
||||
/// - shape: The shape that the input tensor should be resized to.
|
||||
/// - Throws: An error if the input tensor at the given index could not be resized.
|
||||
public func resizeInput(at index: Int, to shape: TensorShape) throws {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard TFL_InterpreterResizeInputTensor(
|
||||
cInterpreter,
|
||||
Int32(index),
|
||||
shape.int32Dimensions,
|
||||
Int32(shape.rank)
|
||||
) == kTfLiteOk
|
||||
else {
|
||||
throw InterpreterError.failedToResizeInputTensor(index: index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the given data to the input tensor at the given index.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - data: The data to be copied to the input tensor's data buffer.
|
||||
/// - index: The index for the input tensor.
|
||||
/// - Throws: An error if the `data.count` does not match the input tensor's `data.count` or if
|
||||
/// the given index is invalid.
|
||||
/// - Returns: The input tensor with the copied data.
|
||||
@discardableResult
|
||||
public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)) else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
|
||||
let byteCount = TFL_TensorByteSize(cTensor)
|
||||
guard data.count == byteCount else {
|
||||
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
||||
}
|
||||
|
||||
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||
return try input(at: index)
|
||||
}
|
||||
|
||||
/// Allocates memory for all input tensors based on their `TensorShape`s.
|
||||
///
|
||||
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
||||
/// interpreter and/or resizing any input tensors.
|
||||
/// - Throws: An error if memory could not be allocated for the input tensors.
|
||||
public func allocateTensors() throws {
|
||||
guard TFL_InterpreterAllocateTensors(cInterpreter) == kTfLiteOk else {
|
||||
throw InterpreterError.failedToAllocateTensors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension String {
|
||||
/// Returns a new `String` initialized by using the given format C array as a template into which
|
||||
/// the remaining argument values are substituted according to the user’s default locale.
|
||||
///
|
||||
/// - Note: Returns `nil` if a new `String` could not be constructed from the given values.
|
||||
/// - Parameters:
|
||||
/// - cFormat: The format C array as a template for substituting values.
|
||||
/// - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`.
|
||||
init?(cFormat: UnsafePointer<CChar>, arguments: CVaListPointer) {
|
||||
var buffer: UnsafeMutablePointer<CChar>?
|
||||
guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil }
|
||||
self.init(validatingUTF8: cString)
|
||||
}
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// TensorFlow Lite interpreter errors.
|
||||
public enum InterpreterError: Error {
|
||||
case invalidTensorIndex(index: Int, maxIndex: Int)
|
||||
case invalidTensorDataCount(provided: Int, required: Int)
|
||||
case invalidTensorDataType
|
||||
case failedToLoadModel
|
||||
case failedToCreateInterpreter
|
||||
case failedToResizeInputTensor(index: Int)
|
||||
case failedToCopyDataToInputTensor
|
||||
case failedToAllocateTensors
|
||||
case allocateTensorsRequired
|
||||
case invokeInterpreterRequired
|
||||
case tensorFlowLiteError(String)
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension InterpreterError: LocalizedError {
|
||||
/// Localized description of the interpreter error.
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .invalidTensorIndex(let index, let maxIndex):
|
||||
return "Invalid tensor index \(index), max index is \(maxIndex)."
|
||||
case .invalidTensorDataCount(let providedCount, let requiredCount):
|
||||
return "Provided data count \(providedCount) must match the required count \(requiredCount)."
|
||||
case .invalidTensorDataType:
|
||||
return "Tensor data type is unsupported or could not be determined because of a model error."
|
||||
case .failedToLoadModel:
|
||||
return "Failed to load the given model."
|
||||
case .failedToCreateInterpreter:
|
||||
return "Failed to create the interpreter."
|
||||
case .failedToResizeInputTensor(let index):
|
||||
return "Failed to resize input tesnor at index \(index)."
|
||||
case .failedToCopyDataToInputTensor:
|
||||
return "Failed to copy data to input tensor."
|
||||
case .failedToAllocateTensors:
|
||||
return "Failed to allocate memory for input tensors."
|
||||
case .allocateTensorsRequired:
|
||||
return "Must call allocateTensors()."
|
||||
case .invokeInterpreterRequired:
|
||||
return "Must call invoke()."
|
||||
case .tensorFlowLiteError(let message):
|
||||
return "TensorFlow Lite Error: \(message)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension InterpreterError: CustomStringConvertible {
|
||||
/// Textual representation of the TensorFlow Lite interpreter error.
|
||||
public var description: String {
|
||||
return errorDescription ?? "Unknown error."
|
||||
}
|
||||
}
|
||||
|
||||
#if swift(>=4.2)
|
||||
extension InterpreterError: Equatable {}
|
||||
#else
|
||||
extension InterpreterError: Equatable {
|
||||
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
|
||||
switch (lhs, rhs) {
|
||||
case (.invalidTensorDataType, .invalidTensorDataType),
|
||||
(.failedToLoadModel, .failedToLoadModel),
|
||||
(.failedToCreateInterpreter, .failedToCreateInterpreter),
|
||||
(.failedToAllocateTensors, .failedToAllocateTensors),
|
||||
(.allocateTensorsRequired, .allocateTensorsRequired),
|
||||
(.invokeInterpreterRequired, .invokeInterpreterRequired):
|
||||
return true
|
||||
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
|
||||
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
|
||||
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
|
||||
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
|
||||
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
|
||||
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
|
||||
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
|
||||
return lhsIndex == rhsIndex
|
||||
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
|
||||
return lhsMessage == rhsMessage
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // swift(>=4.2)
|
@ -0,0 +1,29 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Custom configuration options for a TensorFlow Lite interpreter.
|
||||
public struct InterpreterOptions: Equatable {
|
||||
|
||||
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
|
||||
/// indicates that the `Interpreter` will decide the number of threads to use.
|
||||
public var threadCount: Int? = nil
|
||||
|
||||
/// Whether error logging to the console is enabled. The default is `false`.
|
||||
public var isErrorLoggingEnabled = false
|
||||
|
||||
/// Creates a new instance of interpreter options.
|
||||
public init() {}
|
||||
}
|
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
|
||||
final class Model {
|
||||
|
||||
/// The `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
|
||||
typealias CModel = OpaquePointer
|
||||
|
||||
/// The underlying `TFL_Model` C pointer.
|
||||
let cModel: CModel?
|
||||
|
||||
/// Creates a new model instance.
|
||||
///
|
||||
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
||||
/// - Parameters:
|
||||
/// - filePath: Local file path to a TensorFlow Lite model.
|
||||
init?(filePath: String) {
|
||||
guard !filePath.isEmpty, let cModel = TFL_NewModelFromFile(filePath) else { return nil }
|
||||
self.cModel = cModel
|
||||
}
|
||||
|
||||
deinit {
|
||||
TFL_DeleteModel(cModel)
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Parameters that determine the mapping of quantized values to real values. Quantized values can
|
||||
/// be mapped to float values using the following conversion:
|
||||
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||
public struct QuantizationParameters {
|
||||
|
||||
/// Difference between real values corresponding to consecutive quantized values differing by 1.
|
||||
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||
public let scale: Float
|
||||
|
||||
/// Quantized value that corresponds to the real 0 value.
|
||||
public let zeroPoint: Int
|
||||
|
||||
/// Creates a new quantization parameters instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - scale: Scale value for asymmetric quantization.
|
||||
/// - zeroPoint: Zero point for asymmetric quantization.
|
||||
init(scale: Float, zeroPoint: Int) {
|
||||
self.scale = scale
|
||||
self.zeroPoint = zeroPoint
|
||||
}
|
||||
}
|
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
import TensorFlowLiteCAPI
|
||||
|
||||
/// An input or output tensor in a TensorFlow Lite graph.
|
||||
public struct Tensor {
|
||||
|
||||
/// Name of the tensor.
|
||||
public let name: String
|
||||
|
||||
/// Data type of the tensor.
|
||||
public let dataType: TensorDataType
|
||||
|
||||
/// Shape of the tensor.
|
||||
public let shape: TensorShape
|
||||
|
||||
/// Data in the input or output tensor.
|
||||
public let data: Data
|
||||
|
||||
/// Quantization parameters for the tensor if using a quantized model.
|
||||
public let quantizationParameters: QuantizationParameters?
|
||||
|
||||
/// Creates a new input or output tensor instance.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - name: Name of the tensor.
|
||||
/// - dataType: Data type of the tensor.
|
||||
/// - data: Data in the input tensor.
|
||||
/// - quantizationParameters Quantization parameters for the tensor if using a quantized model.
|
||||
/// The default is `nil`.
|
||||
init(
|
||||
name: String,
|
||||
dataType: TensorDataType,
|
||||
shape: TensorShape,
|
||||
data: Data,
|
||||
quantizationParameters: QuantizationParameters? = nil
|
||||
) {
|
||||
self.name = name
|
||||
self.dataType = dataType
|
||||
self.shape = shape
|
||||
self.data = data
|
||||
self.quantizationParameters = quantizationParameters
|
||||
}
|
||||
}
|
||||
|
||||
/// Supported TensorFlow Lite tensor data types.
|
||||
public enum TensorDataType: Equatable {
|
||||
/// 32-bit single precision floating point tensor data type.
|
||||
case float32
|
||||
/// 8-bit unsigned integer tensor data type.
|
||||
case uInt8
|
||||
/// 16-bit signed integer tensor data type.
|
||||
case int16
|
||||
/// 32-bit signed integer tensor data type.
|
||||
case int32
|
||||
/// 64-bit signed integer tensor data type.
|
||||
case int64
|
||||
/// Boolean tensor data type.
|
||||
case bool
|
||||
|
||||
/// Creates a new tensor data type from the given `TFL_Type` or `nil` if the data type is
|
||||
/// unsupported or could not be determined because there was an error.
|
||||
///
|
||||
/// - Parameter type: A data type supported by a tensor.
|
||||
init?(type: TFL_Type) {
|
||||
switch type {
|
||||
case kTfLiteFloat32:
|
||||
self = .float32
|
||||
case kTfLiteUInt8:
|
||||
self = .uInt8
|
||||
case kTfLiteInt16:
|
||||
self = .int16
|
||||
case kTfLiteInt32:
|
||||
self = .int32
|
||||
case kTfLiteInt64:
|
||||
self = .int64
|
||||
case kTfLiteBool:
|
||||
self = .bool
|
||||
case kTfLiteNoType:
|
||||
fallthrough
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The shape of a TensorFlow Lite tensor.
|
||||
public struct TensorShape {
|
||||
|
||||
/// The number of dimensions of the tensor.
|
||||
public let rank: Int
|
||||
|
||||
/// Array of dimensions for the tensor.
|
||||
public let dimensions: [Int]
|
||||
|
||||
/// Array of `Int32` dimensions for the tensor.
|
||||
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||
|
||||
/// Creates a new tensor shape instance with the given array of dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - dimensions: Dimensions for the tensor.
|
||||
public init(_ dimensions: [Int]) {
|
||||
self.rank = dimensions.count
|
||||
self.dimensions = dimensions
|
||||
}
|
||||
|
||||
/// Creates a new tensor shape instance with the given elements representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - elements: Dimensions for the tensor.
|
||||
public init(_ elements: Int...) {
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
||||
|
||||
extension TensorShape: ExpressibleByArrayLiteral {
|
||||
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - arrayLiteral: Dimensions for the tensor.
|
||||
public init(arrayLiteral: Int...) {
|
||||
self.init(arrayLiteral)
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
{
|
||||
"sourceFilters" : [
|
||||
"third_party/tensorflow/lite/experimental/c",
|
||||
"third_party/tensorflow/lite/experimental/swift",
|
||||
"third_party/tensorflow/lite/experimental/swift/Sources",
|
||||
"third_party/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp",
|
||||
"third_party/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj",
|
||||
"third_party/tensorflow/lite/experimental/swift/Tests",
|
||||
],
|
||||
"buildTargets" : [
|
||||
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLite",
|
||||
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLiteApp",
|
||||
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLiteTests",
|
||||
],
|
||||
"projectName" : "TensorFlowLite",
|
||||
"optionSet" : {
|
||||
"LaunchActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"EnvironmentVariables" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"CommandlineArguments" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"LaunchActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
}
|
||||
},
|
||||
"additionalFilePaths" : [
|
||||
"third_party/tensorflow/lite/experimental/swift/BUILD"
|
||||
]
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
{
|
||||
"configDefaults" : {
|
||||
"optionSet" : {
|
||||
"ProjectPrioritizesSwift" : {
|
||||
"p" : "YES"
|
||||
}
|
||||
}
|
||||
},
|
||||
"projectName" : "TensorFlowLite",
|
||||
"packages" : [
|
||||
"third_party/tensorflow/lite/experimental/swift"
|
||||
],
|
||||
"workspaceRoot" : "../../../../../.."
|
||||
}
|
@ -0,0 +1,345 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 50;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */; };
|
||||
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B722146ED64006C3AEF /* AppDelegate.swift */; };
|
||||
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B742146ED64006C3AEF /* ViewController.swift */; };
|
||||
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B762146ED64006C3AEF /* Main.storyboard */; };
|
||||
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B792146ED66006C3AEF /* Assets.xcassets */; };
|
||||
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */; };
|
||||
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Data+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TensorFlowLiteApp.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
4AA72B722146ED64006C3AEF /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
|
||||
4AA72B742146ED64006C3AEF /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = "<group>"; };
|
||||
4AA72B772146ED64006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = "<group>"; };
|
||||
4AA72B792146ED66006C3AEF /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
4AA72B7C2146ED66006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
|
||||
4AA72B7E2146ED66006C3AEF /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
|
||||
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
4AA72B6C2146ED64006C3AEF /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
4AA72B662146ED64006C3AEF = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||
4AA72B702146ED64006C3AEF /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B702146ED64006C3AEF /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
4AA72B722146ED64006C3AEF /* AppDelegate.swift */,
|
||||
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */,
|
||||
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */,
|
||||
4AA72B742146ED64006C3AEF /* ViewController.swift */,
|
||||
4AA72B762146ED64006C3AEF /* Main.storyboard */,
|
||||
4AA72B792146ED66006C3AEF /* Assets.xcassets */,
|
||||
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */,
|
||||
4AA72B7E2146ED66006C3AEF /* Info.plist */,
|
||||
);
|
||||
path = TensorFlowLiteApp;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */;
|
||||
buildPhases = (
|
||||
4AA72B6B2146ED64006C3AEF /* Sources */,
|
||||
4AA72B6C2146ED64006C3AEF /* Frameworks */,
|
||||
4AA72B6D2146ED64006C3AEF /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = TensorFlowLiteApp;
|
||||
productName = TensorFlowLiteApp;
|
||||
productReference = 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
4AA72B672146ED64006C3AEF /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
LastSwiftUpdateCheck = 0940;
|
||||
LastUpgradeCheck = 0940;
|
||||
ORGANIZATIONNAME = Google;
|
||||
TargetAttributes = {
|
||||
4AA72B6E2146ED64006C3AEF = {
|
||||
CreatedOnToolsVersion = 9.4.1;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = 4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */;
|
||||
compatibilityVersion = "Xcode 9.3";
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = 4AA72B662146ED64006C3AEF;
|
||||
productRefGroup = 4AA72B702146ED64006C3AEF /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
4AA72B6D2146ED64006C3AEF /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */,
|
||||
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */,
|
||||
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
4AA72B6B2146ED64006C3AEF /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */,
|
||||
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */,
|
||||
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */,
|
||||
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXVariantGroup section */
|
||||
4AA72B762146ED64006C3AEF /* Main.storyboard */ = {
|
||||
isa = PBXVariantGroup;
|
||||
children = (
|
||||
4AA72B772146ED64006C3AEF /* Base */,
|
||||
);
|
||||
name = Main.storyboard;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */ = {
|
||||
isa = PBXVariantGroup;
|
||||
children = (
|
||||
4AA72B7C2146ED66006C3AEF /* Base */,
|
||||
);
|
||||
name = LaunchScreen.storyboard;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXVariantGroup section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
4AA72B7F2146ED66006C3AEF /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||
MTL_ENABLE_DEBUG_INFO = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
4AA72B802146ED66006C3AEF /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||
CLANG_CXX_LIBRARY = "libc++";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-O";
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
4AA72B822146ED66006C3AEF /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_VERSION = 4.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
4AA72B832146ED66006C3AEF /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_VERSION = 4.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
4AA72B7F2146ED66006C3AEF /* Debug */,
|
||||
4AA72B802146ED66006C3AEF /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
4AA72B822146ED66006C3AEF /* Debug */,
|
||||
4AA72B832146ED66006C3AEF /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
};
|
||||
rootObject = 4AA72B672146ED64006C3AEF /* Project object */;
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
import UIKit
|
||||
|
||||
@UIApplicationMain
|
||||
|
||||
final class AppDelegate: UIResponder, UIApplicationDelegate {
|
||||
|
||||
/// The main window of the app.
|
||||
var window: UIWindow?
|
||||
|
||||
func application(
|
||||
_ application: UIApplication,
|
||||
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? = nil
|
||||
) -> Bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
#if !swift(>=4.2)
|
||||
extension UIApplication {
|
||||
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
|
||||
}
|
||||
#endif // !swift(>=4.2)
|
@ -0,0 +1,22 @@
|
||||
import Foundation
|
||||
|
||||
extension Array {
|
||||
/// Creates a new array from the bytes of the given unsafe data.
|
||||
///
|
||||
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
|
||||
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
|
||||
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
|
||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||
/// `MemoryLayout<Element>.stride`.
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
}
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "20x20",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "20x20",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "29x29",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "29x29",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "40x40",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "40x40",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "60x60",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "iphone",
|
||||
"size" : "60x60",
|
||||
"scale" : "3x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "20x20",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "20x20",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "29x29",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "29x29",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "40x40",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "40x40",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "76x76",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "76x76",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ipad",
|
||||
"size" : "83.5x83.5",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "ios-marketing",
|
||||
"size" : "1024x1024",
|
||||
"scale" : "1x"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"version" : 1,
|
||||
"author" : "xcode"
|
||||
}
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"version" : 1,
|
||||
"author" : "xcode"
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14109" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" launchScreen="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="01J-lp-oVM">
|
||||
<device id="retina4_7" orientation="portrait">
|
||||
<adaptation id="fullscreen"/>
|
||||
</device>
|
||||
<dependencies>
|
||||
<deployment identifier="iOS"/>
|
||||
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14088"/>
|
||||
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||
</dependencies>
|
||||
<scenes>
|
||||
<!--View Controller-->
|
||||
<scene sceneID="EHf-IW-A2E">
|
||||
<objects>
|
||||
<viewController id="01J-lp-oVM" sceneMemberID="viewController">
|
||||
<layoutGuides>
|
||||
<viewControllerLayoutGuide type="top" id="Llm-lL-Icb"/>
|
||||
<viewControllerLayoutGuide type="bottom" id="xb3-aO-Qok"/>
|
||||
</layoutGuides>
|
||||
<view key="view" contentMode="scaleToFill" id="Ze5-6b-2t3">
|
||||
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||
<subviews>
|
||||
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="TensorFlowLite" textAlignment="center" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="3Gq-PV-hia">
|
||||
<rect key="frame" x="16" y="315" width="343" height="38.5"/>
|
||||
<fontDescription key="fontDescription" type="boldSystem" pointSize="32"/>
|
||||
<nil key="textColor"/>
|
||||
<nil key="highlightedColor"/>
|
||||
</label>
|
||||
</subviews>
|
||||
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstItem="3Gq-PV-hia" firstAttribute="leading" secondItem="Ze5-6b-2t3" secondAttribute="leading" constant="16" id="aXL-9T-5Pf"/>
|
||||
<constraint firstItem="3Gq-PV-hia" firstAttribute="centerY" secondItem="Ze5-6b-2t3" secondAttribute="centerY" id="cDf-Go-1FR"/>
|
||||
<constraint firstAttribute="trailing" secondItem="3Gq-PV-hia" secondAttribute="trailing" constant="16" id="fB9-BX-A3B"/>
|
||||
</constraints>
|
||||
</view>
|
||||
</viewController>
|
||||
<placeholder placeholderIdentifier="IBFirstResponder" id="iYj-Kq-Ea1" userLabel="First Responder" sceneMemberID="firstResponder"/>
|
||||
</objects>
|
||||
<point key="canvasLocation" x="52" y="374.66266866566718"/>
|
||||
</scene>
|
||||
</scenes>
|
||||
</document>
|
@ -0,0 +1,95 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14313.18" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="BYZ-38-t0r">
|
||||
<device id="retina4_7" orientation="portrait">
|
||||
<adaptation id="fullscreen"/>
|
||||
</device>
|
||||
<dependencies>
|
||||
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14283.14"/>
|
||||
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||
</dependencies>
|
||||
<scenes>
|
||||
<!--View Controller-->
|
||||
<scene sceneID="tne-QT-ifu">
|
||||
<objects>
|
||||
<viewController storyboardIdentifier="viewController" useStoryboardIdentifierAsRestorationIdentifier="YES" id="BYZ-38-t0r" customClass="ViewController" customModule="third_party_tensorflow_lite_experimental_swift_TensorFlowLiteAppLib" sceneMemberID="viewController">
|
||||
<layoutGuides>
|
||||
<viewControllerLayoutGuide type="top" id="y3c-jy-aDJ"/>
|
||||
<viewControllerLayoutGuide type="bottom" id="wfy-db-euE"/>
|
||||
</layoutGuides>
|
||||
<view key="view" contentMode="scaleToFill" id="8bC-Xf-vdC">
|
||||
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||
<subviews>
|
||||
<textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" editable="NO" selectable="NO" translatesAutoresizingMaskIntoConstraints="NO" id="7Mj-sL-hrd">
|
||||
<rect key="frame" x="0.0" y="367" width="375" height="300"/>
|
||||
<color key="backgroundColor" red="0.0" green="0.47843137250000001" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="300" id="YUb-MC-D5w"/>
|
||||
</constraints>
|
||||
<color key="textColor" cocoaTouchSystemColor="tableCellGroupedBackgroundColor"/>
|
||||
<fontDescription key="fontDescription" type="system" pointSize="14"/>
|
||||
<textInputTraits key="textInputTraits" autocapitalizationType="sentences"/>
|
||||
</textView>
|
||||
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Qwg-EP-bd6" userLabel="Bottom Toolbar">
|
||||
<rect key="frame" x="0.0" y="323" width="375" height="44"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="44" id="jhT-Q0-E9N"/>
|
||||
</constraints>
|
||||
<items>
|
||||
<barButtonItem style="plain" systemItem="flexibleSpace" id="P3q-uA-YUa"/>
|
||||
<barButtonItem title="Invoke Interpreter" id="A4J-Mg-nmd" userLabel="Invoke Button">
|
||||
<connections>
|
||||
<action selector="invokeInterpreter:" destination="BYZ-38-t0r" id="lZU-x7-PsJ"/>
|
||||
</connections>
|
||||
</barButtonItem>
|
||||
<barButtonItem style="plain" systemItem="flexibleSpace" id="Qad-Pa-ySg"/>
|
||||
</items>
|
||||
</toolbar>
|
||||
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Gkb-TR-PCB" userLabel="Top Toolbar">
|
||||
<rect key="frame" x="0.0" y="28" width="375" height="44"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="height" constant="44" id="hSD-2q-fUE"/>
|
||||
</constraints>
|
||||
<items>
|
||||
<barButtonItem style="plain" id="LKw-TX-bbH">
|
||||
<segmentedControl key="customView" opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="left" contentVerticalAlignment="top" segmentControlStyle="bar" selectedSegmentIndex="0" id="rhA-nW-xzT">
|
||||
<rect key="frame" x="16" y="7" width="343" height="30"/>
|
||||
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
|
||||
<segments>
|
||||
<segment title="Add"/>
|
||||
<segment title="AddQuantized"/>
|
||||
<segment title="MultiAdd"/>
|
||||
</segments>
|
||||
<connections>
|
||||
<action selector="modelChanged:" destination="BYZ-38-t0r" eventType="valueChanged" id="YnG-Ov-B5D"/>
|
||||
</connections>
|
||||
</segmentedControl>
|
||||
</barButtonItem>
|
||||
</items>
|
||||
</toolbar>
|
||||
</subviews>
|
||||
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||
<constraints>
|
||||
<constraint firstAttribute="trailing" secondItem="Gkb-TR-PCB" secondAttribute="trailing" id="4Cr-Sf-I7n"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="bottom" secondItem="wfy-db-euE" secondAttribute="top" id="6ot-zD-sze"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="top" secondItem="Qwg-EP-bd6" secondAttribute="bottom" id="ELA-C6-NiG"/>
|
||||
<constraint firstAttribute="trailing" secondItem="7Mj-sL-hrd" secondAttribute="trailing" id="HDO-xr-mBl"/>
|
||||
<constraint firstItem="Gkb-TR-PCB" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="Kmo-6K-gS4"/>
|
||||
<constraint firstItem="Qwg-EP-bd6" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="hGu-lm-fMG"/>
|
||||
<constraint firstAttribute="trailing" secondItem="Qwg-EP-bd6" secondAttribute="trailing" id="iXR-LK-nTO"/>
|
||||
<constraint firstItem="7Mj-sL-hrd" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="nr7-jW-ZYf"/>
|
||||
<constraint firstItem="Gkb-TR-PCB" firstAttribute="top" secondItem="y3c-jy-aDJ" secondAttribute="bottom" constant="8" id="uCF-VW-rR0"/>
|
||||
</constraints>
|
||||
</view>
|
||||
<connections>
|
||||
<outlet property="invokeButton" destination="A4J-Mg-nmd" id="UxZ-Ft-E45"/>
|
||||
<outlet property="modelControl" destination="rhA-nW-xzT" id="KKf-TT-BQ2"/>
|
||||
<outlet property="resultsTextView" destination="7Mj-sL-hrd" id="T4I-z4-tYA"/>
|
||||
</connections>
|
||||
</viewController>
|
||||
<placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/>
|
||||
</objects>
|
||||
<point key="canvasLocation" x="125.59999999999999" y="133.5832083958021"/>
|
||||
</scene>
|
||||
</scenes>
|
||||
</document>
|
@ -0,0 +1,13 @@
|
||||
import Foundation
|
||||
|
||||
extension Data {
|
||||
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||
///
|
||||
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||
/// data from the resulting buffer has undefined behavior.
|
||||
/// - Parameter array: An array with elements of type `T`.
|
||||
init<T>(copyingBufferOf array: [T]) {
|
||||
self = array.withUnsafeBufferPointer(Data.init)
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleDevelopmentRegion</key>
|
||||
<string>en</string>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>$(EXECUTABLE_NAME)</string>
|
||||
<key>CFBundleIdentifier</key>
|
||||
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
|
||||
<key>CFBundleInfoDictionaryVersion</key>
|
||||
<string>6.0</string>
|
||||
<key>CFBundleName</key>
|
||||
<string>$(PRODUCT_NAME)</string>
|
||||
<key>CFBundlePackageType</key>
|
||||
<string>APPL</string>
|
||||
<key>CFBundleShortVersionString</key>
|
||||
<string>1.0</string>
|
||||
<key>CFBundleVersion</key>
|
||||
<string>0.0.1</string>
|
||||
<key>LSRequiresIPhoneOS</key>
|
||||
<true/>
|
||||
<key>NSCameraUsageDescription</key>
|
||||
<string>NSCameraUsageDescription</string>
|
||||
<key>NSPhotoLibraryUsageDescription</key>
|
||||
<string>Select a photo to detect objects in.</string>
|
||||
<key>UILaunchStoryboardName</key>
|
||||
<string>LaunchScreen</string>
|
||||
<key>UIMainStoryboardFile</key>
|
||||
<string>Main</string>
|
||||
<key>UIRequiredDeviceCapabilities</key>
|
||||
<array>
|
||||
<string>armv7</string>
|
||||
</array>
|
||||
<key>UISupportedInterfaceOrientations</key>
|
||||
<array>
|
||||
<string>UIInterfaceOrientationPortrait</string>
|
||||
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||
</array>
|
||||
<key>UISupportedInterfaceOrientations~ipad</key>
|
||||
<array>
|
||||
<string>UIInterfaceOrientationPortrait</string>
|
||||
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
@ -0,0 +1,299 @@
|
||||
import TensorFlowLite
|
||||
import UIKit
|
||||
|
||||
class ViewController: UIViewController {
|
||||
|
||||
// MARK: - Properties
|
||||
|
||||
/// TensorFlowLite interpreter object for performing inference from a given model.
|
||||
private var interpreter: Interpreter?
|
||||
|
||||
/// Serial dispatch queue for managing `Interpreter` calls.
|
||||
private let interpreterQueue = DispatchQueue(
|
||||
label: Constant.dispatchQueueLabel,
|
||||
qos: .userInitiated
|
||||
)
|
||||
|
||||
/// The currently selected model.
|
||||
private var currentModel: Model {
|
||||
guard let currentModel = Model(rawValue: modelControl.selectedSegmentIndex) else {
|
||||
preconditionFailure("Invalid model for selected segment index.")
|
||||
}
|
||||
return currentModel
|
||||
}
|
||||
|
||||
/// A description of the current model.
|
||||
private var modelDescription: String {
|
||||
guard let interpreter = interpreter else { return "" }
|
||||
let inputCount = interpreter.inputTensorCount
|
||||
let outputCount = interpreter.outputTensorCount
|
||||
let inputTensors = (0..<inputCount).map { index in
|
||||
var tensorInfo = " Input \(index + 1): "
|
||||
do {
|
||||
let tensor = try interpreter.input(at: index)
|
||||
tensorInfo += "\(tensor)"
|
||||
} catch let error {
|
||||
tensorInfo += "\(error.localizedDescription)"
|
||||
}
|
||||
return tensorInfo
|
||||
}.joined(separator: "\n")
|
||||
let outputTensors = (0..<outputCount).map { index in
|
||||
var tensorInfo = " Output \(index + 1): "
|
||||
do {
|
||||
let tensor = try interpreter.output(at: index)
|
||||
tensorInfo += "\(tensor)"
|
||||
} catch let error {
|
||||
tensorInfo += "\(error.localizedDescription)"
|
||||
}
|
||||
return tensorInfo
|
||||
}.joined(separator: "\n")
|
||||
return "Model Description:\n" +
|
||||
" Input Tensor Count = \(inputCount)\n\(inputTensors)\n\n" +
|
||||
" Output Tensor Count = \(outputCount)\n\(outputTensors)"
|
||||
}
|
||||
|
||||
// MARK: - IBOutlets
|
||||
|
||||
/// A segmented control for changing models. See the `Model` enum for available models.
|
||||
@IBOutlet private var modelControl: UISegmentedControl!
|
||||
|
||||
@IBOutlet private var resultsTextView: UITextView!
|
||||
@IBOutlet private var invokeButton: UIBarButtonItem!
|
||||
|
||||
// MARK: - UIViewController
|
||||
|
||||
override func viewDidLoad() {
|
||||
super.viewDidLoad()
|
||||
|
||||
invokeButton.isEnabled = false
|
||||
loadModel()
|
||||
}
|
||||
|
||||
// MARK: - IBActions
|
||||
|
||||
@IBAction func modelChanged(_ sender: Any) {
|
||||
invokeButton.isEnabled = false
|
||||
updateResultsText("Switched to the \(currentModel.description) model.")
|
||||
loadModel()
|
||||
}
|
||||
|
||||
@IBAction func invokeInterpreter(_ sender: Any) {
|
||||
switch currentModel {
|
||||
case .add:
|
||||
invokeAdd()
|
||||
case .addQuantized:
|
||||
invokeAddQuantized()
|
||||
case .multiAdd:
|
||||
invokeMultiAdd()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private func loadModel() {
|
||||
let fileInfo = currentModel.fileInfo
|
||||
guard let modelPath = Bundle.main.path(forResource: fileInfo.name, ofType: fileInfo.extension)
|
||||
else {
|
||||
updateResultsText("Failed to load the \(currentModel.description) model.")
|
||||
return
|
||||
}
|
||||
setUpInterpreter(withModelPath: modelPath)
|
||||
}
|
||||
|
||||
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||
interpreterQueue.async {
|
||||
do {
|
||||
var options = InterpreterOptions()
|
||||
options.isErrorLoggingEnabled = true
|
||||
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to create the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
safeDispatchOnMain { self.invokeButton.isEnabled = true }
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeAdd() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
try interpreter.resizeInput(at: 0, to: [2])
|
||||
try interpreter.allocateTensors()
|
||||
let input: [Float32] = [1, 3]
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 2 add operations on input \(input.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
let data = Data(copyingBufferOf: input)
|
||||
try interpreter.copy(data, toInputAt: 0)
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: 0)
|
||||
let results: () -> String = {
|
||||
guard let results = [Float32](unsafeData: outputTensor.data) else { return "No results." }
|
||||
return resultsText + results.description
|
||||
}
|
||||
self.updateResultsText(results())
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeAddQuantized() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
try interpreter.resizeInput(at: 0, to: [2])
|
||||
try interpreter.allocateTensors()
|
||||
let input: [UInt8] = [1, 3]
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 2 add operations on quantized input \(input.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
let data = Data(input)
|
||||
try interpreter.copy(data, toInputAt: 0)
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: 0)
|
||||
let results: () -> String = {
|
||||
guard let quantizationParameters = outputTensor.quantizationParameters else {
|
||||
return "No results."
|
||||
}
|
||||
let quantizedResults = [UInt8](outputTensor.data)
|
||||
let dequantizedResults = quantizedResults.map {
|
||||
quantizationParameters.scale * Float(Int($0) - quantizationParameters.zeroPoint)
|
||||
}
|
||||
return resultsText + quantizedResults.description +
|
||||
", dequantized results: " + dequantizedResults.description
|
||||
}
|
||||
self.updateResultsText(results())
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func invokeMultiAdd() {
|
||||
interpreterQueue.async {
|
||||
guard let interpreter = self.interpreter else {
|
||||
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||
return
|
||||
}
|
||||
do {
|
||||
let shape = TensorShape(2)
|
||||
try (0..<interpreter.inputTensorCount).forEach { index in
|
||||
try interpreter.resizeInput(at: index, to: shape)
|
||||
}
|
||||
try interpreter.allocateTensors()
|
||||
let inputs = try (0..<interpreter.inputTensorCount).map { index -> [Float32] in
|
||||
let input = [Float32(index + 1), Float32(index + 2)]
|
||||
let data = Data(copyingBufferOf: input)
|
||||
try interpreter.copy(data, toInputAt: index)
|
||||
return input
|
||||
}
|
||||
let resultsText = self.modelDescription + "\n\n" +
|
||||
"Performing 3 add operations on inputs \(inputs.description) equals: "
|
||||
self.updateResultsText(resultsText)
|
||||
try interpreter.invoke()
|
||||
let results = try (0..<interpreter.outputTensorCount).map { index -> [Float32] in
|
||||
let tensor = try interpreter.output(at: index)
|
||||
return [Float32](unsafeData: tensor.data) ?? []
|
||||
}
|
||||
self.updateResultsText(resultsText + results.description)
|
||||
} catch let error {
|
||||
self.updateResultsText(
|
||||
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func updateResultsText(_ text: String? = nil) {
|
||||
safeDispatchOnMain { self.resultsTextView.text = text }
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
static let dispatchQueueLabel = "TensorFlowLiteInterpreterQueue"
|
||||
static let nilInterpreterErrorMessage =
|
||||
"Failed to invoke the interpreter because the interpreter was nil."
|
||||
}
|
||||
|
||||
/// Models that can be loaded by the TensorFlow Lite `Interpreter`.
|
||||
private enum Model: Int, CustomStringConvertible {
|
||||
/// A float model that performs two add operations on one input tensor and returns the result in
|
||||
/// one output tensor.
|
||||
case add = 0
|
||||
/// A quantized model that performs two add operations on one input tensor and returns the result
|
||||
/// in one output tensor.
|
||||
case addQuantized = 1
|
||||
/// A float model that performs three add operations on four input tensors and returns the results
|
||||
/// in 2 output tensors.
|
||||
case multiAdd = 2
|
||||
|
||||
var fileInfo: (name: String, extension: String) {
|
||||
switch self {
|
||||
case .add:
|
||||
return Add.fileInfo
|
||||
case .addQuantized:
|
||||
return AddQuantized.fileInfo
|
||||
case .multiAdd:
|
||||
return MultiAdd.fileInfo
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - CustomStringConvertible
|
||||
|
||||
var description: String {
|
||||
switch self {
|
||||
case .add:
|
||||
return Add.name
|
||||
case .addQuantized:
|
||||
return AddQuantized.name
|
||||
case .multiAdd:
|
||||
return MultiAdd.name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Values for the `Add` model.
|
||||
private enum Add {
|
||||
static let name = "Add"
|
||||
static let fileInfo = (name: "add", extension: "bin")
|
||||
}
|
||||
|
||||
/// Values for the `AddQuantized` model.
|
||||
private enum AddQuantized {
|
||||
static let name = "AddQuantized"
|
||||
static let fileInfo = (name: "add_quantized", extension: "bin")
|
||||
}
|
||||
|
||||
/// Values for the `MultiAdd` model.
|
||||
private enum MultiAdd {
|
||||
static let name = "MultiAdd"
|
||||
static let fileInfo = (name: "multi_add", extension: "bin")
|
||||
}
|
||||
|
||||
// MARK: - Fileprivate
|
||||
|
||||
/// Safely dispatches the given block on the main queue. If the current thread is `main`, the block
|
||||
/// is executed synchronously; otherwise, the block is executed asynchronously on the main thread.
|
||||
fileprivate func safeDispatchOnMain(_ block: @escaping () -> Void) {
|
||||
if Thread.isMainThread { block(); return }
|
||||
DispatchQueue.main.async { block() }
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class InterpreterOptionsTests: XCTestCase {
|
||||
|
||||
func testInterpreterOptions_InitWithDefaultValues() {
|
||||
let options = InterpreterOptions()
|
||||
XCTAssertNil(options.threadCount)
|
||||
XCTAssertFalse(options.isErrorLoggingEnabled)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_InitWithCustomValues() {
|
||||
var options = InterpreterOptions()
|
||||
options.threadCount = 2
|
||||
XCTAssertEqual(options.threadCount, 2)
|
||||
options.isErrorLoggingEnabled = true
|
||||
XCTAssertTrue(options.isErrorLoggingEnabled)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_Equatable() {
|
||||
var options1 = InterpreterOptions()
|
||||
var options2 = InterpreterOptions()
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options1.threadCount = 2
|
||||
options2.threadCount = 2
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.threadCount = 3
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
options2.threadCount = 2
|
||||
|
||||
options1.isErrorLoggingEnabled = true
|
||||
options2.isErrorLoggingEnabled = true
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.isErrorLoggingEnabled = false
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
}
|
||||
}
|
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
@ -0,0 +1,315 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class InterpreterTests: XCTestCase {
|
||||
|
||||
var interpreter: Interpreter!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
interpreter = try! Interpreter(modelPath: AddModel.path)
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
interpreter = nil
|
||||
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPath() {
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
||||
}
|
||||
|
||||
func testInterpreter_Init_ThrowsFailedToLoadModel() {
|
||||
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPathAndOptions() {
|
||||
var options = InterpreterOptions()
|
||||
options.threadCount = 2
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorCount() {
|
||||
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorCount() {
|
||||
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertNoThrow(try interpreter.invoke())
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
||||
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
||||
XCTAssertEqual(outputTensor, AddModel.outputTensor)
|
||||
let expectedResults = [Float32](unsafeData: outputTensor.data)
|
||||
XCTAssertEqual(expectedResults, AddModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex)
|
||||
XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor)
|
||||
let expectedResults = [UInt8](outputTensor.data)
|
||||
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.invoke()
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
||||
let maxIndex = AddModel.outputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape() {
|
||||
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.resizeInput(
|
||||
at: AddModel.invalidIndex,
|
||||
to: [2, 2, 3]
|
||||
)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.copy(
|
||||
AddModel.inputData,
|
||||
toInputAt: AddModel.invalidIndex
|
||||
)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||
XCTAssertThrowsError(try interpreter.copy(
|
||||
invalidData,
|
||||
toInputAt: AddModel.validIndex
|
||||
)) { error in
|
||||
self.assertEqualErrors(
|
||||
actual: error,
|
||||
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_AllocateTensors() {
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private func setUpAddModelInputTensor() throws {
|
||||
precondition(interpreter != nil)
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||
}
|
||||
|
||||
private func setUpAddQuantizedModelInputTensor() throws {
|
||||
precondition(interpreter != nil)
|
||||
try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex)
|
||||
}
|
||||
|
||||
private func assertEqualErrors(actual: Error, expected: InterpreterError) {
|
||||
guard let actual = actual as? InterpreterError else {
|
||||
XCTFail("Actual error should be of type InterpreterError.")
|
||||
return
|
||||
}
|
||||
XCTAssertEqual(actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
/// Values for the `add.bin` model.
|
||||
private enum AddModel {
|
||||
static let info = (name: "add", extension: "bin")
|
||||
static let inputTensorCount = 1
|
||||
static let outputTensorCount = 1
|
||||
static let invalidIndex = 1
|
||||
static let validIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let dataCount = inputData.count
|
||||
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
||||
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
||||
static let results = [Float32(3.0), Float32(9.0)]
|
||||
|
||||
static let inputTensor = Tensor(
|
||||
name: "input",
|
||||
dataType: .float32,
|
||||
shape: shape,
|
||||
data: inputData
|
||||
)
|
||||
static let outputTensor = Tensor(
|
||||
name: "output",
|
||||
dataType: .float32,
|
||||
shape: shape,
|
||||
data: outputData
|
||||
)
|
||||
|
||||
static var path: String = {
|
||||
let bundle = Bundle(for: InterpreterTests.self)
|
||||
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||
return path
|
||||
}()
|
||||
}
|
||||
|
||||
/// Values for the `add_quantized.bin` model.
|
||||
private enum AddQuantizedModel {
|
||||
static let info = (name: "add_quantized", extension: "bin")
|
||||
static let inputOutputIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let inputData = Data([1, 3])
|
||||
static let outputData = Data([3, 9])
|
||||
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
||||
static let results: [UInt8] = [3, 9]
|
||||
|
||||
static let inputTensor = Tensor(
|
||||
name: "input",
|
||||
dataType: .uInt8,
|
||||
shape: shape,
|
||||
data: inputData,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
static let outputTensor = Tensor(
|
||||
name: "output",
|
||||
dataType: .uInt8,
|
||||
shape: shape,
|
||||
data: outputData,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
|
||||
static var path: String = {
|
||||
let bundle = Bundle(for: InterpreterTests.self)
|
||||
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||
return path
|
||||
}()
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension Array {
|
||||
/// Creates a new array from the bytes of the given unsafe data.
|
||||
///
|
||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||
/// `MemoryLayout<Element>.stride`.
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
}
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
||||
|
||||
extension Data {
|
||||
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||
///
|
||||
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||
/// data from the resulting buffer has undefined behavior.
|
||||
/// - Parameter array: An array with elements of type `T`.
|
||||
init<T>(copyingBufferOf array: [T]) {
|
||||
self = array.withUnsafeBufferPointer(Data.init)
|
||||
}
|
||||
}
|
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class ModelTests: XCTestCase {
|
||||
|
||||
var modelPath: String!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
let bundle = Bundle(for: type(of: self))
|
||||
guard let modelPath = bundle.path(
|
||||
forResource: Constant.modelInfo.name,
|
||||
ofType: Constant.modelInfo.extension)
|
||||
else {
|
||||
XCTFail("Failed to get the model file path.")
|
||||
return
|
||||
}
|
||||
self.modelPath = modelPath
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
modelPath = nil
|
||||
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testModel_InitWithFilePath() {
|
||||
XCTAssertNotNil(Model(filePath: modelPath))
|
||||
}
|
||||
|
||||
func testModel_InitWithEmptyFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: ""))
|
||||
}
|
||||
|
||||
func testModel_InitWithInvalidFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: "invalid/path"))
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
static let modelInfo = (name: "add", extension: "bin")
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class QuantizationParametersTests: XCTestCase {
|
||||
|
||||
func testQuantizationParameters_InitWithCustomValues() {
|
||||
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters.scale, 0.5)
|
||||
XCTAssertEqual(parameters.zeroPoint, 1)
|
||||
}
|
||||
|
||||
func testQuantizationParameters_Equatable() {
|
||||
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters1, parameters2)
|
||||
|
||||
let parameters3 = QuantizationParameters(scale: 0.4, zeroPoint: 1)
|
||||
XCTAssertNotEqual(parameters1, parameters3)
|
||||
XCTAssertNotEqual(parameters2, parameters3)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension QuantizationParameters: Equatable {
|
||||
public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool {
|
||||
return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint
|
||||
}
|
||||
}
|
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
@ -0,0 +1,83 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class TensorTests: XCTestCase {
|
||||
|
||||
// MARK: - Tensor
|
||||
|
||||
func testTensor_Init() {
|
||||
let name = "InputTensor"
|
||||
let dataType: TensorDataType = .uInt8
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let inputTensor = Tensor(
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
shape: shape,
|
||||
data: data,
|
||||
quantizationParameters: quantizationParameters
|
||||
)
|
||||
XCTAssertEqual(inputTensor.name, name)
|
||||
XCTAssertEqual(inputTensor.dataType, dataType)
|
||||
XCTAssertEqual(inputTensor.shape, shape)
|
||||
XCTAssertEqual(inputTensor.data, data)
|
||||
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
||||
}
|
||||
|
||||
// MARK: - TensorShape
|
||||
|
||||
func testTensorShape_InitWithArray() {
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithElements() {
|
||||
let shape = TensorShape(2, 2, 3)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithArrayLiteral() {
|
||||
let shape: TensorShape = [2, 2, 3]
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
private enum Constant {
|
||||
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
|
||||
static let dimensions = [2, 2, 3]
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension TensorShape: Equatable {
|
||||
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
|
||||
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
extension Tensor: Equatable {
|
||||
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
|
||||
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
|
||||
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
|
||||
}
|
||||
}
|
@ -30,14 +30,19 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||
PIP_PACKAGE_QUERY_EXPRESSION = (
|
||||
"deps(//tensorflow/tools/pip_package:build_pip_package)")
|
||||
|
||||
# List of file paths containing BUILD files that should not be included for the
|
||||
# pip smoke test.
|
||||
BUILD_BLACKLIST = [
|
||||
"tensorflow/lite/examples/android",
|
||||
"tensorflow/lite/experimental/swift",
|
||||
]
|
||||
|
||||
def GetBuild(dir_base):
|
||||
"""Get the list of BUILD file all targets recursively startind at dir_base."""
|
||||
items = []
|
||||
for root, _, files in os.walk(dir_base):
|
||||
for name in files:
|
||||
if (name == "BUILD" and
|
||||
root.find("tensorflow/lite/examples/android") == -1):
|
||||
if (name == "BUILD" and root not in BUILD_BLACKLIST):
|
||||
items.append("//" + root + ":all")
|
||||
return items
|
||||
|
||||
@ -67,9 +72,9 @@ def BuildPyTestDependencies():
|
||||
|
||||
PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies()
|
||||
|
||||
# Hard-coded blacklist of files if not included in pip package
|
||||
# TODO(amitpatankar): Clean up blacklist.
|
||||
BLACKLIST = [
|
||||
# List of dependencies that should not included in the pip package.
|
||||
DEPENDENCY_BLACKLIST = [
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
"//tensorflow:no_tensorflow_py_deps",
|
||||
@ -82,9 +87,7 @@ BLACKLIST = [
|
||||
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
||||
"//tensorflow/python/feature_column:vocabulary_testdata",
|
||||
"//tensorflow/python:framework/test_file_system.so",
|
||||
# contrib
|
||||
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
|
||||
"//tensorflow/contrib/keras:testing_utils",
|
||||
# lite
|
||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm",
|
||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm.py",
|
||||
"//tensorflow/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long
|
||||
@ -93,6 +96,9 @@ BLACKLIST = [
|
||||
"//tensorflow/lite/python:interpreter_test",
|
||||
"//tensorflow/lite/python:interpreter.py",
|
||||
"//tensorflow/lite/python:interpreter_test.py",
|
||||
# contrib
|
||||
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
|
||||
"//tensorflow/contrib/keras:testing_utils",
|
||||
"//tensorflow/contrib/ffmpeg:test_data",
|
||||
"//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base",
|
||||
"//tensorflow/contrib/hadoop:test_data",
|
||||
@ -149,8 +155,8 @@ def main():
|
||||
# File extensions and endings to ignore
|
||||
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
|
||||
|
||||
ignored_files = 0
|
||||
blacklisted_files = len(BLACKLIST)
|
||||
ignored_files_count = 0
|
||||
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
||||
# Compare dependencies
|
||||
for dependency in tf_py_test_dependencies_list:
|
||||
if dependency and dependency.startswith("//tensorflow"):
|
||||
@ -158,16 +164,16 @@ def main():
|
||||
# Ignore extensions
|
||||
if any(dependency.endswith(ext) for ext in ignore_extensions):
|
||||
ignore = True
|
||||
ignored_files += 1
|
||||
ignored_files_count += 1
|
||||
|
||||
# Check if the dependency is in the pip package, the blacklist, or
|
||||
# should be ignored because of its file extension
|
||||
# Check if the dependency is in the pip package, the dependency blacklist,
|
||||
# or should be ignored because of its file extension.
|
||||
if not (ignore or dependency in pip_package_dependencies_list or
|
||||
dependency in BLACKLIST):
|
||||
dependency in DEPENDENCY_BLACKLIST):
|
||||
missing_dependencies.append(dependency)
|
||||
|
||||
print("Ignored files: %d" % ignored_files)
|
||||
print("Blacklisted files: %d" % blacklisted_files)
|
||||
print("Ignored files count: %d" % ignored_files_count)
|
||||
print("Blacklisted dependencies count: %d" % blacklisted_dependencies_count)
|
||||
if missing_dependencies:
|
||||
print("Missing the following dependencies from pip_packages:")
|
||||
for missing_dependency in missing_dependencies:
|
||||
|
Loading…
Reference in New Issue
Block a user