parent
d183818d4d
commit
59d5fbcc1a
100
tensorflow/lite/experimental/swift/BUILD
Normal file
100
tensorflow/lite/experimental/swift/BUILD
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# TensorFlow Lite for Swift.
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:private"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_unit_test")
|
||||||
|
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
|
||||||
|
|
||||||
|
MINIMUM_OS_VERSION = "9.0"
|
||||||
|
|
||||||
|
SWIFT_COPTS = [
|
||||||
|
"-wmo",
|
||||||
|
]
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "TensorFlowLite",
|
||||||
|
srcs = glob(["Sources/*.swift"]),
|
||||||
|
copts = SWIFT_COPTS,
|
||||||
|
module_name = "TensorFlowLite",
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/experimental/c:c_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "TensorFlowLiteTests",
|
||||||
|
size = "small",
|
||||||
|
minimum_os_version = MINIMUM_OS_VERSION,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
# DISABLED: Following sanitizer tests are not supported by iOS test targets.
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
],
|
||||||
|
deps = [":TensorFlowLiteTestsLib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "TensorFlowLiteTestsLib",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = glob(["Tests/*.swift"]),
|
||||||
|
copts = SWIFT_COPTS,
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [
|
||||||
|
":TensorFlowLite",
|
||||||
|
":TestResources",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "TestResources",
|
||||||
|
resources = [
|
||||||
|
"//tensorflow/lite:testdata/add.bin",
|
||||||
|
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||||
|
"//tensorflow/lite:testdata/multi_add.bin",
|
||||||
|
],
|
||||||
|
tags = ["manual"],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_application(
|
||||||
|
name = "TensorFlowLiteApp",
|
||||||
|
app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]),
|
||||||
|
bundle_id = "com.tensorflow.lite.swift.TensorFlowLite",
|
||||||
|
families = [
|
||||||
|
"ipad",
|
||||||
|
"iphone",
|
||||||
|
],
|
||||||
|
infoplists = ["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist"],
|
||||||
|
launch_storyboard = "TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard",
|
||||||
|
minimum_os_version = MINIMUM_OS_VERSION,
|
||||||
|
sdk_frameworks = [
|
||||||
|
"CoreGraphics",
|
||||||
|
],
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [":TensorFlowLiteAppLib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "TensorFlowLiteAppLib",
|
||||||
|
srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]),
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [
|
||||||
|
":TensorFlowLite",
|
||||||
|
":TensorFlowLiteAppResources",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "TensorFlowLiteAppResources",
|
||||||
|
storyboards = glob([
|
||||||
|
"TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard",
|
||||||
|
]),
|
||||||
|
tags = ["manual"],
|
||||||
|
deps = [":TestResources"],
|
||||||
|
)
|
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
202
tensorflow/lite/experimental/swift/LICENSE
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
54
tensorflow/lite/experimental/swift/README.md
Normal file
54
tensorflow/lite/experimental/swift/README.md
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# TensorFlow Lite for Swift
|
||||||
|
|
||||||
|
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
|
||||||
|
solution for Swift developers. It enables low-latency inference of on-device
|
||||||
|
machine learning models with a small binary size and fast performance supporting
|
||||||
|
hardware acceleration.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Bazel
|
||||||
|
|
||||||
|
In your `BUILD` file, add the `TensorFlowLite` dependency:
|
||||||
|
|
||||||
|
```python
|
||||||
|
swift_library(
|
||||||
|
# ...
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/swift:TensorFlowLite",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
In your Swift files, import the module:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import TensorFlowLite
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tulsi
|
||||||
|
|
||||||
|
Open the `TensorFlowLite.tulsiproj` using the [TulsiApp](https://github.com/bazelbuild/tulsi) or by
|
||||||
|
running the [`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||||
|
script:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
generate_xcodeproj.sh --genconfig tensorflow/lite/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||||
|
```
|
||||||
|
|
||||||
|
### CocoaPods
|
||||||
|
|
||||||
|
Add the following to your `Podfile`:
|
||||||
|
|
||||||
|
```ruby
|
||||||
|
use_frameworks!
|
||||||
|
pod 'TensorFlowLiteSwift'
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run `pod install`.
|
||||||
|
|
||||||
|
In your Swift files, import the module:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import TensorFlowLite
|
||||||
|
```
|
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
265
tensorflow/lite/experimental/swift/Sources/Interpreter.swift
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import TensorFlowLiteCAPI
|
||||||
|
|
||||||
|
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
||||||
|
public final class Interpreter {
|
||||||
|
|
||||||
|
/// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
|
||||||
|
private typealias CInterpreter = OpaquePointer
|
||||||
|
|
||||||
|
/// Total number of input tensors associated with the model.
|
||||||
|
public var inputTensorCount: Int {
|
||||||
|
return Int(TFL_InterpreterGetInputTensorCount(cInterpreter))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Total number of output tensors associated with the model.
|
||||||
|
public var outputTensorCount: Int {
|
||||||
|
return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The underlying `TFL_Interpreter` C pointer.
|
||||||
|
private var cInterpreter: CInterpreter?
|
||||||
|
|
||||||
|
/// Creates a new model interpreter instance.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||||
|
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that
|
||||||
|
/// interpreter will determine the configuration options.
|
||||||
|
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||||
|
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
||||||
|
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||||
|
|
||||||
|
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
||||||
|
guard let cOptions = TFL_NewInterpreterOptions() else {
|
||||||
|
throw InterpreterError.failedToCreateInterpreter
|
||||||
|
}
|
||||||
|
if let threadCount = options.threadCount, threadCount > 0 {
|
||||||
|
TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount))
|
||||||
|
}
|
||||||
|
if options.isErrorLoggingEnabled {
|
||||||
|
TFL_InterpreterOptionsSetErrorReporter(
|
||||||
|
cOptions,
|
||||||
|
{ (_, format, arguments) in
|
||||||
|
guard let cFormat = format,
|
||||||
|
let message = String(cFormat: cFormat, arguments: arguments)
|
||||||
|
else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
print(String(describing: InterpreterError.tensorFlowLiteError(message)))
|
||||||
|
},
|
||||||
|
nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return cOptions
|
||||||
|
}
|
||||||
|
defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) }
|
||||||
|
|
||||||
|
guard let cInterpreter = TFL_NewInterpreter(model.cModel, cInterpreterOptions) else {
|
||||||
|
throw InterpreterError.failedToCreateInterpreter
|
||||||
|
}
|
||||||
|
self.cInterpreter = cInterpreter
|
||||||
|
}
|
||||||
|
|
||||||
|
deinit {
|
||||||
|
TFL_DeleteInterpreter(cInterpreter)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the interpreter to perform inference from the loaded graph.
|
||||||
|
///
|
||||||
|
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
||||||
|
public func invoke() throws {
|
||||||
|
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
||||||
|
// TODO(b/117510052): Determine which error to throw.
|
||||||
|
throw InterpreterError.allocateTensorsRequired
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the input tensor at the given index.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - index: The index for the input tensor.
|
||||||
|
/// - Throws: An error if the index is invalid or the tensors have not been allocated.
|
||||||
|
/// - Returns: The input tensor at the given index.
|
||||||
|
public func input(at index: Int) throws -> Tensor {
|
||||||
|
let maxIndex = inputTensorCount - 1
|
||||||
|
guard case 0...maxIndex = index else {
|
||||||
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
|
}
|
||||||
|
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)),
|
||||||
|
let bytes = TFL_TensorData(cTensor),
|
||||||
|
let nameCString = TFL_TensorName(cTensor)
|
||||||
|
else {
|
||||||
|
throw InterpreterError.allocateTensorsRequired
|
||||||
|
}
|
||||||
|
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||||
|
throw InterpreterError.invalidTensorDataType
|
||||||
|
}
|
||||||
|
|
||||||
|
let name = String(cString: nameCString)
|
||||||
|
let rank = TFL_TensorNumDims(cTensor)
|
||||||
|
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||||
|
let shape = TensorShape(dimensions)
|
||||||
|
let byteCount = TFL_TensorByteSize(cTensor)
|
||||||
|
let data = Data(bytes: bytes, count: byteCount)
|
||||||
|
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||||
|
let scale = cQuantizationParams.scale
|
||||||
|
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||||
|
var quantizationParameters: QuantizationParameters? = nil
|
||||||
|
if scale != 0.0 {
|
||||||
|
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||||
|
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||||
|
}
|
||||||
|
let tensor = Tensor(
|
||||||
|
name: name,
|
||||||
|
dataType: dataType,
|
||||||
|
shape: shape,
|
||||||
|
data: data,
|
||||||
|
quantizationParameters: quantizationParameters
|
||||||
|
)
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the output tensor at the given index.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - index: The index for the output tensor.
|
||||||
|
/// - Throws: An error if the index is invalid, tensors haven't been allocated, or interpreter
|
||||||
|
/// hasn't been invoked for models that dynamically compute output tensors based on the values
|
||||||
|
/// of its input tensors.
|
||||||
|
/// - Returns: The output tensor at the given index.
|
||||||
|
public func output(at index: Int) throws -> Tensor {
|
||||||
|
let maxIndex = outputTensorCount - 1
|
||||||
|
guard case 0...maxIndex = index else {
|
||||||
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
|
}
|
||||||
|
guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)),
|
||||||
|
let bytes = TFL_TensorData(cTensor),
|
||||||
|
let nameCString = TFL_TensorName(cTensor)
|
||||||
|
else {
|
||||||
|
// TODO(b/117510052): Determine which error to throw.
|
||||||
|
throw InterpreterError.invokeInterpreterRequired
|
||||||
|
}
|
||||||
|
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||||
|
throw InterpreterError.invalidTensorDataType
|
||||||
|
}
|
||||||
|
|
||||||
|
let name = String(cString: nameCString)
|
||||||
|
let rank = TFL_TensorNumDims(cTensor)
|
||||||
|
let dimensions = (0..<rank).map { Int(TFL_TensorDim(cTensor, $0)) }
|
||||||
|
let shape = TensorShape(dimensions)
|
||||||
|
let byteCount = TFL_TensorByteSize(cTensor)
|
||||||
|
let data = Data(bytes: bytes, count: byteCount)
|
||||||
|
let cQuantizationParams = TFL_TensorQuantizationParams(cTensor)
|
||||||
|
let scale = cQuantizationParams.scale
|
||||||
|
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||||
|
var quantizationParameters: QuantizationParameters? = nil
|
||||||
|
if scale != 0.0 {
|
||||||
|
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||||
|
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||||
|
}
|
||||||
|
let tensor = Tensor(
|
||||||
|
name: name,
|
||||||
|
dataType: dataType,
|
||||||
|
shape: shape,
|
||||||
|
data: data,
|
||||||
|
quantizationParameters: quantizationParameters
|
||||||
|
)
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resizes the input tensor at the given index to the specified tensor shape.
|
||||||
|
///
|
||||||
|
/// - Note: After resizing an input tensor, the client **must** explicitly call
|
||||||
|
/// `allocateTensors()` before attempting to access the resized tensor data or invoking the
|
||||||
|
/// interpreter to perform inference.
|
||||||
|
/// - Parameters:
|
||||||
|
/// - index: The index for the input tensor.
|
||||||
|
/// - shape: The shape that the input tensor should be resized to.
|
||||||
|
/// - Throws: An error if the input tensor at the given index could not be resized.
|
||||||
|
public func resizeInput(at index: Int, to shape: TensorShape) throws {
|
||||||
|
let maxIndex = inputTensorCount - 1
|
||||||
|
guard case 0...maxIndex = index else {
|
||||||
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
|
}
|
||||||
|
guard TFL_InterpreterResizeInputTensor(
|
||||||
|
cInterpreter,
|
||||||
|
Int32(index),
|
||||||
|
shape.int32Dimensions,
|
||||||
|
Int32(shape.rank)
|
||||||
|
) == kTfLiteOk
|
||||||
|
else {
|
||||||
|
throw InterpreterError.failedToResizeInputTensor(index: index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies the given data to the input tensor at the given index.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - data: The data to be copied to the input tensor's data buffer.
|
||||||
|
/// - index: The index for the input tensor.
|
||||||
|
/// - Throws: An error if the `data.count` does not match the input tensor's `data.count` or if
|
||||||
|
/// the given index is invalid.
|
||||||
|
/// - Returns: The input tensor with the copied data.
|
||||||
|
@discardableResult
|
||||||
|
public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor {
|
||||||
|
let maxIndex = inputTensorCount - 1
|
||||||
|
guard case 0...maxIndex = index else {
|
||||||
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
|
}
|
||||||
|
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)) else {
|
||||||
|
throw InterpreterError.allocateTensorsRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
let byteCount = TFL_TensorByteSize(cTensor)
|
||||||
|
guard data.count == byteCount else {
|
||||||
|
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||||
|
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||||
|
return try input(at: index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocates memory for all input tensors based on their `TensorShape`s.
|
||||||
|
///
|
||||||
|
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
||||||
|
/// interpreter and/or resizing any input tensors.
|
||||||
|
/// - Throws: An error if memory could not be allocated for the input tensors.
|
||||||
|
public func allocateTensors() throws {
|
||||||
|
guard TFL_InterpreterAllocateTensors(cInterpreter) == kTfLiteOk else {
|
||||||
|
throw InterpreterError.failedToAllocateTensors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
extension String {
|
||||||
|
/// Returns a new `String` initialized by using the given format C array as a template into which
|
||||||
|
/// the remaining argument values are substituted according to the user’s default locale.
|
||||||
|
///
|
||||||
|
/// - Note: Returns `nil` if a new `String` could not be constructed from the given values.
|
||||||
|
/// - Parameters:
|
||||||
|
/// - cFormat: The format C array as a template for substituting values.
|
||||||
|
/// - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`.
|
||||||
|
init?(cFormat: UnsafePointer<CChar>, arguments: CVaListPointer) {
|
||||||
|
var buffer: UnsafeMutablePointer<CChar>?
|
||||||
|
guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil }
|
||||||
|
self.init(validatingUTF8: cString)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// TensorFlow Lite interpreter errors.
|
||||||
|
public enum InterpreterError: Error {
|
||||||
|
case invalidTensorIndex(index: Int, maxIndex: Int)
|
||||||
|
case invalidTensorDataCount(provided: Int, required: Int)
|
||||||
|
case invalidTensorDataType
|
||||||
|
case failedToLoadModel
|
||||||
|
case failedToCreateInterpreter
|
||||||
|
case failedToResizeInputTensor(index: Int)
|
||||||
|
case failedToCopyDataToInputTensor
|
||||||
|
case failedToAllocateTensors
|
||||||
|
case allocateTensorsRequired
|
||||||
|
case invokeInterpreterRequired
|
||||||
|
case tensorFlowLiteError(String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
extension InterpreterError: LocalizedError {
|
||||||
|
/// Localized description of the interpreter error.
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .invalidTensorIndex(let index, let maxIndex):
|
||||||
|
return "Invalid tensor index \(index), max index is \(maxIndex)."
|
||||||
|
case .invalidTensorDataCount(let providedCount, let requiredCount):
|
||||||
|
return "Provided data count \(providedCount) must match the required count \(requiredCount)."
|
||||||
|
case .invalidTensorDataType:
|
||||||
|
return "Tensor data type is unsupported or could not be determined because of a model error."
|
||||||
|
case .failedToLoadModel:
|
||||||
|
return "Failed to load the given model."
|
||||||
|
case .failedToCreateInterpreter:
|
||||||
|
return "Failed to create the interpreter."
|
||||||
|
case .failedToResizeInputTensor(let index):
|
||||||
|
return "Failed to resize input tesnor at index \(index)."
|
||||||
|
case .failedToCopyDataToInputTensor:
|
||||||
|
return "Failed to copy data to input tensor."
|
||||||
|
case .failedToAllocateTensors:
|
||||||
|
return "Failed to allocate memory for input tensors."
|
||||||
|
case .allocateTensorsRequired:
|
||||||
|
return "Must call allocateTensors()."
|
||||||
|
case .invokeInterpreterRequired:
|
||||||
|
return "Must call invoke()."
|
||||||
|
case .tensorFlowLiteError(let message):
|
||||||
|
return "TensorFlow Lite Error: \(message)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension InterpreterError: CustomStringConvertible {
|
||||||
|
/// Textual representation of the TensorFlow Lite interpreter error.
|
||||||
|
public var description: String {
|
||||||
|
return errorDescription ?? "Unknown error."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if swift(>=4.2)
|
||||||
|
extension InterpreterError: Equatable {}
|
||||||
|
#else
|
||||||
|
extension InterpreterError: Equatable {
|
||||||
|
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
|
||||||
|
switch (lhs, rhs) {
|
||||||
|
case (.invalidTensorDataType, .invalidTensorDataType),
|
||||||
|
(.failedToLoadModel, .failedToLoadModel),
|
||||||
|
(.failedToCreateInterpreter, .failedToCreateInterpreter),
|
||||||
|
(.failedToAllocateTensors, .failedToAllocateTensors),
|
||||||
|
(.allocateTensorsRequired, .allocateTensorsRequired),
|
||||||
|
(.invokeInterpreterRequired, .invokeInterpreterRequired):
|
||||||
|
return true
|
||||||
|
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
|
||||||
|
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
|
||||||
|
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
|
||||||
|
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
|
||||||
|
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
|
||||||
|
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
|
||||||
|
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
|
||||||
|
return lhsIndex == rhsIndex
|
||||||
|
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
|
||||||
|
return lhsMessage == rhsMessage
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // swift(>=4.2)
|
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Custom configuration options for a TensorFlow Lite interpreter.
|
||||||
|
public struct InterpreterOptions: Equatable {
|
||||||
|
|
||||||
|
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
|
||||||
|
/// indicates that the `Interpreter` will decide the number of threads to use.
|
||||||
|
public var threadCount: Int? = nil
|
||||||
|
|
||||||
|
/// Whether error logging to the console is enabled. The default is `false`.
|
||||||
|
public var isErrorLoggingEnabled = false
|
||||||
|
|
||||||
|
/// Creates a new instance of interpreter options.
|
||||||
|
public init() {}
|
||||||
|
}
|
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
40
tensorflow/lite/experimental/swift/Sources/Model.swift
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import TensorFlowLiteCAPI
|
||||||
|
|
||||||
|
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
|
||||||
|
final class Model {
|
||||||
|
|
||||||
|
/// The `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
|
||||||
|
typealias CModel = OpaquePointer
|
||||||
|
|
||||||
|
/// The underlying `TFL_Model` C pointer.
|
||||||
|
let cModel: CModel?
|
||||||
|
|
||||||
|
/// Creates a new model instance.
|
||||||
|
///
|
||||||
|
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
||||||
|
/// - Parameters:
|
||||||
|
/// - filePath: Local file path to a TensorFlow Lite model.
|
||||||
|
init?(filePath: String) {
|
||||||
|
guard !filePath.isEmpty, let cModel = TFL_NewModelFromFile(filePath) else { return nil }
|
||||||
|
self.cModel = cModel
|
||||||
|
}
|
||||||
|
|
||||||
|
deinit {
|
||||||
|
TFL_DeleteModel(cModel)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Parameters that determine the mapping of quantized values to real values. Quantized values can
|
||||||
|
/// be mapped to float values using the following conversion:
|
||||||
|
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||||
|
public struct QuantizationParameters {
|
||||||
|
|
||||||
|
/// Difference between real values corresponding to consecutive quantized values differing by 1.
|
||||||
|
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||||
|
public let scale: Float
|
||||||
|
|
||||||
|
/// Quantized value that corresponds to the real 0 value.
|
||||||
|
public let zeroPoint: Int
|
||||||
|
|
||||||
|
/// Creates a new quantization parameters instance.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - scale: Scale value for asymmetric quantization.
|
||||||
|
/// - zeroPoint: Zero point for asymmetric quantization.
|
||||||
|
init(scale: Float, zeroPoint: Int) {
|
||||||
|
self.scale = scale
|
||||||
|
self.zeroPoint = zeroPoint
|
||||||
|
}
|
||||||
|
}
|
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
138
tensorflow/lite/experimental/swift/Sources/Tensor.swift
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import TensorFlowLiteCAPI
|
||||||
|
|
||||||
|
/// An input or output tensor in a TensorFlow Lite graph.
|
||||||
|
public struct Tensor {
|
||||||
|
|
||||||
|
/// Name of the tensor.
|
||||||
|
public let name: String
|
||||||
|
|
||||||
|
/// Data type of the tensor.
|
||||||
|
public let dataType: TensorDataType
|
||||||
|
|
||||||
|
/// Shape of the tensor.
|
||||||
|
public let shape: TensorShape
|
||||||
|
|
||||||
|
/// Data in the input or output tensor.
|
||||||
|
public let data: Data
|
||||||
|
|
||||||
|
/// Quantization parameters for the tensor if using a quantized model.
|
||||||
|
public let quantizationParameters: QuantizationParameters?
|
||||||
|
|
||||||
|
/// Creates a new input or output tensor instance.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - name: Name of the tensor.
|
||||||
|
/// - dataType: Data type of the tensor.
|
||||||
|
/// - data: Data in the input tensor.
|
||||||
|
/// - quantizationParameters Quantization parameters for the tensor if using a quantized model.
|
||||||
|
/// The default is `nil`.
|
||||||
|
init(
|
||||||
|
name: String,
|
||||||
|
dataType: TensorDataType,
|
||||||
|
shape: TensorShape,
|
||||||
|
data: Data,
|
||||||
|
quantizationParameters: QuantizationParameters? = nil
|
||||||
|
) {
|
||||||
|
self.name = name
|
||||||
|
self.dataType = dataType
|
||||||
|
self.shape = shape
|
||||||
|
self.data = data
|
||||||
|
self.quantizationParameters = quantizationParameters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Supported TensorFlow Lite tensor data types.
|
||||||
|
public enum TensorDataType: Equatable {
|
||||||
|
/// 32-bit single precision floating point tensor data type.
|
||||||
|
case float32
|
||||||
|
/// 8-bit unsigned integer tensor data type.
|
||||||
|
case uInt8
|
||||||
|
/// 16-bit signed integer tensor data type.
|
||||||
|
case int16
|
||||||
|
/// 32-bit signed integer tensor data type.
|
||||||
|
case int32
|
||||||
|
/// 64-bit signed integer tensor data type.
|
||||||
|
case int64
|
||||||
|
/// Boolean tensor data type.
|
||||||
|
case bool
|
||||||
|
|
||||||
|
/// Creates a new tensor data type from the given `TFL_Type` or `nil` if the data type is
|
||||||
|
/// unsupported or could not be determined because there was an error.
|
||||||
|
///
|
||||||
|
/// - Parameter type: A data type supported by a tensor.
|
||||||
|
init?(type: TFL_Type) {
|
||||||
|
switch type {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
self = .float32
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
self = .uInt8
|
||||||
|
case kTfLiteInt16:
|
||||||
|
self = .int16
|
||||||
|
case kTfLiteInt32:
|
||||||
|
self = .int32
|
||||||
|
case kTfLiteInt64:
|
||||||
|
self = .int64
|
||||||
|
case kTfLiteBool:
|
||||||
|
self = .bool
|
||||||
|
case kTfLiteNoType:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The shape of a TensorFlow Lite tensor.
|
||||||
|
public struct TensorShape {
|
||||||
|
|
||||||
|
/// The number of dimensions of the tensor.
|
||||||
|
public let rank: Int
|
||||||
|
|
||||||
|
/// Array of dimensions for the tensor.
|
||||||
|
public let dimensions: [Int]
|
||||||
|
|
||||||
|
/// Array of `Int32` dimensions for the tensor.
|
||||||
|
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||||
|
|
||||||
|
/// Creates a new tensor shape instance with the given array of dimensions.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - dimensions: Dimensions for the tensor.
|
||||||
|
public init(_ dimensions: [Int]) {
|
||||||
|
self.rank = dimensions.count
|
||||||
|
self.dimensions = dimensions
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new tensor shape instance with the given elements representing the dimensions.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - elements: Dimensions for the tensor.
|
||||||
|
public init(_ elements: Int...) {
|
||||||
|
self.init(elements)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension TensorShape: ExpressibleByArrayLiteral {
|
||||||
|
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - arrayLiteral: Dimensions for the tensor.
|
||||||
|
public init(arrayLiteral: Int...) {
|
||||||
|
self.init(arrayLiteral)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"sourceFilters" : [
|
||||||
|
"third_party/tensorflow/lite/experimental/c",
|
||||||
|
"third_party/tensorflow/lite/experimental/swift",
|
||||||
|
"third_party/tensorflow/lite/experimental/swift/Sources",
|
||||||
|
"third_party/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp",
|
||||||
|
"third_party/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj",
|
||||||
|
"third_party/tensorflow/lite/experimental/swift/Tests",
|
||||||
|
],
|
||||||
|
"buildTargets" : [
|
||||||
|
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLite",
|
||||||
|
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLiteApp",
|
||||||
|
"//third_party/tensorflow/lite/experimental/swift:TensorFlowLiteTests",
|
||||||
|
],
|
||||||
|
"projectName" : "TensorFlowLite",
|
||||||
|
"optionSet" : {
|
||||||
|
"LaunchActionPreActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BazelBuildStartupOptionsRelease" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BazelBuildOptionsRelease" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BazelBuildOptionsDebug" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"EnvironmentVariables" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BuildActionPreActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"CommandlineArguments" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"TestActionPreActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BazelBuildStartupOptionsDebug" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"BuildActionPostActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"TestActionPostActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
},
|
||||||
|
"LaunchActionPostActionScript" : {
|
||||||
|
"p" : "$(inherited)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalFilePaths" : [
|
||||||
|
"third_party/tensorflow/lite/experimental/swift/BUILD"
|
||||||
|
]
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"configDefaults" : {
|
||||||
|
"optionSet" : {
|
||||||
|
"ProjectPrioritizesSwift" : {
|
||||||
|
"p" : "YES"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"projectName" : "TensorFlowLite",
|
||||||
|
"packages" : [
|
||||||
|
"third_party/tensorflow/lite/experimental/swift"
|
||||||
|
],
|
||||||
|
"workspaceRoot" : "../../../../../.."
|
||||||
|
}
|
@ -0,0 +1,345 @@
|
|||||||
|
// !$*UTF8*$!
|
||||||
|
{
|
||||||
|
archiveVersion = 1;
|
||||||
|
classes = {
|
||||||
|
};
|
||||||
|
objectVersion = 50;
|
||||||
|
objects = {
|
||||||
|
|
||||||
|
/* Begin PBXBuildFile section */
|
||||||
|
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */; };
|
||||||
|
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B722146ED64006C3AEF /* AppDelegate.swift */; };
|
||||||
|
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B742146ED64006C3AEF /* ViewController.swift */; };
|
||||||
|
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B762146ED64006C3AEF /* Main.storyboard */; };
|
||||||
|
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B792146ED66006C3AEF /* Assets.xcassets */; };
|
||||||
|
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */; };
|
||||||
|
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */; };
|
||||||
|
/* End PBXBuildFile section */
|
||||||
|
|
||||||
|
/* Begin PBXFileReference section */
|
||||||
|
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Data+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||||
|
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TensorFlowLiteApp.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||||
|
4AA72B722146ED64006C3AEF /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
|
||||||
|
4AA72B742146ED64006C3AEF /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = "<group>"; };
|
||||||
|
4AA72B772146ED64006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = "<group>"; };
|
||||||
|
4AA72B792146ED66006C3AEF /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||||
|
4AA72B7C2146ED66006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
|
||||||
|
4AA72B7E2146ED66006C3AEF /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
|
||||||
|
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+TensorFlowLite.swift"; sourceTree = "<group>"; };
|
||||||
|
/* End PBXFileReference section */
|
||||||
|
|
||||||
|
/* Begin PBXFrameworksBuildPhase section */
|
||||||
|
4AA72B6C2146ED64006C3AEF /* Frameworks */ = {
|
||||||
|
isa = PBXFrameworksBuildPhase;
|
||||||
|
buildActionMask = 2147483647;
|
||||||
|
files = (
|
||||||
|
);
|
||||||
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
};
|
||||||
|
/* End PBXFrameworksBuildPhase section */
|
||||||
|
|
||||||
|
/* Begin PBXGroup section */
|
||||||
|
4AA72B662146ED64006C3AEF = {
|
||||||
|
isa = PBXGroup;
|
||||||
|
children = (
|
||||||
|
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||||
|
4AA72B702146ED64006C3AEF /* Products */,
|
||||||
|
);
|
||||||
|
sourceTree = "<group>";
|
||||||
|
};
|
||||||
|
4AA72B702146ED64006C3AEF /* Products */ = {
|
||||||
|
isa = PBXGroup;
|
||||||
|
children = (
|
||||||
|
4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */,
|
||||||
|
);
|
||||||
|
name = Products;
|
||||||
|
sourceTree = "<group>";
|
||||||
|
};
|
||||||
|
4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||||
|
isa = PBXGroup;
|
||||||
|
children = (
|
||||||
|
4AA72B722146ED64006C3AEF /* AppDelegate.swift */,
|
||||||
|
4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */,
|
||||||
|
4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */,
|
||||||
|
4AA72B742146ED64006C3AEF /* ViewController.swift */,
|
||||||
|
4AA72B762146ED64006C3AEF /* Main.storyboard */,
|
||||||
|
4AA72B792146ED66006C3AEF /* Assets.xcassets */,
|
||||||
|
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */,
|
||||||
|
4AA72B7E2146ED66006C3AEF /* Info.plist */,
|
||||||
|
);
|
||||||
|
path = TensorFlowLiteApp;
|
||||||
|
sourceTree = "<group>";
|
||||||
|
};
|
||||||
|
/* End PBXGroup section */
|
||||||
|
|
||||||
|
/* Begin PBXNativeTarget section */
|
||||||
|
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */ = {
|
||||||
|
isa = PBXNativeTarget;
|
||||||
|
buildConfigurationList = 4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */;
|
||||||
|
buildPhases = (
|
||||||
|
4AA72B6B2146ED64006C3AEF /* Sources */,
|
||||||
|
4AA72B6C2146ED64006C3AEF /* Frameworks */,
|
||||||
|
4AA72B6D2146ED64006C3AEF /* Resources */,
|
||||||
|
);
|
||||||
|
buildRules = (
|
||||||
|
);
|
||||||
|
dependencies = (
|
||||||
|
);
|
||||||
|
name = TensorFlowLiteApp;
|
||||||
|
productName = TensorFlowLiteApp;
|
||||||
|
productReference = 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */;
|
||||||
|
productType = "com.apple.product-type.application";
|
||||||
|
};
|
||||||
|
/* End PBXNativeTarget section */
|
||||||
|
|
||||||
|
/* Begin PBXProject section */
|
||||||
|
4AA72B672146ED64006C3AEF /* Project object */ = {
|
||||||
|
isa = PBXProject;
|
||||||
|
attributes = {
|
||||||
|
LastSwiftUpdateCheck = 0940;
|
||||||
|
LastUpgradeCheck = 0940;
|
||||||
|
ORGANIZATIONNAME = Google;
|
||||||
|
TargetAttributes = {
|
||||||
|
4AA72B6E2146ED64006C3AEF = {
|
||||||
|
CreatedOnToolsVersion = 9.4.1;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
buildConfigurationList = 4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */;
|
||||||
|
compatibilityVersion = "Xcode 9.3";
|
||||||
|
developmentRegion = en;
|
||||||
|
hasScannedForEncodings = 0;
|
||||||
|
knownRegions = (
|
||||||
|
en,
|
||||||
|
Base,
|
||||||
|
);
|
||||||
|
mainGroup = 4AA72B662146ED64006C3AEF;
|
||||||
|
productRefGroup = 4AA72B702146ED64006C3AEF /* Products */;
|
||||||
|
projectDirPath = "";
|
||||||
|
projectRoot = "";
|
||||||
|
targets = (
|
||||||
|
4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
/* End PBXProject section */
|
||||||
|
|
||||||
|
/* Begin PBXResourcesBuildPhase section */
|
||||||
|
4AA72B6D2146ED64006C3AEF /* Resources */ = {
|
||||||
|
isa = PBXResourcesBuildPhase;
|
||||||
|
buildActionMask = 2147483647;
|
||||||
|
files = (
|
||||||
|
4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */,
|
||||||
|
4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */,
|
||||||
|
4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */,
|
||||||
|
);
|
||||||
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
};
|
||||||
|
/* End PBXResourcesBuildPhase section */
|
||||||
|
|
||||||
|
/* Begin PBXSourcesBuildPhase section */
|
||||||
|
4AA72B6B2146ED64006C3AEF /* Sources */ = {
|
||||||
|
isa = PBXSourcesBuildPhase;
|
||||||
|
buildActionMask = 2147483647;
|
||||||
|
files = (
|
||||||
|
4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */,
|
||||||
|
4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */,
|
||||||
|
4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */,
|
||||||
|
4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */,
|
||||||
|
);
|
||||||
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
};
|
||||||
|
/* End PBXSourcesBuildPhase section */
|
||||||
|
|
||||||
|
/* Begin PBXVariantGroup section */
|
||||||
|
4AA72B762146ED64006C3AEF /* Main.storyboard */ = {
|
||||||
|
isa = PBXVariantGroup;
|
||||||
|
children = (
|
||||||
|
4AA72B772146ED64006C3AEF /* Base */,
|
||||||
|
);
|
||||||
|
name = Main.storyboard;
|
||||||
|
sourceTree = "<group>";
|
||||||
|
};
|
||||||
|
4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */ = {
|
||||||
|
isa = PBXVariantGroup;
|
||||||
|
children = (
|
||||||
|
4AA72B7C2146ED66006C3AEF /* Base */,
|
||||||
|
);
|
||||||
|
name = LaunchScreen.storyboard;
|
||||||
|
sourceTree = "<group>";
|
||||||
|
};
|
||||||
|
/* End PBXVariantGroup section */
|
||||||
|
|
||||||
|
/* Begin XCBuildConfiguration section */
|
||||||
|
4AA72B7F2146ED66006C3AEF /* Debug */ = {
|
||||||
|
isa = XCBuildConfiguration;
|
||||||
|
buildSettings = {
|
||||||
|
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||||
|
CLANG_ANALYZER_NONNULL = YES;
|
||||||
|
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||||
|
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||||
|
CLANG_CXX_LIBRARY = "libc++";
|
||||||
|
CLANG_ENABLE_MODULES = YES;
|
||||||
|
CLANG_ENABLE_OBJC_ARC = YES;
|
||||||
|
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||||
|
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||||
|
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_COMMA = YES;
|
||||||
|
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||||
|
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||||
|
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||||
|
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||||
|
CLANG_WARN_EMPTY_BODY = YES;
|
||||||
|
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||||
|
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||||
|
CLANG_WARN_INT_CONVERSION = YES;
|
||||||
|
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||||
|
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||||
|
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||||
|
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||||
|
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||||
|
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||||
|
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||||
|
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||||
|
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||||
|
COPY_PHASE_STRIP = NO;
|
||||||
|
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||||
|
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||||
|
ENABLE_TESTABILITY = YES;
|
||||||
|
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||||
|
GCC_DYNAMIC_NO_PIC = NO;
|
||||||
|
GCC_NO_COMMON_BLOCKS = YES;
|
||||||
|
GCC_OPTIMIZATION_LEVEL = 0;
|
||||||
|
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||||
|
"DEBUG=1",
|
||||||
|
"$(inherited)",
|
||||||
|
);
|
||||||
|
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||||
|
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||||
|
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||||
|
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||||
|
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||||
|
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||||
|
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||||
|
MTL_ENABLE_DEBUG_INFO = YES;
|
||||||
|
ONLY_ACTIVE_ARCH = YES;
|
||||||
|
SDKROOT = iphoneos;
|
||||||
|
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
|
||||||
|
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||||
|
};
|
||||||
|
name = Debug;
|
||||||
|
};
|
||||||
|
4AA72B802146ED66006C3AEF /* Release */ = {
|
||||||
|
isa = XCBuildConfiguration;
|
||||||
|
buildSettings = {
|
||||||
|
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||||
|
CLANG_ANALYZER_NONNULL = YES;
|
||||||
|
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||||
|
CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
|
||||||
|
CLANG_CXX_LIBRARY = "libc++";
|
||||||
|
CLANG_ENABLE_MODULES = YES;
|
||||||
|
CLANG_ENABLE_OBJC_ARC = YES;
|
||||||
|
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||||
|
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||||
|
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_COMMA = YES;
|
||||||
|
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||||
|
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||||
|
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||||
|
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||||
|
CLANG_WARN_EMPTY_BODY = YES;
|
||||||
|
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||||
|
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||||
|
CLANG_WARN_INT_CONVERSION = YES;
|
||||||
|
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||||
|
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||||
|
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||||
|
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||||
|
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||||
|
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||||
|
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||||
|
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||||
|
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||||
|
CODE_SIGN_IDENTITY = "iPhone Developer";
|
||||||
|
COPY_PHASE_STRIP = NO;
|
||||||
|
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||||
|
ENABLE_NS_ASSERTIONS = NO;
|
||||||
|
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||||
|
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||||
|
GCC_NO_COMMON_BLOCKS = YES;
|
||||||
|
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||||
|
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||||
|
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||||
|
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||||
|
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||||
|
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||||
|
IPHONEOS_DEPLOYMENT_TARGET = 11.4;
|
||||||
|
MTL_ENABLE_DEBUG_INFO = NO;
|
||||||
|
SDKROOT = iphoneos;
|
||||||
|
SWIFT_COMPILATION_MODE = wholemodule;
|
||||||
|
SWIFT_OPTIMIZATION_LEVEL = "-O";
|
||||||
|
VALIDATE_PRODUCT = YES;
|
||||||
|
};
|
||||||
|
name = Release;
|
||||||
|
};
|
||||||
|
4AA72B822146ED66006C3AEF /* Debug */ = {
|
||||||
|
isa = XCBuildConfiguration;
|
||||||
|
buildSettings = {
|
||||||
|
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||||
|
CODE_SIGN_STYLE = Automatic;
|
||||||
|
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||||
|
LD_RUNPATH_SEARCH_PATHS = (
|
||||||
|
"$(inherited)",
|
||||||
|
"@executable_path/Frameworks",
|
||||||
|
);
|
||||||
|
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||||
|
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||||
|
SWIFT_VERSION = 4.0;
|
||||||
|
TARGETED_DEVICE_FAMILY = "1,2";
|
||||||
|
};
|
||||||
|
name = Debug;
|
||||||
|
};
|
||||||
|
4AA72B832146ED66006C3AEF /* Release */ = {
|
||||||
|
isa = XCBuildConfiguration;
|
||||||
|
buildSettings = {
|
||||||
|
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||||
|
CODE_SIGN_STYLE = Automatic;
|
||||||
|
INFOPLIST_FILE = TensorFlowLiteApp/Info.plist;
|
||||||
|
LD_RUNPATH_SEARCH_PATHS = (
|
||||||
|
"$(inherited)",
|
||||||
|
"@executable_path/Frameworks",
|
||||||
|
);
|
||||||
|
PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite;
|
||||||
|
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||||
|
SWIFT_VERSION = 4.0;
|
||||||
|
TARGETED_DEVICE_FAMILY = "1,2";
|
||||||
|
};
|
||||||
|
name = Release;
|
||||||
|
};
|
||||||
|
/* End XCBuildConfiguration section */
|
||||||
|
|
||||||
|
/* Begin XCConfigurationList section */
|
||||||
|
4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */ = {
|
||||||
|
isa = XCConfigurationList;
|
||||||
|
buildConfigurations = (
|
||||||
|
4AA72B7F2146ED66006C3AEF /* Debug */,
|
||||||
|
4AA72B802146ED66006C3AEF /* Release */,
|
||||||
|
);
|
||||||
|
defaultConfigurationIsVisible = 0;
|
||||||
|
defaultConfigurationName = Release;
|
||||||
|
};
|
||||||
|
4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */ = {
|
||||||
|
isa = XCConfigurationList;
|
||||||
|
buildConfigurations = (
|
||||||
|
4AA72B822146ED66006C3AEF /* Debug */,
|
||||||
|
4AA72B832146ED66006C3AEF /* Release */,
|
||||||
|
);
|
||||||
|
defaultConfigurationIsVisible = 0;
|
||||||
|
defaultConfigurationName = Release;
|
||||||
|
};
|
||||||
|
/* End XCConfigurationList section */
|
||||||
|
};
|
||||||
|
rootObject = 4AA72B672146ED64006C3AEF /* Project object */;
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
import UIKit
|
||||||
|
|
||||||
|
@UIApplicationMain
|
||||||
|
|
||||||
|
final class AppDelegate: UIResponder, UIApplicationDelegate {
|
||||||
|
|
||||||
|
/// The main window of the app.
|
||||||
|
var window: UIWindow?
|
||||||
|
|
||||||
|
func application(
|
||||||
|
_ application: UIApplication,
|
||||||
|
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? = nil
|
||||||
|
) -> Bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
#if !swift(>=4.2)
|
||||||
|
extension UIApplication {
|
||||||
|
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
|
||||||
|
}
|
||||||
|
#endif // !swift(>=4.2)
|
@ -0,0 +1,22 @@
|
|||||||
|
import Foundation
|
||||||
|
|
||||||
|
extension Array {
|
||||||
|
/// Creates a new array from the bytes of the given unsafe data.
|
||||||
|
///
|
||||||
|
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
|
||||||
|
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
|
||||||
|
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
|
||||||
|
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||||
|
/// `MemoryLayout<Element>.stride`.
|
||||||
|
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||||
|
init?(unsafeData: Data) {
|
||||||
|
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||||
|
let elements = unsafeData.withUnsafeBytes {
|
||||||
|
UnsafeBufferPointer<Element>(
|
||||||
|
start: $0,
|
||||||
|
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.init(elements)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,98 @@
|
|||||||
|
{
|
||||||
|
"images" : [
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "20x20",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "20x20",
|
||||||
|
"scale" : "3x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "29x29",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "29x29",
|
||||||
|
"scale" : "3x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "40x40",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "40x40",
|
||||||
|
"scale" : "3x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "60x60",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "iphone",
|
||||||
|
"size" : "60x60",
|
||||||
|
"scale" : "3x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "20x20",
|
||||||
|
"scale" : "1x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "20x20",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "29x29",
|
||||||
|
"scale" : "1x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "29x29",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "40x40",
|
||||||
|
"scale" : "1x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "40x40",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "76x76",
|
||||||
|
"scale" : "1x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "76x76",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ipad",
|
||||||
|
"size" : "83.5x83.5",
|
||||||
|
"scale" : "2x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idiom" : "ios-marketing",
|
||||||
|
"size" : "1024x1024",
|
||||||
|
"scale" : "1x"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"info" : {
|
||||||
|
"version" : 1,
|
||||||
|
"author" : "xcode"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"info" : {
|
||||||
|
"version" : 1,
|
||||||
|
"author" : "xcode"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14109" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" launchScreen="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="01J-lp-oVM">
|
||||||
|
<device id="retina4_7" orientation="portrait">
|
||||||
|
<adaptation id="fullscreen"/>
|
||||||
|
</device>
|
||||||
|
<dependencies>
|
||||||
|
<deployment identifier="iOS"/>
|
||||||
|
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14088"/>
|
||||||
|
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||||
|
</dependencies>
|
||||||
|
<scenes>
|
||||||
|
<!--View Controller-->
|
||||||
|
<scene sceneID="EHf-IW-A2E">
|
||||||
|
<objects>
|
||||||
|
<viewController id="01J-lp-oVM" sceneMemberID="viewController">
|
||||||
|
<layoutGuides>
|
||||||
|
<viewControllerLayoutGuide type="top" id="Llm-lL-Icb"/>
|
||||||
|
<viewControllerLayoutGuide type="bottom" id="xb3-aO-Qok"/>
|
||||||
|
</layoutGuides>
|
||||||
|
<view key="view" contentMode="scaleToFill" id="Ze5-6b-2t3">
|
||||||
|
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||||
|
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||||
|
<subviews>
|
||||||
|
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="TensorFlowLite" textAlignment="center" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="3Gq-PV-hia">
|
||||||
|
<rect key="frame" x="16" y="315" width="343" height="38.5"/>
|
||||||
|
<fontDescription key="fontDescription" type="boldSystem" pointSize="32"/>
|
||||||
|
<nil key="textColor"/>
|
||||||
|
<nil key="highlightedColor"/>
|
||||||
|
</label>
|
||||||
|
</subviews>
|
||||||
|
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||||
|
<constraints>
|
||||||
|
<constraint firstItem="3Gq-PV-hia" firstAttribute="leading" secondItem="Ze5-6b-2t3" secondAttribute="leading" constant="16" id="aXL-9T-5Pf"/>
|
||||||
|
<constraint firstItem="3Gq-PV-hia" firstAttribute="centerY" secondItem="Ze5-6b-2t3" secondAttribute="centerY" id="cDf-Go-1FR"/>
|
||||||
|
<constraint firstAttribute="trailing" secondItem="3Gq-PV-hia" secondAttribute="trailing" constant="16" id="fB9-BX-A3B"/>
|
||||||
|
</constraints>
|
||||||
|
</view>
|
||||||
|
</viewController>
|
||||||
|
<placeholder placeholderIdentifier="IBFirstResponder" id="iYj-Kq-Ea1" userLabel="First Responder" sceneMemberID="firstResponder"/>
|
||||||
|
</objects>
|
||||||
|
<point key="canvasLocation" x="52" y="374.66266866566718"/>
|
||||||
|
</scene>
|
||||||
|
</scenes>
|
||||||
|
</document>
|
@ -0,0 +1,95 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14313.18" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="BYZ-38-t0r">
|
||||||
|
<device id="retina4_7" orientation="portrait">
|
||||||
|
<adaptation id="fullscreen"/>
|
||||||
|
</device>
|
||||||
|
<dependencies>
|
||||||
|
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14283.14"/>
|
||||||
|
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
|
||||||
|
</dependencies>
|
||||||
|
<scenes>
|
||||||
|
<!--View Controller-->
|
||||||
|
<scene sceneID="tne-QT-ifu">
|
||||||
|
<objects>
|
||||||
|
<viewController storyboardIdentifier="viewController" useStoryboardIdentifierAsRestorationIdentifier="YES" id="BYZ-38-t0r" customClass="ViewController" customModule="third_party_tensorflow_lite_experimental_swift_TensorFlowLiteAppLib" sceneMemberID="viewController">
|
||||||
|
<layoutGuides>
|
||||||
|
<viewControllerLayoutGuide type="top" id="y3c-jy-aDJ"/>
|
||||||
|
<viewControllerLayoutGuide type="bottom" id="wfy-db-euE"/>
|
||||||
|
</layoutGuides>
|
||||||
|
<view key="view" contentMode="scaleToFill" id="8bC-Xf-vdC">
|
||||||
|
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
|
||||||
|
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
|
||||||
|
<subviews>
|
||||||
|
<textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" editable="NO" selectable="NO" translatesAutoresizingMaskIntoConstraints="NO" id="7Mj-sL-hrd">
|
||||||
|
<rect key="frame" x="0.0" y="367" width="375" height="300"/>
|
||||||
|
<color key="backgroundColor" red="0.0" green="0.47843137250000001" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||||
|
<constraints>
|
||||||
|
<constraint firstAttribute="height" constant="300" id="YUb-MC-D5w"/>
|
||||||
|
</constraints>
|
||||||
|
<color key="textColor" cocoaTouchSystemColor="tableCellGroupedBackgroundColor"/>
|
||||||
|
<fontDescription key="fontDescription" type="system" pointSize="14"/>
|
||||||
|
<textInputTraits key="textInputTraits" autocapitalizationType="sentences"/>
|
||||||
|
</textView>
|
||||||
|
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Qwg-EP-bd6" userLabel="Bottom Toolbar">
|
||||||
|
<rect key="frame" x="0.0" y="323" width="375" height="44"/>
|
||||||
|
<constraints>
|
||||||
|
<constraint firstAttribute="height" constant="44" id="jhT-Q0-E9N"/>
|
||||||
|
</constraints>
|
||||||
|
<items>
|
||||||
|
<barButtonItem style="plain" systemItem="flexibleSpace" id="P3q-uA-YUa"/>
|
||||||
|
<barButtonItem title="Invoke Interpreter" id="A4J-Mg-nmd" userLabel="Invoke Button">
|
||||||
|
<connections>
|
||||||
|
<action selector="invokeInterpreter:" destination="BYZ-38-t0r" id="lZU-x7-PsJ"/>
|
||||||
|
</connections>
|
||||||
|
</barButtonItem>
|
||||||
|
<barButtonItem style="plain" systemItem="flexibleSpace" id="Qad-Pa-ySg"/>
|
||||||
|
</items>
|
||||||
|
</toolbar>
|
||||||
|
<toolbar opaque="NO" clearsContextBeforeDrawing="NO" contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="Gkb-TR-PCB" userLabel="Top Toolbar">
|
||||||
|
<rect key="frame" x="0.0" y="28" width="375" height="44"/>
|
||||||
|
<constraints>
|
||||||
|
<constraint firstAttribute="height" constant="44" id="hSD-2q-fUE"/>
|
||||||
|
</constraints>
|
||||||
|
<items>
|
||||||
|
<barButtonItem style="plain" id="LKw-TX-bbH">
|
||||||
|
<segmentedControl key="customView" opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="left" contentVerticalAlignment="top" segmentControlStyle="bar" selectedSegmentIndex="0" id="rhA-nW-xzT">
|
||||||
|
<rect key="frame" x="16" y="7" width="343" height="30"/>
|
||||||
|
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
|
||||||
|
<segments>
|
||||||
|
<segment title="Add"/>
|
||||||
|
<segment title="AddQuantized"/>
|
||||||
|
<segment title="MultiAdd"/>
|
||||||
|
</segments>
|
||||||
|
<connections>
|
||||||
|
<action selector="modelChanged:" destination="BYZ-38-t0r" eventType="valueChanged" id="YnG-Ov-B5D"/>
|
||||||
|
</connections>
|
||||||
|
</segmentedControl>
|
||||||
|
</barButtonItem>
|
||||||
|
</items>
|
||||||
|
</toolbar>
|
||||||
|
</subviews>
|
||||||
|
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
|
||||||
|
<constraints>
|
||||||
|
<constraint firstAttribute="trailing" secondItem="Gkb-TR-PCB" secondAttribute="trailing" id="4Cr-Sf-I7n"/>
|
||||||
|
<constraint firstItem="7Mj-sL-hrd" firstAttribute="bottom" secondItem="wfy-db-euE" secondAttribute="top" id="6ot-zD-sze"/>
|
||||||
|
<constraint firstItem="7Mj-sL-hrd" firstAttribute="top" secondItem="Qwg-EP-bd6" secondAttribute="bottom" id="ELA-C6-NiG"/>
|
||||||
|
<constraint firstAttribute="trailing" secondItem="7Mj-sL-hrd" secondAttribute="trailing" id="HDO-xr-mBl"/>
|
||||||
|
<constraint firstItem="Gkb-TR-PCB" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="Kmo-6K-gS4"/>
|
||||||
|
<constraint firstItem="Qwg-EP-bd6" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="hGu-lm-fMG"/>
|
||||||
|
<constraint firstAttribute="trailing" secondItem="Qwg-EP-bd6" secondAttribute="trailing" id="iXR-LK-nTO"/>
|
||||||
|
<constraint firstItem="7Mj-sL-hrd" firstAttribute="leading" secondItem="8bC-Xf-vdC" secondAttribute="leading" id="nr7-jW-ZYf"/>
|
||||||
|
<constraint firstItem="Gkb-TR-PCB" firstAttribute="top" secondItem="y3c-jy-aDJ" secondAttribute="bottom" constant="8" id="uCF-VW-rR0"/>
|
||||||
|
</constraints>
|
||||||
|
</view>
|
||||||
|
<connections>
|
||||||
|
<outlet property="invokeButton" destination="A4J-Mg-nmd" id="UxZ-Ft-E45"/>
|
||||||
|
<outlet property="modelControl" destination="rhA-nW-xzT" id="KKf-TT-BQ2"/>
|
||||||
|
<outlet property="resultsTextView" destination="7Mj-sL-hrd" id="T4I-z4-tYA"/>
|
||||||
|
</connections>
|
||||||
|
</viewController>
|
||||||
|
<placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/>
|
||||||
|
</objects>
|
||||||
|
<point key="canvasLocation" x="125.59999999999999" y="133.5832083958021"/>
|
||||||
|
</scene>
|
||||||
|
</scenes>
|
||||||
|
</document>
|
@ -0,0 +1,13 @@
|
|||||||
|
import Foundation
|
||||||
|
|
||||||
|
extension Data {
|
||||||
|
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||||
|
///
|
||||||
|
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||||
|
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||||
|
/// data from the resulting buffer has undefined behavior.
|
||||||
|
/// - Parameter array: An array with elements of type `T`.
|
||||||
|
init<T>(copyingBufferOf array: [T]) {
|
||||||
|
self = array.withUnsafeBufferPointer(Data.init)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleDevelopmentRegion</key>
|
||||||
|
<string>en</string>
|
||||||
|
<key>CFBundleExecutable</key>
|
||||||
|
<string>$(EXECUTABLE_NAME)</string>
|
||||||
|
<key>CFBundleIdentifier</key>
|
||||||
|
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
|
||||||
|
<key>CFBundleInfoDictionaryVersion</key>
|
||||||
|
<string>6.0</string>
|
||||||
|
<key>CFBundleName</key>
|
||||||
|
<string>$(PRODUCT_NAME)</string>
|
||||||
|
<key>CFBundlePackageType</key>
|
||||||
|
<string>APPL</string>
|
||||||
|
<key>CFBundleShortVersionString</key>
|
||||||
|
<string>1.0</string>
|
||||||
|
<key>CFBundleVersion</key>
|
||||||
|
<string>0.0.1</string>
|
||||||
|
<key>LSRequiresIPhoneOS</key>
|
||||||
|
<true/>
|
||||||
|
<key>NSCameraUsageDescription</key>
|
||||||
|
<string>NSCameraUsageDescription</string>
|
||||||
|
<key>NSPhotoLibraryUsageDescription</key>
|
||||||
|
<string>Select a photo to detect objects in.</string>
|
||||||
|
<key>UILaunchStoryboardName</key>
|
||||||
|
<string>LaunchScreen</string>
|
||||||
|
<key>UIMainStoryboardFile</key>
|
||||||
|
<string>Main</string>
|
||||||
|
<key>UIRequiredDeviceCapabilities</key>
|
||||||
|
<array>
|
||||||
|
<string>armv7</string>
|
||||||
|
</array>
|
||||||
|
<key>UISupportedInterfaceOrientations</key>
|
||||||
|
<array>
|
||||||
|
<string>UIInterfaceOrientationPortrait</string>
|
||||||
|
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||||
|
</array>
|
||||||
|
<key>UISupportedInterfaceOrientations~ipad</key>
|
||||||
|
<array>
|
||||||
|
<string>UIInterfaceOrientationPortrait</string>
|
||||||
|
<string>UIInterfaceOrientationPortraitUpsideDown</string>
|
||||||
|
</array>
|
||||||
|
</dict>
|
||||||
|
</plist>
|
@ -0,0 +1,299 @@
|
|||||||
|
import TensorFlowLite
|
||||||
|
import UIKit
|
||||||
|
|
||||||
|
class ViewController: UIViewController {
|
||||||
|
|
||||||
|
// MARK: - Properties
|
||||||
|
|
||||||
|
/// TensorFlowLite interpreter object for performing inference from a given model.
|
||||||
|
private var interpreter: Interpreter?
|
||||||
|
|
||||||
|
/// Serial dispatch queue for managing `Interpreter` calls.
|
||||||
|
private let interpreterQueue = DispatchQueue(
|
||||||
|
label: Constant.dispatchQueueLabel,
|
||||||
|
qos: .userInitiated
|
||||||
|
)
|
||||||
|
|
||||||
|
/// The currently selected model.
|
||||||
|
private var currentModel: Model {
|
||||||
|
guard let currentModel = Model(rawValue: modelControl.selectedSegmentIndex) else {
|
||||||
|
preconditionFailure("Invalid model for selected segment index.")
|
||||||
|
}
|
||||||
|
return currentModel
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A description of the current model.
|
||||||
|
private var modelDescription: String {
|
||||||
|
guard let interpreter = interpreter else { return "" }
|
||||||
|
let inputCount = interpreter.inputTensorCount
|
||||||
|
let outputCount = interpreter.outputTensorCount
|
||||||
|
let inputTensors = (0..<inputCount).map { index in
|
||||||
|
var tensorInfo = " Input \(index + 1): "
|
||||||
|
do {
|
||||||
|
let tensor = try interpreter.input(at: index)
|
||||||
|
tensorInfo += "\(tensor)"
|
||||||
|
} catch let error {
|
||||||
|
tensorInfo += "\(error.localizedDescription)"
|
||||||
|
}
|
||||||
|
return tensorInfo
|
||||||
|
}.joined(separator: "\n")
|
||||||
|
let outputTensors = (0..<outputCount).map { index in
|
||||||
|
var tensorInfo = " Output \(index + 1): "
|
||||||
|
do {
|
||||||
|
let tensor = try interpreter.output(at: index)
|
||||||
|
tensorInfo += "\(tensor)"
|
||||||
|
} catch let error {
|
||||||
|
tensorInfo += "\(error.localizedDescription)"
|
||||||
|
}
|
||||||
|
return tensorInfo
|
||||||
|
}.joined(separator: "\n")
|
||||||
|
return "Model Description:\n" +
|
||||||
|
" Input Tensor Count = \(inputCount)\n\(inputTensors)\n\n" +
|
||||||
|
" Output Tensor Count = \(outputCount)\n\(outputTensors)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - IBOutlets
|
||||||
|
|
||||||
|
/// A segmented control for changing models. See the `Model` enum for available models.
|
||||||
|
@IBOutlet private var modelControl: UISegmentedControl!
|
||||||
|
|
||||||
|
@IBOutlet private var resultsTextView: UITextView!
|
||||||
|
@IBOutlet private var invokeButton: UIBarButtonItem!
|
||||||
|
|
||||||
|
// MARK: - UIViewController
|
||||||
|
|
||||||
|
override func viewDidLoad() {
|
||||||
|
super.viewDidLoad()
|
||||||
|
|
||||||
|
invokeButton.isEnabled = false
|
||||||
|
loadModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - IBActions
|
||||||
|
|
||||||
|
@IBAction func modelChanged(_ sender: Any) {
|
||||||
|
invokeButton.isEnabled = false
|
||||||
|
updateResultsText("Switched to the \(currentModel.description) model.")
|
||||||
|
loadModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
@IBAction func invokeInterpreter(_ sender: Any) {
|
||||||
|
switch currentModel {
|
||||||
|
case .add:
|
||||||
|
invokeAdd()
|
||||||
|
case .addQuantized:
|
||||||
|
invokeAddQuantized()
|
||||||
|
case .multiAdd:
|
||||||
|
invokeMultiAdd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private
|
||||||
|
|
||||||
|
private func loadModel() {
|
||||||
|
let fileInfo = currentModel.fileInfo
|
||||||
|
guard let modelPath = Bundle.main.path(forResource: fileInfo.name, ofType: fileInfo.extension)
|
||||||
|
else {
|
||||||
|
updateResultsText("Failed to load the \(currentModel.description) model.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setUpInterpreter(withModelPath: modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||||
|
interpreterQueue.async {
|
||||||
|
do {
|
||||||
|
var options = InterpreterOptions()
|
||||||
|
options.isErrorLoggingEnabled = true
|
||||||
|
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
||||||
|
} catch let error {
|
||||||
|
self.updateResultsText(
|
||||||
|
"Failed to create the interpreter with error: \(error.localizedDescription)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
safeDispatchOnMain { self.invokeButton.isEnabled = true }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func invokeAdd() {
|
||||||
|
interpreterQueue.async {
|
||||||
|
guard let interpreter = self.interpreter else {
|
||||||
|
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
try interpreter.resizeInput(at: 0, to: [2])
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
let input: [Float32] = [1, 3]
|
||||||
|
let resultsText = self.modelDescription + "\n\n" +
|
||||||
|
"Performing 2 add operations on input \(input.description) equals: "
|
||||||
|
self.updateResultsText(resultsText)
|
||||||
|
let data = Data(copyingBufferOf: input)
|
||||||
|
try interpreter.copy(data, toInputAt: 0)
|
||||||
|
try interpreter.invoke()
|
||||||
|
let outputTensor = try interpreter.output(at: 0)
|
||||||
|
let results: () -> String = {
|
||||||
|
guard let results = [Float32](unsafeData: outputTensor.data) else { return "No results." }
|
||||||
|
return resultsText + results.description
|
||||||
|
}
|
||||||
|
self.updateResultsText(results())
|
||||||
|
} catch let error {
|
||||||
|
self.updateResultsText(
|
||||||
|
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func invokeAddQuantized() {
|
||||||
|
interpreterQueue.async {
|
||||||
|
guard let interpreter = self.interpreter else {
|
||||||
|
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
try interpreter.resizeInput(at: 0, to: [2])
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
let input: [UInt8] = [1, 3]
|
||||||
|
let resultsText = self.modelDescription + "\n\n" +
|
||||||
|
"Performing 2 add operations on quantized input \(input.description) equals: "
|
||||||
|
self.updateResultsText(resultsText)
|
||||||
|
let data = Data(input)
|
||||||
|
try interpreter.copy(data, toInputAt: 0)
|
||||||
|
try interpreter.invoke()
|
||||||
|
let outputTensor = try interpreter.output(at: 0)
|
||||||
|
let results: () -> String = {
|
||||||
|
guard let quantizationParameters = outputTensor.quantizationParameters else {
|
||||||
|
return "No results."
|
||||||
|
}
|
||||||
|
let quantizedResults = [UInt8](outputTensor.data)
|
||||||
|
let dequantizedResults = quantizedResults.map {
|
||||||
|
quantizationParameters.scale * Float(Int($0) - quantizationParameters.zeroPoint)
|
||||||
|
}
|
||||||
|
return resultsText + quantizedResults.description +
|
||||||
|
", dequantized results: " + dequantizedResults.description
|
||||||
|
}
|
||||||
|
self.updateResultsText(results())
|
||||||
|
} catch let error {
|
||||||
|
self.updateResultsText(
|
||||||
|
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func invokeMultiAdd() {
|
||||||
|
interpreterQueue.async {
|
||||||
|
guard let interpreter = self.interpreter else {
|
||||||
|
self.updateResultsText(Constant.nilInterpreterErrorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
let shape = TensorShape(2)
|
||||||
|
try (0..<interpreter.inputTensorCount).forEach { index in
|
||||||
|
try interpreter.resizeInput(at: index, to: shape)
|
||||||
|
}
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
let inputs = try (0..<interpreter.inputTensorCount).map { index -> [Float32] in
|
||||||
|
let input = [Float32(index + 1), Float32(index + 2)]
|
||||||
|
let data = Data(copyingBufferOf: input)
|
||||||
|
try interpreter.copy(data, toInputAt: index)
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
let resultsText = self.modelDescription + "\n\n" +
|
||||||
|
"Performing 3 add operations on inputs \(inputs.description) equals: "
|
||||||
|
self.updateResultsText(resultsText)
|
||||||
|
try interpreter.invoke()
|
||||||
|
let results = try (0..<interpreter.outputTensorCount).map { index -> [Float32] in
|
||||||
|
let tensor = try interpreter.output(at: index)
|
||||||
|
return [Float32](unsafeData: tensor.data) ?? []
|
||||||
|
}
|
||||||
|
self.updateResultsText(resultsText + results.description)
|
||||||
|
} catch let error {
|
||||||
|
self.updateResultsText(
|
||||||
|
"Failed to invoke the interpreter with error: \(error.localizedDescription)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func updateResultsText(_ text: String? = nil) {
|
||||||
|
safeDispatchOnMain { self.resultsTextView.text = text }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Constants
|
||||||
|
|
||||||
|
private enum Constant {
|
||||||
|
static let dispatchQueueLabel = "TensorFlowLiteInterpreterQueue"
|
||||||
|
static let nilInterpreterErrorMessage =
|
||||||
|
"Failed to invoke the interpreter because the interpreter was nil."
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Models that can be loaded by the TensorFlow Lite `Interpreter`.
|
||||||
|
private enum Model: Int, CustomStringConvertible {
|
||||||
|
/// A float model that performs two add operations on one input tensor and returns the result in
|
||||||
|
/// one output tensor.
|
||||||
|
case add = 0
|
||||||
|
/// A quantized model that performs two add operations on one input tensor and returns the result
|
||||||
|
/// in one output tensor.
|
||||||
|
case addQuantized = 1
|
||||||
|
/// A float model that performs three add operations on four input tensors and returns the results
|
||||||
|
/// in 2 output tensors.
|
||||||
|
case multiAdd = 2
|
||||||
|
|
||||||
|
var fileInfo: (name: String, extension: String) {
|
||||||
|
switch self {
|
||||||
|
case .add:
|
||||||
|
return Add.fileInfo
|
||||||
|
case .addQuantized:
|
||||||
|
return AddQuantized.fileInfo
|
||||||
|
case .multiAdd:
|
||||||
|
return MultiAdd.fileInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - CustomStringConvertible
|
||||||
|
|
||||||
|
var description: String {
|
||||||
|
switch self {
|
||||||
|
case .add:
|
||||||
|
return Add.name
|
||||||
|
case .addQuantized:
|
||||||
|
return AddQuantized.name
|
||||||
|
case .multiAdd:
|
||||||
|
return MultiAdd.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Values for the `Add` model.
|
||||||
|
private enum Add {
|
||||||
|
static let name = "Add"
|
||||||
|
static let fileInfo = (name: "add", extension: "bin")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Values for the `AddQuantized` model.
|
||||||
|
private enum AddQuantized {
|
||||||
|
static let name = "AddQuantized"
|
||||||
|
static let fileInfo = (name: "add_quantized", extension: "bin")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Values for the `MultiAdd` model.
|
||||||
|
private enum MultiAdd {
|
||||||
|
static let name = "MultiAdd"
|
||||||
|
static let fileInfo = (name: "multi_add", extension: "bin")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Fileprivate
|
||||||
|
|
||||||
|
/// Safely dispatches the given block on the main queue. If the current thread is `main`, the block
|
||||||
|
/// is executed synchronously; otherwise, the block is executed asynchronously on the main thread.
|
||||||
|
fileprivate func safeDispatchOnMain(_ block: @escaping () -> Void) {
|
||||||
|
if Thread.isMainThread { block(); return }
|
||||||
|
DispatchQueue.main.async { block() }
|
||||||
|
}
|
@ -0,0 +1,54 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
class InterpreterOptionsTests: XCTestCase {
|
||||||
|
|
||||||
|
func testInterpreterOptions_InitWithDefaultValues() {
|
||||||
|
let options = InterpreterOptions()
|
||||||
|
XCTAssertNil(options.threadCount)
|
||||||
|
XCTAssertFalse(options.isErrorLoggingEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreterOptions_InitWithCustomValues() {
|
||||||
|
var options = InterpreterOptions()
|
||||||
|
options.threadCount = 2
|
||||||
|
XCTAssertEqual(options.threadCount, 2)
|
||||||
|
options.isErrorLoggingEnabled = true
|
||||||
|
XCTAssertTrue(options.isErrorLoggingEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreterOptions_Equatable() {
|
||||||
|
var options1 = InterpreterOptions()
|
||||||
|
var options2 = InterpreterOptions()
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options1.threadCount = 2
|
||||||
|
options2.threadCount = 2
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options2.threadCount = 3
|
||||||
|
XCTAssertNotEqual(options1, options2)
|
||||||
|
options2.threadCount = 2
|
||||||
|
|
||||||
|
options1.isErrorLoggingEnabled = true
|
||||||
|
options2.isErrorLoggingEnabled = true
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options2.isErrorLoggingEnabled = false
|
||||||
|
XCTAssertNotEqual(options1, options2)
|
||||||
|
}
|
||||||
|
}
|
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
315
tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
class InterpreterTests: XCTestCase {
|
||||||
|
|
||||||
|
var interpreter: Interpreter!
|
||||||
|
|
||||||
|
override func setUp() {
|
||||||
|
super.setUp()
|
||||||
|
|
||||||
|
interpreter = try! Interpreter(modelPath: AddModel.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
override func tearDown() {
|
||||||
|
interpreter = nil
|
||||||
|
|
||||||
|
super.tearDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InitWithModelPath() {
|
||||||
|
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_Init_ThrowsFailedToLoadModel() {
|
||||||
|
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
||||||
|
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InitWithModelPathAndOptions() {
|
||||||
|
var options = InterpreterOptions()
|
||||||
|
options.threadCount = 2
|
||||||
|
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InputTensorCount() {
|
||||||
|
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_OutputTensorCount() {
|
||||||
|
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_Invoke() throws {
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
XCTAssertNoThrow(try interpreter.invoke())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||||
|
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
||||||
|
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InputTensorAtIndex() throws {
|
||||||
|
try setUpAddModelInputTensor()
|
||||||
|
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
||||||
|
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
|
||||||
|
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||||
|
try setUpAddQuantizedModelInputTensor()
|
||||||
|
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
||||||
|
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
||||||
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
|
self.assertEqualErrors(
|
||||||
|
actual: error,
|
||||||
|
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||||
|
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
||||||
|
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_OutputTensorAtIndex() throws {
|
||||||
|
try setUpAddModelInputTensor()
|
||||||
|
try interpreter.invoke()
|
||||||
|
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
||||||
|
XCTAssertEqual(outputTensor, AddModel.outputTensor)
|
||||||
|
let expectedResults = [Float32](unsafeData: outputTensor.data)
|
||||||
|
XCTAssertEqual(expectedResults, AddModel.results)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
|
||||||
|
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||||
|
try setUpAddQuantizedModelInputTensor()
|
||||||
|
try interpreter.invoke()
|
||||||
|
let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex)
|
||||||
|
XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor)
|
||||||
|
let expectedResults = [UInt8](outputTensor.data)
|
||||||
|
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
try interpreter.invoke()
|
||||||
|
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
||||||
|
let maxIndex = AddModel.outputTensorCount - 1
|
||||||
|
self.assertEqualErrors(
|
||||||
|
actual: error,
|
||||||
|
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||||
|
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
||||||
|
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_ResizeInputTensorAtIndexToShape() {
|
||||||
|
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
||||||
|
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||||
|
XCTAssertThrowsError(try interpreter.resizeInput(
|
||||||
|
at: AddModel.invalidIndex,
|
||||||
|
to: [2, 2, 3]
|
||||||
|
)) { error in
|
||||||
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
|
self.assertEqualErrors(
|
||||||
|
actual: error,
|
||||||
|
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
|
||||||
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||||
|
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||||
|
XCTAssertThrowsError(try interpreter.copy(
|
||||||
|
AddModel.inputData,
|
||||||
|
toInputAt: AddModel.invalidIndex
|
||||||
|
)) { error in
|
||||||
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
|
self.assertEqualErrors(
|
||||||
|
actual: error,
|
||||||
|
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||||
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||||
|
XCTAssertThrowsError(try interpreter.copy(
|
||||||
|
invalidData,
|
||||||
|
toInputAt: AddModel.validIndex
|
||||||
|
)) { error in
|
||||||
|
self.assertEqualErrors(
|
||||||
|
actual: error,
|
||||||
|
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterpreter_AllocateTensors() {
|
||||||
|
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private
|
||||||
|
|
||||||
|
private func setUpAddModelInputTensor() throws {
|
||||||
|
precondition(interpreter != nil)
|
||||||
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func setUpAddQuantizedModelInputTensor() throws {
|
||||||
|
precondition(interpreter != nil)
|
||||||
|
try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape)
|
||||||
|
try interpreter.allocateTensors()
|
||||||
|
try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func assertEqualErrors(actual: Error, expected: InterpreterError) {
|
||||||
|
guard let actual = actual as? InterpreterError else {
|
||||||
|
XCTFail("Actual error should be of type InterpreterError.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
XCTAssertEqual(actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Constants
|
||||||
|
|
||||||
|
/// Values for the `add.bin` model.
|
||||||
|
private enum AddModel {
|
||||||
|
static let info = (name: "add", extension: "bin")
|
||||||
|
static let inputTensorCount = 1
|
||||||
|
static let outputTensorCount = 1
|
||||||
|
static let invalidIndex = 1
|
||||||
|
static let validIndex = 0
|
||||||
|
static let shape: TensorShape = [2]
|
||||||
|
static let dataCount = inputData.count
|
||||||
|
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
||||||
|
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
||||||
|
static let results = [Float32(3.0), Float32(9.0)]
|
||||||
|
|
||||||
|
static let inputTensor = Tensor(
|
||||||
|
name: "input",
|
||||||
|
dataType: .float32,
|
||||||
|
shape: shape,
|
||||||
|
data: inputData
|
||||||
|
)
|
||||||
|
static let outputTensor = Tensor(
|
||||||
|
name: "output",
|
||||||
|
dataType: .float32,
|
||||||
|
shape: shape,
|
||||||
|
data: outputData
|
||||||
|
)
|
||||||
|
|
||||||
|
static var path: String = {
|
||||||
|
let bundle = Bundle(for: InterpreterTests.self)
|
||||||
|
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||||
|
return path
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Values for the `add_quantized.bin` model.
|
||||||
|
private enum AddQuantizedModel {
|
||||||
|
static let info = (name: "add_quantized", extension: "bin")
|
||||||
|
static let inputOutputIndex = 0
|
||||||
|
static let shape: TensorShape = [2]
|
||||||
|
static let inputData = Data([1, 3])
|
||||||
|
static let outputData = Data([3, 9])
|
||||||
|
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
||||||
|
static let results: [UInt8] = [3, 9]
|
||||||
|
|
||||||
|
static let inputTensor = Tensor(
|
||||||
|
name: "input",
|
||||||
|
dataType: .uInt8,
|
||||||
|
shape: shape,
|
||||||
|
data: inputData,
|
||||||
|
quantizationParameters: quantizationParameters
|
||||||
|
)
|
||||||
|
static let outputTensor = Tensor(
|
||||||
|
name: "output",
|
||||||
|
dataType: .uInt8,
|
||||||
|
shape: shape,
|
||||||
|
data: outputData,
|
||||||
|
quantizationParameters: quantizationParameters
|
||||||
|
)
|
||||||
|
|
||||||
|
static var path: String = {
|
||||||
|
let bundle = Bundle(for: InterpreterTests.self)
|
||||||
|
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
|
||||||
|
return path
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
extension Array {
|
||||||
|
/// Creates a new array from the bytes of the given unsafe data.
|
||||||
|
///
|
||||||
|
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||||
|
/// `MemoryLayout<Element>.stride`.
|
||||||
|
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||||
|
init?(unsafeData: Data) {
|
||||||
|
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||||
|
let elements = unsafeData.withUnsafeBytes {
|
||||||
|
UnsafeBufferPointer<Element>(
|
||||||
|
start: $0,
|
||||||
|
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.init(elements)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Data {
|
||||||
|
/// Creates a new buffer by copying the buffer pointer of the given array.
|
||||||
|
///
|
||||||
|
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
|
||||||
|
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
|
||||||
|
/// data from the resulting buffer has undefined behavior.
|
||||||
|
/// - Parameter array: An array with elements of type `T`.
|
||||||
|
init<T>(copyingBufferOf array: [T]) {
|
||||||
|
self = array.withUnsafeBufferPointer(Data.init)
|
||||||
|
}
|
||||||
|
}
|
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
59
tensorflow/lite/experimental/swift/Tests/ModelTests.swift
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
class ModelTests: XCTestCase {
|
||||||
|
|
||||||
|
var modelPath: String!
|
||||||
|
|
||||||
|
override func setUp() {
|
||||||
|
super.setUp()
|
||||||
|
|
||||||
|
let bundle = Bundle(for: type(of: self))
|
||||||
|
guard let modelPath = bundle.path(
|
||||||
|
forResource: Constant.modelInfo.name,
|
||||||
|
ofType: Constant.modelInfo.extension)
|
||||||
|
else {
|
||||||
|
XCTFail("Failed to get the model file path.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
self.modelPath = modelPath
|
||||||
|
}
|
||||||
|
|
||||||
|
override func tearDown() {
|
||||||
|
modelPath = nil
|
||||||
|
|
||||||
|
super.tearDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testModel_InitWithFilePath() {
|
||||||
|
XCTAssertNotNil(Model(filePath: modelPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testModel_InitWithEmptyFilePath_FailsInitialization() {
|
||||||
|
XCTAssertNil(Model(filePath: ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testModel_InitWithInvalidFilePath_FailsInitialization() {
|
||||||
|
XCTAssertNil(Model(filePath: "invalid/path"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Constants
|
||||||
|
|
||||||
|
private enum Constant {
|
||||||
|
static let modelInfo = (name: "add", extension: "bin")
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
class QuantizationParametersTests: XCTestCase {
|
||||||
|
|
||||||
|
func testQuantizationParameters_InitWithCustomValues() {
|
||||||
|
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
|
XCTAssertEqual(parameters.scale, 0.5)
|
||||||
|
XCTAssertEqual(parameters.zeroPoint, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizationParameters_Equatable() {
|
||||||
|
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
|
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
|
XCTAssertEqual(parameters1, parameters2)
|
||||||
|
|
||||||
|
let parameters3 = QuantizationParameters(scale: 0.4, zeroPoint: 1)
|
||||||
|
XCTAssertNotEqual(parameters1, parameters3)
|
||||||
|
XCTAssertNotEqual(parameters2, parameters3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
extension QuantizationParameters: Equatable {
|
||||||
|
public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool {
|
||||||
|
return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint
|
||||||
|
}
|
||||||
|
}
|
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
83
tensorflow/lite/experimental/swift/Tests/TensorTests.swift
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
// Copyright 2018 Google Inc. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
class TensorTests: XCTestCase {
|
||||||
|
|
||||||
|
// MARK: - Tensor
|
||||||
|
|
||||||
|
func testTensor_Init() {
|
||||||
|
let name = "InputTensor"
|
||||||
|
let dataType: TensorDataType = .uInt8
|
||||||
|
let shape = TensorShape(Constant.dimensions)
|
||||||
|
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||||
|
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
|
let inputTensor = Tensor(
|
||||||
|
name: name,
|
||||||
|
dataType: dataType,
|
||||||
|
shape: shape,
|
||||||
|
data: data,
|
||||||
|
quantizationParameters: quantizationParameters
|
||||||
|
)
|
||||||
|
XCTAssertEqual(inputTensor.name, name)
|
||||||
|
XCTAssertEqual(inputTensor.dataType, dataType)
|
||||||
|
XCTAssertEqual(inputTensor.shape, shape)
|
||||||
|
XCTAssertEqual(inputTensor.data, data)
|
||||||
|
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - TensorShape
|
||||||
|
|
||||||
|
func testTensorShape_InitWithArray() {
|
||||||
|
let shape = TensorShape(Constant.dimensions)
|
||||||
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTensorShape_InitWithElements() {
|
||||||
|
let shape = TensorShape(2, 2, 3)
|
||||||
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTensorShape_InitWithArrayLiteral() {
|
||||||
|
let shape: TensorShape = [2, 2, 3]
|
||||||
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Constants
|
||||||
|
|
||||||
|
private enum Constant {
|
||||||
|
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
|
||||||
|
static let dimensions = [2, 2, 3]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extensions
|
||||||
|
|
||||||
|
extension TensorShape: Equatable {
|
||||||
|
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
|
||||||
|
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Tensor: Equatable {
|
||||||
|
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
|
||||||
|
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
|
||||||
|
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
|
||||||
|
}
|
||||||
|
}
|
@ -30,14 +30,19 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
|||||||
PIP_PACKAGE_QUERY_EXPRESSION = (
|
PIP_PACKAGE_QUERY_EXPRESSION = (
|
||||||
"deps(//tensorflow/tools/pip_package:build_pip_package)")
|
"deps(//tensorflow/tools/pip_package:build_pip_package)")
|
||||||
|
|
||||||
|
# List of file paths containing BUILD files that should not be included for the
|
||||||
|
# pip smoke test.
|
||||||
|
BUILD_BLACKLIST = [
|
||||||
|
"tensorflow/lite/examples/android",
|
||||||
|
"tensorflow/lite/experimental/swift",
|
||||||
|
]
|
||||||
|
|
||||||
def GetBuild(dir_base):
|
def GetBuild(dir_base):
|
||||||
"""Get the list of BUILD file all targets recursively startind at dir_base."""
|
"""Get the list of BUILD file all targets recursively startind at dir_base."""
|
||||||
items = []
|
items = []
|
||||||
for root, _, files in os.walk(dir_base):
|
for root, _, files in os.walk(dir_base):
|
||||||
for name in files:
|
for name in files:
|
||||||
if (name == "BUILD" and
|
if (name == "BUILD" and root not in BUILD_BLACKLIST):
|
||||||
root.find("tensorflow/lite/examples/android") == -1):
|
|
||||||
items.append("//" + root + ":all")
|
items.append("//" + root + ":all")
|
||||||
return items
|
return items
|
||||||
|
|
||||||
@ -67,9 +72,9 @@ def BuildPyTestDependencies():
|
|||||||
|
|
||||||
PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies()
|
PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies()
|
||||||
|
|
||||||
# Hard-coded blacklist of files if not included in pip package
|
|
||||||
# TODO(amitpatankar): Clean up blacklist.
|
# TODO(amitpatankar): Clean up blacklist.
|
||||||
BLACKLIST = [
|
# List of dependencies that should not included in the pip package.
|
||||||
|
DEPENDENCY_BLACKLIST = [
|
||||||
"//tensorflow/python:extra_py_tests_deps",
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||||
"//tensorflow:no_tensorflow_py_deps",
|
"//tensorflow:no_tensorflow_py_deps",
|
||||||
@ -82,9 +87,7 @@ BLACKLIST = [
|
|||||||
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
||||||
"//tensorflow/python/feature_column:vocabulary_testdata",
|
"//tensorflow/python/feature_column:vocabulary_testdata",
|
||||||
"//tensorflow/python:framework/test_file_system.so",
|
"//tensorflow/python:framework/test_file_system.so",
|
||||||
# contrib
|
# lite
|
||||||
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
|
|
||||||
"//tensorflow/contrib/keras:testing_utils",
|
|
||||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm",
|
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm",
|
||||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm.py",
|
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm.py",
|
||||||
"//tensorflow/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long
|
"//tensorflow/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long
|
||||||
@ -93,6 +96,9 @@ BLACKLIST = [
|
|||||||
"//tensorflow/lite/python:interpreter_test",
|
"//tensorflow/lite/python:interpreter_test",
|
||||||
"//tensorflow/lite/python:interpreter.py",
|
"//tensorflow/lite/python:interpreter.py",
|
||||||
"//tensorflow/lite/python:interpreter_test.py",
|
"//tensorflow/lite/python:interpreter_test.py",
|
||||||
|
# contrib
|
||||||
|
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
|
||||||
|
"//tensorflow/contrib/keras:testing_utils",
|
||||||
"//tensorflow/contrib/ffmpeg:test_data",
|
"//tensorflow/contrib/ffmpeg:test_data",
|
||||||
"//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base",
|
"//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base",
|
||||||
"//tensorflow/contrib/hadoop:test_data",
|
"//tensorflow/contrib/hadoop:test_data",
|
||||||
@ -149,8 +155,8 @@ def main():
|
|||||||
# File extensions and endings to ignore
|
# File extensions and endings to ignore
|
||||||
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
|
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
|
||||||
|
|
||||||
ignored_files = 0
|
ignored_files_count = 0
|
||||||
blacklisted_files = len(BLACKLIST)
|
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
||||||
# Compare dependencies
|
# Compare dependencies
|
||||||
for dependency in tf_py_test_dependencies_list:
|
for dependency in tf_py_test_dependencies_list:
|
||||||
if dependency and dependency.startswith("//tensorflow"):
|
if dependency and dependency.startswith("//tensorflow"):
|
||||||
@ -158,16 +164,16 @@ def main():
|
|||||||
# Ignore extensions
|
# Ignore extensions
|
||||||
if any(dependency.endswith(ext) for ext in ignore_extensions):
|
if any(dependency.endswith(ext) for ext in ignore_extensions):
|
||||||
ignore = True
|
ignore = True
|
||||||
ignored_files += 1
|
ignored_files_count += 1
|
||||||
|
|
||||||
# Check if the dependency is in the pip package, the blacklist, or
|
# Check if the dependency is in the pip package, the dependency blacklist,
|
||||||
# should be ignored because of its file extension
|
# or should be ignored because of its file extension.
|
||||||
if not (ignore or dependency in pip_package_dependencies_list or
|
if not (ignore or dependency in pip_package_dependencies_list or
|
||||||
dependency in BLACKLIST):
|
dependency in DEPENDENCY_BLACKLIST):
|
||||||
missing_dependencies.append(dependency)
|
missing_dependencies.append(dependency)
|
||||||
|
|
||||||
print("Ignored files: %d" % ignored_files)
|
print("Ignored files count: %d" % ignored_files_count)
|
||||||
print("Blacklisted files: %d" % blacklisted_files)
|
print("Blacklisted dependencies count: %d" % blacklisted_dependencies_count)
|
||||||
if missing_dependencies:
|
if missing_dependencies:
|
||||||
print("Missing the following dependencies from pip_packages:")
|
print("Missing the following dependencies from pip_packages:")
|
||||||
for missing_dependency in missing_dependencies:
|
for missing_dependency in missing_dependencies:
|
||||||
|
Loading…
Reference in New Issue
Block a user