STT/native_client/swift/stt_ios/STT.swift
Leon Kiefer fab1bbad73
fixed duplicate deallocation of stream
streamCtx must be unset after STT_FreeStream was called in STT_FinishStreamWithMetadata, else STT_FreeStream is called again on destruction of STTStream resulting in EXC_BAD_ACCESS errors
2021-09-26 12:56:28 +02:00

458 lines
17 KiB
Swift

//
// STT.swift
// stt_ios
//
// Created by Reuben Morais on 14.06.20.
// Copyright © 2020 Mozilla
// Copyright © 2021 Coqui GmbH
import stt_ios.libstt_Private
public enum STTError: Error {
// Should be kept in sync with coqui-stt.h
case noModel(errorCode: Int32)
case invalidAlphabet(errorCode: Int32)
case invalidShape(errorCode: Int32)
case invalidScorer(errorCode: Int32)
case modelIncompatible(errorCode: Int32)
case scorerNotEnabled(errorCode: Int32)
case scorerUnreadable(errorCode: Int32)
case scorerInvalidLm(errorCode: Int32)
case scorerNoTrie(errorCode: Int32)
case scorerInvalidTrie(errorCode: Int32)
case scorerVersionMismatch(errorCode: Int32)
case failInitMmap(errorCode: Int32)
case failInitSess(errorCode: Int32)
case failInterpreter(errorCode: Int32)
case failRunSess(errorCode: Int32)
case failCreateStream(errorCode: Int32)
case failReadProtobuf(errorCode: Int32)
case failCreateSess(errorCode: Int32)
case failCreateModel(errorCode: Int32)
// Additional case for invalid error codes, should never happen unless the
// user has mixed header and binary versions.
case invalidErrorCode(errorCode: Int32)
}
extension STTError : LocalizedError {
public var errorDescription: String? {
switch self {
case .noModel(let errorCode),
.invalidAlphabet(let errorCode),
.invalidShape(let errorCode),
.invalidScorer(let errorCode),
.modelIncompatible(let errorCode),
.scorerNotEnabled(let errorCode),
.scorerUnreadable(let errorCode),
.scorerInvalidLm(let errorCode),
.scorerNoTrie(let errorCode),
.scorerInvalidTrie(let errorCode),
.scorerVersionMismatch(let errorCode),
.failInitMmap(let errorCode),
.failInitSess(let errorCode),
.failInterpreter(let errorCode),
.failRunSess(let errorCode),
.failCreateStream(let errorCode),
.failReadProtobuf(let errorCode),
.failCreateSess(let errorCode),
.failCreateModel(let errorCode),
.invalidErrorCode(let errorCode):
let result = STT_ErrorCodeToErrorMessage(errorCode)
defer { STT_FreeString(result) }
return String(cString: result!)
}
}
}
private func errorCodeToEnum(errorCode: Int32) -> STTError {
switch Int(errorCode) {
case Int(STT_ERR_NO_MODEL.rawValue):
return STTError.noModel(errorCode: errorCode)
case Int(STT_ERR_INVALID_ALPHABET.rawValue):
return STTError.invalidAlphabet(errorCode: errorCode)
case Int(STT_ERR_INVALID_SHAPE.rawValue):
return STTError.invalidShape(errorCode: errorCode)
case Int(STT_ERR_INVALID_SCORER.rawValue):
return STTError.invalidScorer(errorCode: errorCode)
case Int(STT_ERR_MODEL_INCOMPATIBLE.rawValue):
return STTError.modelIncompatible(errorCode: errorCode)
case Int(STT_ERR_SCORER_NOT_ENABLED.rawValue):
return STTError.scorerNotEnabled(errorCode: errorCode)
case Int(STT_ERR_SCORER_UNREADABLE.rawValue):
return STTError.scorerUnreadable(errorCode: errorCode)
case Int(STT_ERR_SCORER_INVALID_LM.rawValue):
return STTError.scorerInvalidLm(errorCode: errorCode)
case Int(STT_ERR_SCORER_NO_TRIE.rawValue):
return STTError.scorerNoTrie(errorCode: errorCode)
case Int(STT_ERR_SCORER_INVALID_TRIE.rawValue):
return STTError.scorerInvalidTrie(errorCode: errorCode)
case Int(STT_ERR_SCORER_VERSION_MISMATCH.rawValue):
return STTError.scorerVersionMismatch(errorCode: errorCode)
case Int(STT_ERR_FAIL_INIT_MMAP.rawValue):
return STTError.failInitMmap(errorCode: errorCode)
case Int(STT_ERR_FAIL_INIT_SESS.rawValue):
return STTError.failInitSess(errorCode: errorCode)
case Int(STT_ERR_FAIL_INTERPRETER.rawValue):
return STTError.failInterpreter(errorCode: errorCode)
case Int(STT_ERR_FAIL_RUN_SESS.rawValue):
return STTError.failRunSess(errorCode: errorCode)
case Int(STT_ERR_FAIL_CREATE_STREAM.rawValue):
return STTError.failCreateStream(errorCode: errorCode)
case Int(STT_ERR_FAIL_READ_PROTOBUF.rawValue):
return STTError.failReadProtobuf(errorCode: errorCode)
case Int(STT_ERR_FAIL_CREATE_SESS.rawValue):
return STTError.failCreateSess(errorCode: errorCode)
case Int(STT_ERR_FAIL_CREATE_MODEL.rawValue):
return STTError.failCreateModel(errorCode: errorCode)
default:
return STTError.invalidErrorCode(errorCode: errorCode)
}
}
private func evaluateErrorCode(errorCode: Int32) throws {
if errorCode != Int32(STT_ERR_OK.rawValue) {
throw errorCodeToEnum(errorCode: errorCode)
}
}
/// Stores text of an individual token, along with its timing information
public struct STTTokenMetadata {
/// The text corresponding to this token
public let text: String
/// Position of the token in units of 20ms
public let timestep: Int
/// Position of the token in seconds
public let startTime: Float
internal init(fromInternal: TokenMetadata) {
text = String(cString: fromInternal.text)
timestep = Int(fromInternal.timestep)
startTime = fromInternal.start_time
}
}
/** A single transcript computed by the model, including a confidence value and
the metadata for its constituent tokens
*/
public struct STTCandidateTranscript {
/// Array of STTTokenMetadata objects
public private(set) var tokens: [STTTokenMetadata] = []
/** Approximated confidence value for this transcript. This corresponds to
both acoustic model and language model scores that contributed to the
creation of this transcript.
*/
let confidence: Double
internal init(fromInternal: CandidateTranscript) {
let tokensBuffer = UnsafeBufferPointer<TokenMetadata>(start: fromInternal.tokens, count: Int(fromInternal.num_tokens))
for tok in tokensBuffer {
tokens.append(STTTokenMetadata(fromInternal: tok))
}
confidence = fromInternal.confidence
}
}
/// An array of STTCandidateTranscript objects computed by the model
public struct STTMetadata {
/// Array of STTCandidateTranscript objects
public private(set) var transcripts: [STTCandidateTranscript] = []
internal init(fromInternal: UnsafeMutablePointer<Metadata>) {
let md = fromInternal.pointee
let transcriptsBuffer = UnsafeBufferPointer<CandidateTranscript>(
start: md.transcripts,
count: Int(md.num_transcripts))
for tr in transcriptsBuffer {
transcripts.append(STTCandidateTranscript(fromInternal: tr))
}
}
}
public class STTStream {
private var streamCtx: OpaquePointer!
internal init(streamContext: OpaquePointer) {
streamCtx = streamContext
}
deinit {
if streamCtx != nil {
STT_FreeStream(streamCtx)
streamCtx = nil
}
}
/** Feed audio samples to an ongoing streaming inference.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Precondition: `finishStream()` has not been called on this stream.
*/
public func feedAudioContent(buffer: Array<Int16>) {
precondition(streamCtx != nil, "calling method on invalidated Stream")
buffer.withUnsafeBufferPointer { unsafeBufferPointer in
feedAudioContent(buffer: unsafeBufferPointer)
}
}
/** Feed audio samples to an ongoing streaming inference.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Precondition: `finishStream()` has not been called on this stream.
*/
public func feedAudioContent(buffer: UnsafeBufferPointer<Int16>) {
precondition(streamCtx != nil, "calling method on invalidated Stream")
STT_FeedAudioContent(streamCtx, buffer.baseAddress, UInt32(buffer.count))
}
/** Compute the intermediate decoding of an ongoing streaming inference.
- Precondition: `finishStream()` has not been called on this stream.
- Returns: The STT intermediate result.
*/
public func intermediateDecode() -> String {
precondition(streamCtx != nil, "calling method on invalidated Stream")
let result = STT_IntermediateDecode(streamCtx)
defer { STT_FreeString(result) }
return String(cString: result!)
}
/** Compute the intermediate decoding of an ongoing streaming inference,
return results including metadata.
- Parameter numResults: The number of candidate transcripts to return.
- Precondition: `finishStream()` has not been called on this stream.
- Returns: Metadata struct containing multiple CandidateTranscript structs.
Each transcript has per-token metadata including timing information.
*/
public func intermediateDecodeWithMetadata(numResults: Int) -> STTMetadata {
precondition(streamCtx != nil, "calling method on invalidated Stream")
let result = STT_IntermediateDecodeWithMetadata(streamCtx, UInt32(numResults))!
defer { STT_FreeMetadata(result) }
return STTMetadata(fromInternal: result)
}
/** Compute the final decoding of an ongoing streaming inference and return
the result. Signals the end of an ongoing streaming inference.
- Precondition: `finishStream()` has not been called on this stream.
- Returns: The STT result.
- Postcondition: This method will invalidate this streaming context.
*/
public func finishStream() -> String {
precondition(streamCtx != nil, "calling method on invalidated Stream")
let result = STT_FinishStream(streamCtx)
defer {
STT_FreeString(result)
streamCtx = nil
}
return String(cString: result!)
}
/** Compute the final decoding of an ongoing streaming inference and return
results including metadata. Signals the end of an ongoing streaming
inference.
- Parameter numResults: The number of candidate transcripts to return.
- Precondition: `finishStream()` has not been called on this stream.
- Returns: Metadata struct containing multiple CandidateTranscript structs.
Each transcript has per-token metadata including timing information.
- Postcondition: This method will invalidate this streaming context.
*/
public func finishStreamWithMetadata(numResults: Int) -> STTMetadata {
precondition(streamCtx != nil, "calling method on invalidated Stream")
let result = STT_FinishStreamWithMetadata(streamCtx, UInt32(numResults))!
defer {
STT_FreeMetadata(result)
streamCtx = nil
}
return STTMetadata(fromInternal: result)
}
}
/// An object providing an interface to a trained STT model.
public class STTModel {
private var modelCtx: OpaquePointer!
/**
- Parameter modelPath: The path to the model file.
- Throws: `STTError` on failure.
*/
public init(modelPath: String) throws {
let err = STT_CreateModel(modelPath, &modelCtx)
try evaluateErrorCode(errorCode: err)
}
deinit {
STT_FreeModel(modelCtx)
modelCtx = nil
}
/** Get beam width value used by the model. If {@link STT_SetModelBeamWidth}
was not called before, will return the default value loaded from the
model file.
- Returns: Beam width value used by the model.
*/
public func getBeamWidth() -> Int {
return Int(STT_GetModelBeamWidth(modelCtx))
}
/** Set beam width value used by the model.
- Parameter beamWidth: The beam width used by the model. A larger beam
width value generates better results at the cost
of decoding time.
- Throws: `STTError` on failure.
*/
public func setBeamWidth(beamWidth: Int) throws {
let err = STT_SetModelBeamWidth(modelCtx, UInt32(beamWidth))
try evaluateErrorCode(errorCode: err)
}
// The sample rate expected by the model.
public var sampleRate: Int {
get {
return Int(STT_GetModelSampleRate(modelCtx))
}
}
/** Enable decoding using an external scorer.
- Parameter scorerPath: The path to the external scorer file.
- Throws: `STTError` on failure.
*/
public func enableExternalScorer(scorerPath: String) throws {
let err = STT_EnableExternalScorer(modelCtx, scorerPath)
try evaluateErrorCode(errorCode: err)
}
/** Disable decoding using an external scorer.
- Throws: `STTError` on failure.
*/
public func disableExternalScorer() throws {
let err = STT_DisableExternalScorer(modelCtx)
try evaluateErrorCode(errorCode: err)
}
/** Set hyperparameters alpha and beta of the external scorer.
- Parameter alpha: The alpha hyperparameter of the decoder. Language model weight.
- Parameter beta: The beta hyperparameter of the decoder. Word insertion weight.
- Throws: `STTError` on failure.
*/
public func setScorerAlphaBeta(alpha: Float, beta: Float) throws {
let err = STT_SetScorerAlphaBeta(modelCtx, alpha, beta)
try evaluateErrorCode(errorCode: err)
}
/** Use the STT model to convert speech to text.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Returns: The STT result.
*/
public func speechToText(buffer: Array<Int16>) -> String {
return buffer.withUnsafeBufferPointer { unsafeBufferPointer -> String in
return speechToText(buffer: unsafeBufferPointer)
}
}
/** Use the STT model to convert speech to text.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Returns: The STT result.
*/
public func speechToText(buffer: UnsafeBufferPointer<Int16>) -> String {
let result = STT_SpeechToText(modelCtx, buffer.baseAddress, UInt32(buffer.count))
defer { STT_FreeString(result) }
return String(cString: result!)
}
/** Use the STT model to convert speech to text and output results
including metadata.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Parameter numResults: The maximum number of STTCandidateTranscript
structs to return. Returned value might be smaller than this.
- Returns: Metadata struct containing multiple CandidateTranscript structs.
Each transcript has per-token metadata including timing information.
*/
public func speechToTextWithMetadata(buffer: Array<Int16>, numResults: Int) -> STTMetadata {
return buffer.withUnsafeBufferPointer { unsafeBufferPointer -> STTMetadata in
return speechToTextWithMetadata(buffer: unsafeBufferPointer, numResults: numResults)
}
}
/** Use the STT model to convert speech to text and output results
including metadata.
- Parameter buffer: A 16-bit, mono raw audio signal at the appropriate
sample rate (matching what the model was trained on).
- Parameter numResults: The maximum number of STTCandidateTranscript
structs to return. Returned value might be smaller than this.
- Returns: Metadata struct containing multiple CandidateTranscript structs.
Each transcript has per-token metadata including timing information.
*/
public func speechToTextWithMetadata(buffer: UnsafeBufferPointer<Int16>, numResults: Int) -> STTMetadata {
let result = STT_SpeechToTextWithMetadata(
modelCtx,
buffer.baseAddress,
UInt32(buffer.count),
UInt32(numResults))!
defer { STT_FreeMetadata(result) }
return STTMetadata(fromInternal: result)
}
/** Create a new streaming inference state.
- Returns: STTStream object representing the streaming state.
- Throws: `STTError` on failure.
*/
public func createStream() throws -> STTStream {
var streamContext: OpaquePointer!
let err = STT_CreateStream(modelCtx, &streamContext)
try evaluateErrorCode(errorCode: err)
return STTStream(streamContext: streamContext)
}
}
public func STTVersion() -> String {
let result = STT_Version()
defer { STT_FreeString(result) }
return String(cString: result!)
}