Adapt Python bindings to new API

This commit is contained in:
Reuben Morais 2020-02-25 12:50:06 +01:00
parent c74dcffe79
commit 6e88a37ad4
3 changed files with 116 additions and 49 deletions

View File

@ -121,17 +121,20 @@ class Model(object):
""" """
return deepspeech.impl.SpeechToText(self._impl, audio_buffer) return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, audio_buffer): def sttWithMetadata(self, audio_buffer, num_results=1):
""" """
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results. Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array :type audio_buffer: numpy.int16 array
:param num_results: Number of candidate transcripts to return.
:type num_results: int
:return: Outputs a struct of individual letters along with their timing information. :return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata` :type: :func:`Metadata`
""" """
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer) return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
def createStream(self): def createStream(self):
""" """
@ -187,6 +190,19 @@ class Stream(object):
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecode(self._impl) return deepspeech.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference.
:return: The STT intermediate result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self): def finishStream(self):
""" """
Signal the end of an audio signal to an ongoing streaming inference, Signal the end of an audio signal to an ongoing streaming inference,
@ -203,11 +219,14 @@ class Stream(object):
self._impl = None self._impl = None
return result return result
def finishStreamWithMetadata(self): def finishStreamWithMetadata(self, num_results=1):
""" """
Signal the end of an audio signal to an ongoing streaming inference, Signal the end of an audio signal to an ongoing streaming inference,
returns per-letter metadata. returns per-letter metadata.
:param num_results: Number of candidate transcripts to return.
:type num_results: int
:return: Outputs a struct of individual letters along with their timing information. :return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata` :type: :func:`Metadata`
@ -215,7 +234,7 @@ class Stream(object):
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStreamWithMetadata(self._impl) result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None self._impl = None
return result return result
@ -233,52 +252,43 @@ class Stream(object):
# This is only for documentation purpose # This is only for documentation purpose
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h # Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h
class MetadataItem(object): class TokenMetadata(object):
""" """
Stores each individual character, along with its timing information Stores each individual character, along with its timing information
""" """
def character(self): def text(self):
""" """
The character generated for transcription The text for this token
""" """
def timestep(self): def timestep(self):
""" """
Position of the character in units of 20ms Position of the token in units of 20ms
""" """
def start_time(self): def start_time(self):
""" """
Position of the character in seconds Position of the token in seconds
""" """
class Metadata(object): class CandidateTranscript(object):
""" """
Stores the entire CTC output as an array of character metadata objects Stores the entire CTC output as an array of character metadata objects
""" """
def items(self): def tokens(self):
""" """
List of items List of tokens
:return: A list of :func:`MetadataItem` elements :return: A list of :func:`TokenMetadata` elements
:type: list :type: list
""" """
def num_items(self):
"""
Size of the list of items
:return: Size of the list of items
:type: int
"""
def confidence(self): def confidence(self):
""" """
Approximated confidence value for this transcription. This is roughly the Approximated confidence value for this transcription. This is roughly the
@ -286,3 +296,12 @@ class Metadata(object):
contributed to the creation of this transcription. contributed to the creation of this transcription.
""" """
class Metadata(object):
def transcripts(self):
"""
List of candidate transcripts
:return: A list of :func:`CandidateTranscript` objects
:type: list
"""

View File

