Adapt Python bindings to new API
This commit is contained in:
parent
c74dcffe79
commit
6e88a37ad4
@ -121,17 +121,20 @@ class Model(object):
|
||||
"""
|
||||
return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
|
||||
|
||||
def sttWithMetadata(self, audio_buffer):
|
||||
def sttWithMetadata(self, audio_buffer, num_results=1):
|
||||
"""
|
||||
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
|
||||
|
||||
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type audio_buffer: numpy.int16 array
|
||||
|
||||
:param num_results: Number of candidate transcripts to return.
|
||||
:type num_results: int
|
||||
|
||||
:return: Outputs a struct of individual letters along with their timing information.
|
||||
:type: :func:`Metadata`
|
||||
"""
|
||||
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer)
|
||||
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
|
||||
|
||||
def createStream(self):
|
||||
"""
|
||||
@ -187,6 +190,19 @@ class Stream(object):
|
||||
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
|
||||
return deepspeech.impl.IntermediateDecode(self._impl)
|
||||
|
||||
def intermediateDecodeWithMetadata(self, num_results=1):
|
||||
"""
|
||||
Compute the intermediate decoding of an ongoing streaming inference.
|
||||
|
||||
:return: The STT intermediate result.
|
||||
:type: str
|
||||
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
|
||||
return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
|
||||
|
||||
def finishStream(self):
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming inference,
|
||||
@ -203,11 +219,14 @@ class Stream(object):
|
||||
self._impl = None
|
||||
return result
|
||||
|
||||
def finishStreamWithMetadata(self):
|
||||
def finishStreamWithMetadata(self, num_results=1):
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming inference,
|
||||
returns per-letter metadata.
|
||||
|
||||
:param num_results: Number of candidate transcripts to return.
|
||||
:type num_results: int
|
||||
|
||||
:return: Outputs a struct of individual letters along with their timing information.
|
||||
:type: :func:`Metadata`
|
||||
|
||||
@ -215,7 +234,7 @@ class Stream(object):
|
||||
"""
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
|
||||
result = deepspeech.impl.FinishStreamWithMetadata(self._impl)
|
||||
result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results)
|
||||
self._impl = None
|
||||
return result
|
||||
|
||||
@ -233,52 +252,43 @@ class Stream(object):
|
||||
|
||||
|
||||
# This is only for documentation purpose
|
||||
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h
|
||||
class MetadataItem(object):
|
||||
# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h
|
||||
class TokenMetadata(object):
|
||||
"""
|
||||
Stores each individual character, along with its timing information
|
||||
"""
|
||||
|
||||
def character(self):
|
||||
def text(self):
|
||||
"""
|
||||
The character generated for transcription
|
||||
The text for this token
|
||||
"""
|
||||
|
||||
|
||||
def timestep(self):
|
||||
"""
|
||||
Position of the character in units of 20ms
|
||||
Position of the token in units of 20ms
|
||||
"""
|
||||
|
||||
|
||||
def start_time(self):
|
||||
"""
|
||||
Position of the character in seconds
|
||||
Position of the token in seconds
|
||||
"""
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
class CandidateTranscript(object):
|
||||
"""
|
||||
Stores the entire CTC output as an array of character metadata objects
|
||||
"""
|
||||
def items(self):
|
||||
def tokens(self):
|
||||
"""
|
||||
List of items
|
||||
List of tokens
|
||||
|
||||
:return: A list of :func:`MetadataItem` elements
|
||||
:return: A list of :func:`TokenMetadata` elements
|
||||
:type: list
|
||||
"""
|
||||
|
||||
|
||||
def num_items(self):
|
||||
"""
|
||||
Size of the list of items
|
||||
|
||||
:return: Size of the list of items
|
||||
:type: int
|
||||
"""
|
||||
|
||||
|
||||
def confidence(self):
|
||||
"""
|
||||
Approximated confidence value for this transcription. This is roughly the
|
||||
@ -286,3 +296,12 @@ class Metadata(object):
|
||||
contributed to the creation of this transcription.
|
||||
"""
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
def transcripts(self):
|
||||
"""
|
||||
List of candidate transcripts
|
||||
|
||||
:return: A list of :func:`CandidateTranscript` objects
|
||||
:type: list
|
||||
"""
|
||||
|
@ -18,6 +18,7 @@ try:
|
||||
except ImportError:
|
||||
from pipes import quote
|
||||
|
||||
|
||||
def convert_samplerate(audio_path, desired_sample_rate):
|
||||
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
|
||||
try:
|
||||
@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate):
|
||||
|
||||
|
||||
def metadata_to_string(metadata):
|
||||
return ''.join(item.character for item in metadata.items)
|
||||
return ''.join(token.text for token in metadata.tokens)
|
||||
|
||||
def words_from_metadata(metadata):
|
||||
|
||||
def words_from_candidate_transcript(metadata):
|
||||
word = ""
|
||||
word_list = []
|
||||
word_start_time = 0
|
||||
# Loop through each character
|
||||
for i in range(0, metadata.num_items):
|
||||
item = metadata.items[i]
|
||||
for i, token in enumerate(metadata.tokens):
|
||||
# Append character to word if it's not a space
|
||||
if item.character != " ":
|
||||
if token.text != " ":
|
||||
if len(word) == 0:
|
||||
# Log the start time of the new word
|
||||
word_start_time = item.start_time
|
||||
word_start_time = token.start_time
|
||||
|
||||
word = word + item.character
|
||||
word = word + token.text
|
||||
# Word boundary is either a space or the last character in the array
|
||||
if item.character == " " or i == metadata.num_items - 1:
|
||||
word_duration = item.start_time - word_start_time
|
||||
if token.text == " " or i == len(metadata.tokens) - 1:
|
||||
word_duration = token.start_time - word_start_time
|
||||
|
||||
if word_duration < 0:
|
||||
word_duration = 0
|
||||
@ -69,9 +70,11 @@ def words_from_metadata(metadata):
|
||||
|
||||
def metadata_json_output(metadata):
|
||||
json_result = dict()
|
||||
json_result["words"] = words_from_metadata(metadata)
|
||||
json_result["confidence"] = metadata.confidence
|
||||
return json.dumps(json_result)
|
||||
json_result["transcripts"] = [{
|
||||
"confidence": transcript.confidence,
|
||||
"words": words_from_candidate_transcript(transcript),
|
||||
} for transcript in metadata.transcripts]
|
||||
return json.dumps(json_result, indent=2)
|
||||
|
||||
|
||||
|
||||
@ -141,9 +144,9 @@ def main():
|
||||
print('Running inference.', file=sys.stderr)
|
||||
inference_start = timer()
|
||||
if args.extended:
|
||||
print(metadata_to_string(ds.sttWithMetadata(audio)))
|
||||
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
|
||||
elif args.json:
|
||||
print(metadata_json_output(ds.sttWithMetadata(audio)))
|
||||
print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
|
||||
else:
|
||||
print(ds.stt(audio))
|
||||
inference_end = timer() - inference_start
|
||||
|
@ -38,30 +38,69 @@ import_array();
|
||||
%append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN));
|
||||
}
|
||||
|
||||
%typemap(out) MetadataItem* %{
|
||||
$result = PyList_New(arg1->num_items);
|
||||
for (int i = 0; i < arg1->num_items; ++i) {
|
||||
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0);
|
||||
%fragment("parent_reference_init", "init") {
|
||||
// Thread-safe initialization - initialize during Python module initialization
|
||||
parent_reference();
|
||||
}
|
||||
|
||||
%fragment("parent_reference_function", "header", fragment="parent_reference_init") {
|
||||
|
||||
static PyObject *parent_reference() {
|
||||
static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference");
|
||||
return parent_reference_string;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{
|
||||
$result = PyList_New(arg1->num_transcripts);
|
||||
for (int i = 0; i < arg1->num_transcripts; ++i) {
|
||||
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0);
|
||||
// Add a reference to Metadata in the returned elements to avoid premature
|
||||
// garbage collection
|
||||
PyObject_SetAttr(o, parent_reference(), $self);
|
||||
PyList_SetItem($result, i, o);
|
||||
}
|
||||
%}
|
||||
|
||||
%extend struct MetadataItem {
|
||||
%typemap(out, fragment="parent_reference_function") TokenMetadata* %{
|
||||
$result = PyList_New(arg1->num_tokens);
|
||||
for (int i = 0; i < arg1->num_tokens; ++i) {
|
||||
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0);
|
||||
// Add a reference to CandidateTranscript in the returned elements to avoid premature
|
||||
// garbage collection
|
||||
PyObject_SetAttr(o, parent_reference(), $self);
|
||||
PyList_SetItem($result, i, o);
|
||||
}
|
||||
%}
|
||||
|
||||
%extend struct TokenMetadata {
|
||||
%pythoncode %{
|
||||
def __repr__(self):
|
||||
return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time)
|
||||
return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time)
|
||||
%}
|
||||
}
|
||||
|
||||
%extend struct CandidateTranscript {
|
||||
%pythoncode %{
|
||||
def __repr__(self):
|
||||
tokens_repr = ',\n'.join(repr(i) for i in self.tokens)
|
||||
tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n'))
|
||||
return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr)
|
||||
%}
|
||||
}
|
||||
|
||||
%extend struct Metadata {
|
||||
%pythoncode %{
|
||||
def __repr__(self):
|
||||
items_repr = ', \n'.join(' ' + repr(i) for i in self.items)
|
||||
return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr)
|
||||
transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts)
|
||||
transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n'))
|
||||
return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr)
|
||||
%}
|
||||
}
|
||||
|
||||
%ignore Metadata::num_items;
|
||||
%ignore Metadata::num_transcripts;
|
||||
%ignore CandidateTranscript::num_tokens;
|
||||
|
||||
%extend struct Metadata {
|
||||
~Metadata() {
|
||||
@ -69,10 +108,16 @@ import_array();
|
||||
}
|
||||
}
|
||||
|
||||
%nodefaultdtor Metadata;
|
||||
%immutable Metadata::transcripts;
|
||||
%immutable CandidateTranscript::tokens;
|
||||
%immutable TokenMetadata::text;
|
||||
|
||||
%nodefaultctor Metadata;
|
||||
%nodefaultctor MetadataItem;
|
||||
%nodefaultdtor MetadataItem;
|
||||
%nodefaultdtor Metadata;
|
||||
%nodefaultctor CandidateTranscript;
|
||||
%nodefaultdtor CandidateTranscript;
|
||||
%nodefaultctor TokenMetadata;
|
||||
%nodefaultdtor TokenMetadata;
|
||||
|
||||
%typemap(newfree) char* "DS_FreeString($1);";
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user