From 6e88a37ad4367f1481e29472bf0a299881e96e63 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 25 Feb 2020 12:50:06 +0100 Subject: [PATCH] Adapt Python bindings to new API --- native_client/python/__init__.py | 65 +++++++++++++++++++----------- native_client/python/client.py | 31 +++++++------- native_client/python/impl.i | 69 ++++++++++++++++++++++++++------ 3 files changed, 116 insertions(+), 49 deletions(-) diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index a6511efe..5d9072ec 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -121,17 +121,20 @@ class Model(object): """ return deepspeech.impl.SpeechToText(self._impl, audio_buffer) - def sttWithMetadata(self, audio_buffer): + def sttWithMetadata(self, audio_buffer, num_results=1): """ Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results. :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :type audio_buffer: numpy.int16 array + :param num_results: Number of candidate transcripts to return. + :type num_results: int + :return: Outputs a struct of individual letters along with their timing information. :type: :func:`Metadata` """ - return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer) + return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results) def createStream(self): """ @@ -187,6 +190,19 @@ class Stream(object): raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") return deepspeech.impl.IntermediateDecode(self._impl) + def intermediateDecodeWithMetadata(self, num_results=1): + """ + Compute the intermediate decoding of an ongoing streaming inference. + + :return: The STT intermediate result. + :type: str + + :throws: RuntimeError if the stream object is not valid + """ + if not self._impl: + raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") + return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results) + def finishStream(self): """ Signal the end of an audio signal to an ongoing streaming inference, @@ -203,11 +219,14 @@ class Stream(object): self._impl = None return result - def finishStreamWithMetadata(self): + def finishStreamWithMetadata(self, num_results=1): """ Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata. + :param num_results: Number of candidate transcripts to return. + :type num_results: int + :return: Outputs a struct of individual letters along with their timing information. :type: :func:`Metadata` @@ -215,7 +234,7 @@ class Stream(object): """ if not self._impl: raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") - result = deepspeech.impl.FinishStreamWithMetadata(self._impl) + result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results) self._impl = None return result @@ -233,52 +252,43 @@ class Stream(object): # This is only for documentation purpose -# Metadata and MetadataItem should be in sync with native_client/deepspeech.h -class MetadataItem(object): +# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h +class TokenMetadata(object): """ Stores each individual character, along with its timing information """ - def character(self): + def text(self): """ - The character generated for transcription + The text for this token """ def timestep(self): """ - Position of the character in units of 20ms + Position of the token in units of 20ms """ def start_time(self): """ - Position of the character in seconds + Position of the token in seconds """ -class Metadata(object): +class CandidateTranscript(object): """ Stores the entire CTC output as an array of character metadata objects """ - def items(self): + def tokens(self): """ - List of items + List of tokens - :return: A list of :func:`MetadataItem` elements + :return: A list of :func:`TokenMetadata` elements :type: list """ - def num_items(self): - """ - Size of the list of items - - :return: Size of the list of items - :type: int - """ - - def confidence(self): """ Approximated confidence value for this transcription. This is roughly the @@ -286,3 +296,12 @@ class Metadata(object): contributed to the creation of this transcription. """ + +class Metadata(object): + def transcripts(self): + """ + List of candidate transcripts + + :return: A list of :func:`CandidateTranscript` objects + :type: list + """ diff --git a/native_client/python/client.py b/native_client/python/client.py index 671968b9..00fa2ff6 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -18,6 +18,7 @@ try: except ImportError: from pipes import quote + def convert_samplerate(audio_path, desired_sample_rate): sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) try: @@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate): def metadata_to_string(metadata): - return ''.join(item.character for item in metadata.items) + return ''.join(token.text for token in metadata.tokens) -def words_from_metadata(metadata): + +def words_from_candidate_transcript(metadata): word = "" word_list = [] word_start_time = 0 # Loop through each character - for i in range(0, metadata.num_items): - item = metadata.items[i] + for i, token in enumerate(metadata.tokens): # Append character to word if it's not a space - if item.character != " ": + if token.text != " ": if len(word) == 0: # Log the start time of the new word - word_start_time = item.start_time + word_start_time = token.start_time - word = word + item.character + word = word + token.text # Word boundary is either a space or the last character in the array - if item.character == " " or i == metadata.num_items - 1: - word_duration = item.start_time - word_start_time + if token.text == " " or i == len(metadata.tokens) - 1: + word_duration = token.start_time - word_start_time if word_duration < 0: word_duration = 0 @@ -69,9 +70,11 @@ def words_from_metadata(metadata): def metadata_json_output(metadata): json_result = dict() - json_result["words"] = words_from_metadata(metadata) - json_result["confidence"] = metadata.confidence - return json.dumps(json_result) + json_result["transcripts"] = [{ + "confidence": transcript.confidence, + "words": words_from_candidate_transcript(transcript), + } for transcript in metadata.transcripts] + return json.dumps(json_result, indent=2) @@ -141,9 +144,9 @@ def main(): print('Running inference.', file=sys.stderr) inference_start = timer() if args.extended: - print(metadata_to_string(ds.sttWithMetadata(audio))) + print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0])) elif args.json: - print(metadata_json_output(ds.sttWithMetadata(audio))) + print(metadata_json_output(ds.sttWithMetadata(audio, 3))) else: print(ds.stt(audio)) inference_end = timer() - inference_start diff --git a/native_client/python/impl.i b/native_client/python/impl.i index d6c7ba19..001a6165 100644 --- a/native_client/python/impl.i +++ b/native_client/python/impl.i @@ -38,30 +38,69 @@ import_array(); %append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN)); } -%typemap(out) MetadataItem* %{ - $result = PyList_New(arg1->num_items); - for (int i = 0; i < arg1->num_items; ++i) { - PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0); +%fragment("parent_reference_init", "init") { + // Thread-safe initialization - initialize during Python module initialization + parent_reference(); +} + +%fragment("parent_reference_function", "header", fragment="parent_reference_init") { + +static PyObject *parent_reference() { + static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference"); + return parent_reference_string; +} + +} + +%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{ + $result = PyList_New(arg1->num_transcripts); + for (int i = 0; i < arg1->num_transcripts; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0); + // Add a reference to Metadata in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); PyList_SetItem($result, i, o); } %} -%extend struct MetadataItem { +%typemap(out, fragment="parent_reference_function") TokenMetadata* %{ + $result = PyList_New(arg1->num_tokens); + for (int i = 0; i < arg1->num_tokens; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0); + // Add a reference to CandidateTranscript in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); + PyList_SetItem($result, i, o); + } +%} + +%extend struct TokenMetadata { %pythoncode %{ def __repr__(self): - return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time) + return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time) +%} +} + +%extend struct CandidateTranscript { +%pythoncode %{ + def __repr__(self): + tokens_repr = ',\n'.join(repr(i) for i in self.tokens) + tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n')) + return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr) %} } %extend struct Metadata { %pythoncode %{ def __repr__(self): - items_repr = ', \n'.join(' ' + repr(i) for i in self.items) - return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr) + transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts) + transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n')) + return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr) %} } -%ignore Metadata::num_items; +%ignore Metadata::num_transcripts; +%ignore CandidateTranscript::num_tokens; %extend struct Metadata { ~Metadata() { @@ -69,10 +108,16 @@ import_array(); } } -%nodefaultdtor Metadata; +%immutable Metadata::transcripts; +%immutable CandidateTranscript::tokens; +%immutable TokenMetadata::text; + %nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; %typemap(newfree) char* "DS_FreeString($1);";