Adapt Python bindings to new API

This commit is contained in:
Reuben Morais 2020-02-25 12:50:06 +01:00
parent c74dcffe79
commit 6e88a37ad4
3 changed files with 116 additions and 49 deletions

View File

@ -121,17 +121,20 @@ class Model(object):
"""
return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, audio_buffer):
def sttWithMetadata(self, audio_buffer, num_results=1):
"""
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:param num_results: Number of candidate transcripts to return.
:type num_results: int
:return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata`
"""
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer)
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
def createStream(self):
"""
@ -187,6 +190,19 @@ class Stream(object):
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference.
:return: The STT intermediate result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self):
"""
Signal the end of an audio signal to an ongoing streaming inference,
@ -203,11 +219,14 @@ class Stream(object):
self._impl = None
return result
def finishStreamWithMetadata(self):
def finishStreamWithMetadata(self, num_results=1):
"""
Signal the end of an audio signal to an ongoing streaming inference,
returns per-letter metadata.
:param num_results: Number of candidate transcripts to return.
:type num_results: int
:return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata`
@ -215,7 +234,7 @@ class Stream(object):
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStreamWithMetadata(self._impl)
result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None
return result
@ -233,52 +252,43 @@ class Stream(object):
# This is only for documentation purpose
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h
class MetadataItem(object):
# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h
class TokenMetadata(object):
"""
Stores each individual character, along with its timing information
"""
def character(self):
def text(self):
"""
The character generated for transcription
The text for this token
"""
def timestep(self):
"""
Position of the character in units of 20ms
Position of the token in units of 20ms
"""
def start_time(self):
"""
Position of the character in seconds
Position of the token in seconds
"""
class Metadata(object):
class CandidateTranscript(object):
"""
Stores the entire CTC output as an array of character metadata objects
"""
def items(self):
def tokens(self):
"""
List of items
List of tokens
:return: A list of :func:`MetadataItem` elements
:return: A list of :func:`TokenMetadata` elements
:type: list
"""
def num_items(self):
"""
Size of the list of items
:return: Size of the list of items
:type: int
"""
def confidence(self):
"""
Approximated confidence value for this transcription. This is roughly the
@ -286,3 +296,12 @@ class Metadata(object):
contributed to the creation of this transcription.
"""
class Metadata(object):
def transcripts(self):
"""
List of candidate transcripts
:return: A list of :func:`CandidateTranscript` objects
:type: list
"""

View File

@ -18,6 +18,7 @@ try:
except ImportError:
from pipes import quote
def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
try:
@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate):
def metadata_to_string(metadata):
return ''.join(item.character for item in metadata.items)
return ''.join(token.text for token in metadata.tokens)
def words_from_metadata(metadata):
def words_from_candidate_transcript(metadata):
word = ""
word_list = []
word_start_time = 0
# Loop through each character
for i in range(0, metadata.num_items):
item = metadata.items[i]
for i, token in enumerate(metadata.tokens):
# Append character to word if it's not a space
if item.character != " ":
if token.text != " ":
if len(word) == 0:
# Log the start time of the new word
word_start_time = item.start_time
word_start_time = token.start_time
word = word + item.character
word = word + token.text
# Word boundary is either a space or the last character in the array
if item.character == " " or i == metadata.num_items - 1:
word_duration = item.start_time - word_start_time
if token.text == " " or i == len(metadata.tokens) - 1:
word_duration = token.start_time - word_start_time
if word_duration < 0:
word_duration = 0
@ -69,9 +70,11 @@ def words_from_metadata(metadata):
def metadata_json_output(metadata):
json_result = dict()
json_result["words"] = words_from_metadata(metadata)
json_result["confidence"] = metadata.confidence
return json.dumps(json_result)
json_result["transcripts"] = [{
"confidence": transcript.confidence,
"words": words_from_candidate_transcript(transcript),
} for transcript in metadata.transcripts]
return json.dumps(json_result, indent=2)
@ -141,9 +144,9 @@ def main():
print('Running inference.', file=sys.stderr)
inference_start = timer()
if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio)))
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
elif args.json:
print(metadata_json_output(ds.sttWithMetadata(audio)))
print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
else:
print(ds.stt(audio))
inference_end = timer() - inference_start

View File

@ -38,30 +38,69 @@ import_array();
%append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN));
}
%typemap(out) MetadataItem* %{
$result = PyList_New(arg1->num_items);
for (int i = 0; i < arg1->num_items; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0);
%fragment("parent_reference_init", "init") {
// Thread-safe initialization - initialize during Python module initialization
parent_reference();
}
%fragment("parent_reference_function", "header", fragment="parent_reference_init") {
static PyObject *parent_reference() {
static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference");
return parent_reference_string;
}
}
%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{
$result = PyList_New(arg1->num_transcripts);
for (int i = 0; i < arg1->num_transcripts; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0);
// Add a reference to Metadata in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o);
}
%}
%extend struct MetadataItem {
%typemap(out, fragment="parent_reference_function") TokenMetadata* %{
$result = PyList_New(arg1->num_tokens);
for (int i = 0; i < arg1->num_tokens; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0);
// Add a reference to CandidateTranscript in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o);
}
%}
%extend struct TokenMetadata {
%pythoncode %{
def __repr__(self):
return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time)
return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time)
%}
}
%extend struct CandidateTranscript {
%pythoncode %{
def __repr__(self):
tokens_repr = ',\n'.join(repr(i) for i in self.tokens)
tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n'))
return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr)
%}
}
%extend struct Metadata {
%pythoncode %{
def __repr__(self):
items_repr = ', \n'.join(' ' + repr(i) for i in self.items)
return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr)
transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts)
transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n'))
return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr)
%}
}
%ignore Metadata::num_items;
%ignore Metadata::num_transcripts;
%ignore CandidateTranscript::num_tokens;
%extend struct Metadata {
~Metadata() {
@ -69,10 +108,16 @@ import_array();
}
}
%nodefaultdtor Metadata;
%immutable Metadata::transcripts;
%immutable CandidateTranscript::tokens;
%immutable TokenMetadata::text;
%nodefaultctor Metadata;
%nodefaultctor MetadataItem;
%nodefaultdtor MetadataItem;
%nodefaultdtor Metadata;
%nodefaultctor CandidateTranscript;
%nodefaultdtor CandidateTranscript;
%nodefaultctor TokenMetadata;
%nodefaultdtor TokenMetadata;
%typemap(newfree) char* "DS_FreeString($1);";