@ -18,6 +18,7 @@ try:
except ImportError: except ImportError:
from pipes import quote from pipes import quote
def convert_samplerate(audio_path, desired_sample_rate): def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
try: try:
@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate):
def metadata_to_string(metadata): def metadata_to_string(metadata):
return ''.join(item.character for item in metadata.items) return ''.join(token.text for token in metadata.tokens)
def words_from_metadata(metadata):
def words_from_candidate_transcript(metadata):
word = "" word = ""
word_list = [] word_list = []
word_start_time = 0 word_start_time = 0
# Loop through each character # Loop through each character
for i in range(0, metadata.num_items): for i, token in enumerate(metadata.tokens):
item = metadata.items[i]
# Append character to word if it's not a space # Append character to word if it's not a space
if item.character != " ": if token.text != " ":
if len(word) == 0: if len(word) == 0:
# Log the start time of the new word # Log the start time of the new word
word_start_time = item.start_time word_start_time = token.start_time
word = word + item.character word = word + token.text
# Word boundary is either a space or the last character in the array # Word boundary is either a space or the last character in the array
if item.character == " " or i == metadata.num_items - 1: if token.text == " " or i == len(metadata.tokens) - 1:
word_duration = item.start_time - word_start_time word_duration = token.start_time - word_start_time
if word_duration < 0: if word_duration < 0:
word_duration = 0 word_duration = 0
@ -69,9 +70,11 @@ def words_from_metadata(metadata):
def metadata_json_output(metadata): def metadata_json_output(metadata):
json_result = dict() json_result = dict()
json_result["words"] = words_from_metadata(metadata) json_result["transcripts"] = [{
json_result["confidence"] = metadata.confidence "confidence": transcript.confidence,
return json.dumps(json_result) "words": words_from_candidate_transcript(transcript),
} for transcript in metadata.transcripts]
return json.dumps(json_result, indent=2)
@ -141,9 +144,9 @@ def main():
print('Running inference.', file=sys.stderr) print('Running inference.', file=sys.stderr)
inference_start = timer() inference_start = timer()
if args.extended: if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio))) print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
elif args.json: elif args.json:
print(metadata_json_output(ds.sttWithMetadata(audio))) print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
else: else:
print(ds.stt(audio)) print(ds.stt(audio))
inference_end = timer() - inference_start inference_end = timer() - inference_start

View File

@ -38,30 +38,69 @@ import_array();
%append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN)); %append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN));
} }
%typemap(out) MetadataItem* %{ %fragment("parent_reference_init", "init") {
$result = PyList_New(arg1->num_items); // Thread-safe initialization - initialize during Python module initialization
for (int i = 0; i < arg1->num_items; ++i) { parent_reference();
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0); }
%fragment("parent_reference_function", "header", fragment="parent_reference_init") {
static PyObject *parent_reference() {
static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference");
return parent_reference_string;
}
}
%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{
$result = PyList_New(arg1->num_transcripts);
for (int i = 0; i < arg1->num_transcripts; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0);
// Add a reference to Metadata in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o); PyList_SetItem($result, i, o);
} }
%} %}
%extend struct MetadataItem { %typemap(out, fragment="parent_reference_function") TokenMetadata* %{
$result = PyList_New(arg1->num_tokens);
for (int i = 0; i < arg1->num_tokens; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0);
// Add a reference to CandidateTranscript in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o);
}
%}
%extend struct TokenMetadata {
%pythoncode %{ %pythoncode %{
def __repr__(self): def __repr__(self):
return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time) return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time)
%}
}
%extend struct CandidateTranscript {
%pythoncode %{
def __repr__(self):
tokens_repr = ',\n'.join(repr(i) for i in self.tokens)
tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n'))
return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr)
%} %}
} }
%extend struct Metadata { %extend struct Metadata {
%pythoncode %{ %pythoncode %{
def __repr__(self): def __repr__(self):
items_repr = ', \n'.join(' ' + repr(i) for i in self.items) transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts)
return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr) transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n'))
return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr)
%} %}
} }
%ignore Metadata::num_items; %ignore Metadata::num_transcripts;
%ignore CandidateTranscript::num_tokens;
%extend struct Metadata { %extend struct Metadata {
~Metadata() { ~Metadata() {
@ -69,10 +108,16 @@ import_array();
} }
} }
%nodefaultdtor Metadata; %immutable Metadata::transcripts;
%immutable CandidateTranscript::tokens;
%immutable TokenMetadata::text;
%nodefaultctor Metadata; %nodefaultctor Metadata;
%nodefaultctor MetadataItem; %nodefaultdtor Metadata;
%nodefaultdtor MetadataItem; %nodefaultctor CandidateTranscript;
%nodefaultdtor CandidateTranscript;
%nodefaultctor TokenMetadata;
%nodefaultdtor TokenMetadata;
%typemap(newfree) char* "DS_FreeString($1);"; %typemap(newfree) char* "DS_FreeString($1);";