Removes the isErrorLoggingEnabled
property from the InterpreterOptions
struct. Error logging is enabled by default to match ObjC behavior.
PiperOrigin-RevId: 248392591
This commit is contained in:
parent
45d1a97841
commit
18ff53fed4
@ -39,7 +39,7 @@ public final class Interpreter {
|
|||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - modelPath: Local file path to a TensorFlow Lite model.
|
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||||
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that
|
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that
|
||||||
/// interpreter will determine the configuration options.
|
/// the interpreter will determine the configuration options.
|
||||||
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||||
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
||||||
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||||
@ -51,23 +51,21 @@ public final class Interpreter {
|
|||||||
if let threadCount = options.threadCount, threadCount > 0 {
|
if let threadCount = options.threadCount, threadCount > 0 {
|
||||||
TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount))
|
TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount))
|
||||||
}
|
}
|
||||||
if options.isErrorLoggingEnabled {
|
TFL_InterpreterOptionsSetErrorReporter(
|
||||||
TFL_InterpreterOptionsSetErrorReporter(
|
cOptions,
|
||||||
cOptions,
|
{ (_, format, args) -> Void in
|
||||||
{ (_, format, args) -> Void in
|
// Workaround for Swift optionality bug: https://bugs.swift.org/browse/SR-3429.
|
||||||
// Workaround for Swift optionality bug: https://bugs.swift.org/browse/SR-3429.
|
let optionalArgs: CVaListPointer? = args
|
||||||
let optionalArgs: CVaListPointer? = args
|
guard let cFormat = format,
|
||||||
guard let cFormat = format,
|
let arguments = optionalArgs,
|
||||||
let arguments = optionalArgs,
|
let message = String(cFormat: cFormat, arguments: arguments)
|
||||||
let message = String(cFormat: cFormat, arguments: arguments)
|
else {
|
||||||
else {
|
return
|
||||||
return
|
}
|
||||||
}
|
print(String(describing: InterpreterError.tensorFlowLiteError(message)))
|
||||||
print(String(describing: InterpreterError.tensorFlowLiteError(message)))
|
},
|
||||||
},
|
nil
|
||||||
nil
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
return cOptions
|
return cOptions
|
||||||
}
|
}
|
||||||
defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) }
|
defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) }
|
||||||
|
@ -40,7 +40,7 @@ extension InterpreterError: LocalizedError {
|
|||||||
case .invalidTensorDataCount(let providedCount, let requiredCount):
|
case .invalidTensorDataCount(let providedCount, let requiredCount):
|
||||||
return "Provided data count \(providedCount) must match the required count \(requiredCount)."
|
return "Provided data count \(providedCount) must match the required count \(requiredCount)."
|
||||||
case .invalidTensorDataType:
|
case .invalidTensorDataType:
|
||||||
return "Tensor data type is unsupported or could not be determined because of a model error."
|
return "Tensor data type is unsupported or could not be determined due to a model error."
|
||||||
case .failedToLoadModel:
|
case .failedToLoadModel:
|
||||||
return "Failed to load the given model."
|
return "Failed to load the given model."
|
||||||
case .failedToCreateInterpreter:
|
case .failedToCreateInterpreter:
|
||||||
|
@ -12,16 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
/// Custom configuration options for a TensorFlow Lite interpreter.
|
/// Custom configuration options for a TensorFlow Lite `Interpreter`.
|
||||||
public struct InterpreterOptions: Equatable {
|
public struct InterpreterOptions: Equatable {
|
||||||
|
|
||||||
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
|
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which
|
||||||
/// indicates that the `Interpreter` will decide the number of threads to use.
|
/// indicates that the `Interpreter` will decide the number of threads to use.
|
||||||
public var threadCount: Int? = nil
|
public var threadCount: Int? = nil
|
||||||
|
|
||||||
/// Whether error logging to the console is enabled. The default is `false`.
|
|
||||||
public var isErrorLoggingEnabled = false
|
|
||||||
|
|
||||||
/// Creates a new instance of interpreter options.
|
/// Creates a new instance of interpreter options.
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
@ -103,9 +103,7 @@ class ViewController: UIViewController {
|
|||||||
private func setUpInterpreter(withModelPath modelPath: String) {
|
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||||
interpreterQueue.async {
|
interpreterQueue.async {
|
||||||
do {
|
do {
|
||||||
var options = InterpreterOptions()
|
self.interpreter = try Interpreter(modelPath: modelPath)
|
||||||
options.isErrorLoggingEnabled = true
|
|
||||||
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
|
||||||
} catch let error {
|
} catch let error {
|
||||||
self.updateResultsText(
|
self.updateResultsText(
|
||||||
"Failed to create the interpreter with error: \(error.localizedDescription)"
|
"Failed to create the interpreter with error: \(error.localizedDescription)"
|
||||||
|
@ -20,15 +20,12 @@ class InterpreterOptionsTests: XCTestCase {
|
|||||||
func testInterpreterOptions_InitWithDefaultValues() {
|
func testInterpreterOptions_InitWithDefaultValues() {
|
||||||
let options = InterpreterOptions()
|
let options = InterpreterOptions()
|
||||||
XCTAssertNil(options.threadCount)
|
XCTAssertNil(options.threadCount)
|
||||||
XCTAssertFalse(options.isErrorLoggingEnabled)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreterOptions_InitWithCustomValues() {
|
func testInterpreterOptions_InitWithCustomValues() {
|
||||||
var options = InterpreterOptions()
|
var options = InterpreterOptions()
|
||||||
options.threadCount = 2
|
options.threadCount = 2
|
||||||
XCTAssertEqual(options.threadCount, 2)
|
XCTAssertEqual(options.threadCount, 2)
|
||||||
options.isErrorLoggingEnabled = true
|
|
||||||
XCTAssertTrue(options.isErrorLoggingEnabled)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreterOptions_Equatable() {
|
func testInterpreterOptions_Equatable() {
|
||||||
@ -42,13 +39,5 @@ class InterpreterOptionsTests: XCTestCase {
|
|||||||
|
|
||||||
options2.threadCount = 3
|
options2.threadCount = 3
|
||||||
XCTAssertNotEqual(options1, options2)
|
XCTAssertNotEqual(options1, options2)
|
||||||
options2.threadCount = 2
|
|
||||||
|
|
||||||
options1.isErrorLoggingEnabled = true
|
|
||||||
options2.isErrorLoggingEnabled = true
|
|
||||||
XCTAssertEqual(options1, options2)
|
|
||||||
|
|
||||||
options2.isErrorLoggingEnabled = false
|
|
||||||
XCTAssertNotEqual(options1, options2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user