Merge branch 'embed-beam-width' (Fixes #2744)
This commit is contained in:
commit
88a1048322
@ -797,6 +797,7 @@ def export():
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
@ -875,11 +876,6 @@ def package_zip():
|
||||
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
|
||||
json.dump({
|
||||
'name': FLAGS.export_language,
|
||||
'parameters': {
|
||||
'beamWidth': FLAGS.beam_width,
|
||||
'lmAlpha': FLAGS.lm_alpha,
|
||||
'lmBeta': FLAGS.lm_beta
|
||||
}
|
||||
}, f)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
@ -30,15 +30,9 @@ This module should be self-contained:
|
||||
Then run with a TF Lite model, a scorer and a CSV test file
|
||||
'''
|
||||
|
||||
BEAM_WIDTH = 500
|
||||
LM_ALPHA = 0.75
|
||||
LM_BETA = 1.85
|
||||
|
||||
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
|
||||
ds = Model(model, BEAM_WIDTH)
|
||||
ds.enableExternalScorer(scorer)
|
||||
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
|
||||
ds = Model(model)
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
@ -16,7 +16,9 @@ char* scorer = NULL;
|
||||
|
||||
char* audio = NULL;
|
||||
|
||||
int beam_width = 500;
|
||||
bool set_beamwidth = false;
|
||||
|
||||
int beam_width = 0;
|
||||
|
||||
bool set_alphabeta = false;
|
||||
|
||||
@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv)
|
||||
break;
|
||||
|
||||
case 'b':
|
||||
set_beamwidth = true;
|
||||
beam_width = atoi(optarg);
|
||||
break;
|
||||
|
||||
|
@ -368,14 +368,22 @@ main(int argc, char **argv)
|
||||
|
||||
// Initialise DeepSpeech
|
||||
ModelState* ctx;
|
||||
int status = DS_CreateModel(model, beam_width, &ctx);
|
||||
int status = DS_CreateModel(model, &ctx);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not create model.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (set_beamwidth) {
|
||||
status = DS_SetModelBeamWidth(ctx, beam_width);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not set model beam width.\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (scorer) {
|
||||
int status = DS_EnableExternalScorer(ctx, scorer);
|
||||
status = DS_EnableExternalScorer(ctx, scorer);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not enable external scorer.\n");
|
||||
return 1;
|
||||
|
@ -257,7 +257,6 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
|
||||
int
|
||||
DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aBeamWidth,
|
||||
ModelState** retval)
|
||||
{
|
||||
*retval = nullptr;
|
||||
@ -282,7 +281,7 @@ DS_CreateModel(const char* aModelPath,
|
||||
return DS_ERR_FAIL_CREATE_MODEL;
|
||||
}
|
||||
|
||||
int err = model->init(aModelPath, aBeamWidth);
|
||||
int err = model->init(aModelPath);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
@ -291,6 +290,19 @@ DS_CreateModel(const char* aModelPath,
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
unsigned int
|
||||
DS_GetModelBeamWidth(ModelState* aCtx)
|
||||
{
|
||||
return aCtx->beam_width_;
|
||||
}
|
||||
|
||||
int
|
||||
DS_SetModelBeamWidth(ModelState* aCtx, unsigned int aBeamWidth)
|
||||
{
|
||||
aCtx->beam_width_ = aBeamWidth;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
DS_GetModelSampleRate(ModelState* aCtx)
|
||||
{
|
||||
|
@ -78,18 +78,39 @@ enum DeepSpeech_Error_Codes
|
||||
* @brief An object providing an interface to a trained DeepSpeech model.
|
||||
*
|
||||
* @param aModelPath The path to the frozen model graph.
|
||||
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
||||
* width generates better results at the cost of decoding
|
||||
* time.
|
||||
* @param[out] retval a ModelState pointer
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aBeamWidth,
|
||||
ModelState** retval);
|
||||
|
||||
/**
|
||||
* @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
|
||||
* was not called before, will return the default value loaded from the
|
||||
* model file.
|
||||
*
|
||||
* @param aCtx A ModelState pointer created with {@link DS_CreateModel}.
|
||||
*
|
||||
* @return Beam width value used by the model.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
unsigned int DS_GetModelBeamWidth(ModelState* aCtx);
|
||||
|
||||
/**
|
||||
* @brief Set beam width value used by the model.
|
||||
*
|
||||
* @param aCtx A ModelState pointer created with {@link DS_CreateModel}.
|
||||
* @param aBeamWidth The beam width used by the model. A larger beam width value
|
||||
* generates better results at the cost of decoding time.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_SetModelBeamWidth(ModelState* aCtx,
|
||||
unsigned int aBeamWidth);
|
||||
|
||||
/**
|
||||
* @brief Return the sample rate expected by a model.
|
||||
*
|
||||
|
@ -19,11 +19,10 @@ namespace DeepSpeechClient
|
||||
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
|
||||
/// </summary>
|
||||
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
||||
public DeepSpeech(string aModelPath, uint aBeamWidth)
|
||||
public DeepSpeech(string aModelPath)
|
||||
{
|
||||
CreateModel(aModelPath, aBeamWidth);
|
||||
CreateModel(aModelPath);
|
||||
}
|
||||
|
||||
#region IDeepSpeech
|
||||
@ -32,10 +31,8 @@ namespace DeepSpeechClient
|
||||
/// Create an object providing an interface to a trained DeepSpeech model.
|
||||
/// </summary>
|
||||
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
||||
private unsafe void CreateModel(string aModelPath,
|
||||
uint aBeamWidth)
|
||||
private unsafe void CreateModel(string aModelPath)
|
||||
{
|
||||
string exceptionMessage = null;
|
||||
if (string.IsNullOrWhiteSpace(aModelPath))
|
||||
@ -52,11 +49,31 @@ namespace DeepSpeechClient
|
||||
throw new FileNotFoundException(exceptionMessage);
|
||||
}
|
||||
var resultCode = NativeImp.DS_CreateModel(aModelPath,
|
||||
aBeamWidth,
|
||||
ref _modelStatePP);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get beam width value used by the model. If SetModelBeamWidth was not
|
||||
/// called before, will return the default value loaded from the model file.
|
||||
/// </summary>
|
||||
/// <returns>Beam width value used by the model.</returns>
|
||||
public unsafe uint GetModelBeamWidth()
|
||||
{
|
||||
return NativeImp.DS_GetModelBeamWidth(_modelStatePP);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set beam width value used by the model.
|
||||
/// </summary>
|
||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
public unsafe void SetModelBeamWidth(uint aBeamWidth)
|
||||
{
|
||||
var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return the sample rate expected by the model.
|
||||
/// </summary>
|
||||
|
@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces
|
||||
/// <returns>Sample rate.</returns>
|
||||
unsafe int GetModelSampleRate();
|
||||
|
||||
/// <summary>
|
||||
/// Get beam width value used by the model. If SetModelBeamWidth was not
|
||||
/// called before, will return the default value loaded from the model
|
||||
/// file.
|
||||
/// </summary>
|
||||
/// <returns>Beam width value used by the model.</returns>
|
||||
unsafe uint GetModelBeamWidth();
|
||||
|
||||
/// <summary>
|
||||
/// Set beam width value used by the model.
|
||||
/// </summary>
|
||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
unsafe void SetModelBeamWidth(uint aBeamWidth);
|
||||
|
||||
/// <summary>
|
||||
/// Enable decoding using an external scorer.
|
||||
/// </summary>
|
||||
|
@ -14,6 +14,17 @@ namespace DeepSpeechClient
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static extern void DS_PrintVersions();
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
|
||||
ref IntPtr** pint);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
|
||||
uint aBeamWidth);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
|
||||
uint aBeamWidth,
|
||||
|
@ -46,15 +46,12 @@ namespace CSharpExamples
|
||||
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
|
||||
}
|
||||
|
||||
const uint BEAM_WIDTH = 500;
|
||||
|
||||
Stopwatch stopwatch = new Stopwatch();
|
||||
try
|
||||
{
|
||||
Console.WriteLine("Loading model...");
|
||||
stopwatch.Start();
|
||||
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
|
||||
BEAM_WIDTH))
|
||||
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
|
||||
{
|
||||
stopwatch.Stop();
|
||||
|
||||
|
@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||
private void newModel(String tfliteModel) {
|
||||
this._tfliteStatus.setText("Creating model");
|
||||
if (this._m == null) {
|
||||
this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH);
|
||||
this._m = new DeepSpeechModel(tfliteModel);
|
||||
this._m.setBeamWidth(BEAM_WIDTH);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,8 +33,6 @@ public class BasicTest {
|
||||
public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer";
|
||||
public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav";
|
||||
|
||||
public static final int BEAM_WIDTH = 50;
|
||||
|
||||
private char readLEChar(RandomAccessFile f) throws IOException {
|
||||
byte b1 = f.readByte();
|
||||
byte b2 = f.readByte();
|
||||
@ -59,7 +57,7 @@ public class BasicTest {
|
||||
|
||||
@Test
|
||||
public void loadDeepSpeech_basic() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||
m.freeModel();
|
||||
}
|
||||
|
||||
@ -116,7 +114,7 @@ public class BasicTest {
|
||||
|
||||
@Test
|
||||
public void loadDeepSpeech_stt_noLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||
|
||||
String decoded = doSTT(m, false);
|
||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||
@ -125,7 +123,7 @@ public class BasicTest {
|
||||
|
||||
@Test
|
||||
public void loadDeepSpeech_stt_withLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||
m.enableExternalScorer(scorerFile);
|
||||
|
||||
String decoded = doSTT(m, false);
|
||||
@ -135,7 +133,7 @@ public class BasicTest {
|
||||
|
||||
@Test
|
||||
public void loadDeepSpeech_sttWithMetadata_noLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||
|
||||
String decoded = doSTT(m, true);
|
||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||
@ -144,7 +142,7 @@ public class BasicTest {
|
||||
|
||||
@Test
|
||||
public void loadDeepSpeech_sttWithMetadata_withLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||
m.enableExternalScorer(scorerFile);
|
||||
|
||||
String decoded = doSTT(m, true);
|
||||
|
@ -20,16 +20,35 @@ public class DeepSpeechModel {
|
||||
* @constructor
|
||||
*
|
||||
* @param modelPath The path to the frozen model graph.
|
||||
* @param beam_width The beam width used by the decoder. A larger beam
|
||||
* width generates better results at the cost of decoding
|
||||
* time.
|
||||
*/
|
||||
public DeepSpeechModel(String modelPath, int beam_width) {
|
||||
public DeepSpeechModel(String modelPath) {
|
||||
this._mspp = impl.new_modelstatep();
|
||||
impl.CreateModel(modelPath, beam_width, this._mspp);
|
||||
impl.CreateModel(modelPath, this._mspp);
|
||||
this._msp = impl.modelstatep_value(this._mspp);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get beam width value used by the model. If setModelBeamWidth was not
|
||||
* called before, will return the default value loaded from the model file.
|
||||
*
|
||||
* @return Beam width value used by the model.
|
||||
*/
|
||||
public long beamWidth() {
|
||||
return impl.GetModelBeamWidth(this._msp);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set beam width value used by the model.
|
||||
*
|
||||
* @param aBeamWidth The beam width used by the model. A larger beam width value
|
||||
* generates better results at the cost of decoding time.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
public int setBeamWidth(long beamWidth) {
|
||||
return impl.SetModelBeamWidth(this._msp, beamWidth);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the sample rate expected by the model.
|
||||
*
|
||||
|
@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D
|
||||
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
|
||||
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
|
||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
|
||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
|
||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
||||
@ -53,10 +53,14 @@ function metadataToString(metadata) {
|
||||
|
||||
console.error('Loading model from file %s', args['model']);
|
||||
const model_load_start = process.hrtime();
|
||||
var model = new Ds.Model(args['model'], args['beam_width']);
|
||||
var model = new Ds.Model(args['model']);
|
||||
const model_load_end = process.hrtime(model_load_start);
|
||||
console.error('Loaded model in %ds.', totalTime(model_load_end));
|
||||
|
||||
if (args['beam_width']) {
|
||||
model.setBeamWidth(args['beam_width']);
|
||||
}
|
||||
|
||||
var desired_sample_rate = model.sampleRate();
|
||||
|
||||
if (args['scorer']) {
|
||||
|
@ -25,14 +25,13 @@ if (process.platform === 'win32') {
|
||||
* An object providing an interface to a trained DeepSpeech model.
|
||||
*
|
||||
* @param {string} aModelPath The path to the frozen model graph.
|
||||
* @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
function Model() {
|
||||
function Model(aModelPath) {
|
||||
this._impl = null;
|
||||
|
||||
const rets = binding.CreateModel.apply(null, arguments);
|
||||
const rets = binding.CreateModel(aModelPath);
|
||||
const status = rets[0];
|
||||
const impl = rets[1];
|
||||
if (status !== 0) {
|
||||
@ -42,6 +41,27 @@ function Model() {
|
||||
this._impl = impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get beam width value used by the model. If :js:func:Model.setBeamWidth was
|
||||
* not called before, will return the default value loaded from the model file.
|
||||
*
|
||||
* @return {number} Beam width value used by the model.
|
||||
*/
|
||||
Model.prototype.beamWidth = function() {
|
||||
return binding.GetModelBeamWidth(this._impl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set beam width value used by the model.
|
||||
*
|
||||
* @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
|
||||
*
|
||||
* @return {number} Zero on success, non-zero on failure.
|
||||
*/
|
||||
Model.prototype.setBeamWidth = function(aBeamWidth) {
|
||||
return binding.SetModelBeamWidth(this._impl, aBeamWidth);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the sample rate expected by the model.
|
||||
*
|
||||
|
@ -24,10 +24,8 @@ ModelState::~ModelState()
|
||||
}
|
||||
|
||||
int
|
||||
ModelState::init(const char* model_path,
|
||||
unsigned int beam_width)
|
||||
ModelState::init(const char* model_path)
|
||||
{
|
||||
beam_width_ = beam_width;
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ struct ModelState {
|
||||
ModelState();
|
||||
virtual ~ModelState();
|
||||
|
||||
virtual int init(const char* model_path, unsigned int beam_width);
|
||||
virtual int init(const char* model_path);
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||
|
||||
|
@ -28,15 +28,12 @@ class Model(object):
|
||||
|
||||
:param aModelPath: Path to model file to load
|
||||
:type aModelPath: str
|
||||
|
||||
:param aBeamWidth: Decoder beam width
|
||||
:type aBeamWidth: int
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, model_path):
|
||||
# make sure the attribute is there if CreateModel fails
|
||||
self._impl = None
|
||||
|
||||
status, impl = deepspeech.impl.CreateModel(*args, **kwargs)
|
||||
status, impl = deepspeech.impl.CreateModel(model_path)
|
||||
if status != 0:
|
||||
raise RuntimeError("CreateModel failed with error code {}".format(status))
|
||||
self._impl = impl
|
||||
@ -46,6 +43,28 @@ class Model(object):
|
||||
deepspeech.impl.FreeModel(self._impl)
|
||||
self._impl = None
|
||||
|
||||
def beamWidth(self):
|
||||
"""
|
||||
Get beam width value used by the model. If setModelBeamWidth was not
|
||||
called before, will return the default value loaded from the model file.
|
||||
|
||||
:return: Beam width value used by the model.
|
||||
:type: int
|
||||
"""
|
||||
return deepspeech.impl.GetModelBeamWidth(self._impl)
|
||||
|
||||
def setBeamWidth(self, beam_width):
|
||||
"""
|
||||
Set beam width value used by the model.
|
||||
|
||||
:param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
|
||||
:type beam_width: int
|
||||
|
||||
:return: Zero on success, non-zero on failure.
|
||||
:type: int
|
||||
"""
|
||||
return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
|
||||
|
||||
def sampleRate(self):
|
||||
"""
|
||||
Return the sample rate expected by the model.
|
||||
|
@ -92,7 +92,7 @@ def main():
|
||||
help='Path to the external scorer file')
|
||||
parser.add_argument('--audio', required=True,
|
||||
help='Path to the audio file to run (WAV format)')
|
||||
parser.add_argument('--beam_width', type=int, default=500,
|
||||
parser.add_argument('--beam_width', type=int,
|
||||
help='Beam width for the CTC decoder')
|
||||
parser.add_argument('--lm_alpha', type=float,
|
||||
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
|
||||
@ -108,10 +108,13 @@ def main():
|
||||
|
||||
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
||||
model_load_start = timer()
|
||||
ds = Model(args.model, args.beam_width)
|
||||
ds = Model(args.model)
|
||||
model_load_end = timer() - model_load_start
|
||||
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
||||
|
||||
if args.beam_width:
|
||||
ds.setModelBeamWidth(args.beam_width)
|
||||
|
||||
desired_sample_rate = ds.sampleRate()
|
||||
|
||||
if args.scorer:
|
||||
|
@ -9,12 +9,6 @@ import wave
|
||||
from deepspeech import Model
|
||||
|
||||
|
||||
# These constants control the beam search decoder
|
||||
|
||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||
BEAM_WIDTH = 500
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
||||
parser.add_argument('--model', required=True,
|
||||
@ -27,7 +21,7 @@ def main():
|
||||
help='Second audio file to use in interleaved streams')
|
||||
args = parser.parse_args()
|
||||
|
||||
ds = Model(args.model, BEAM_WIDTH)
|
||||
ds = Model(args.model)
|
||||
|
||||
if args.scorer:
|
||||
ds.enableExternalScorer(args.scorer)
|
||||
|
@ -91,10 +91,9 @@ TFLiteModelState::~TFLiteModelState()
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::init(const char* model_path,
|
||||
unsigned int beam_width)
|
||||
TFLiteModelState::init(const char* model_path)
|
||||
{
|
||||
int err = ModelState::init(model_path, beam_width);
|
||||
int err = ModelState::init(model_path);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
@ -129,6 +128,7 @@ TFLiteModelState::init(const char* model_path,
|
||||
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
|
||||
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
||||
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
||||
int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width");
|
||||
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
|
||||
|
||||
std::vector<int> metadata_exec_plan;
|
||||
@ -136,6 +136,7 @@ TFLiteModelState::init(const char* model_path,
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
|
||||
|
||||
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
||||
@ -201,6 +202,9 @@ TFLiteModelState::init(const char* model_path,
|
||||
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
|
||||
|
||||
int* const beam_width = interpreter_->typed_tensor<int>(metadata_beam_width_idx);
|
||||
beam_width_ = (unsigned int)(*beam_width);
|
||||
|
||||
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
||||
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||
if (err != 0) {
|
||||
@ -210,6 +214,8 @@ TFLiteModelState::init(const char* model_path,
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
assert(beam_width_ > 0);
|
||||
assert(alphabet_.GetSize() > 0);
|
||||
|
||||
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||
|
||||
|
@ -30,8 +30,7 @@ struct TFLiteModelState : public ModelState
|
||||
TFLiteModelState();
|
||||
virtual ~TFLiteModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int beam_width) override;
|
||||
virtual int init(const char* model_path) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
|
@ -23,10 +23,9 @@ TFModelState::~TFModelState()
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::init(const char* model_path,
|
||||
unsigned int beam_width)
|
||||
TFModelState::init(const char* model_path)
|
||||
{
|
||||
int err = ModelState::init(model_path, beam_width);
|
||||
int err = ModelState::init(model_path);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
@ -103,6 +102,7 @@ TFModelState::init(const char* model_path,
|
||||
"metadata_sample_rate",
|
||||
"metadata_feature_win_len",
|
||||
"metadata_feature_win_step",
|
||||
"metadata_beam_width",
|
||||
"metadata_alphabet",
|
||||
}, {}, &metadata_outputs);
|
||||
if (!status.ok()) {
|
||||
@ -115,8 +115,10 @@ TFModelState::init(const char* model_path,
|
||||
int win_step_ms = metadata_outputs[2].scalar<int>()();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
int beam_width = metadata_outputs[3].scalar<int>()();
|
||||
beam_width_ = (unsigned int)(beam_width);
|
||||
|
||||
string serialized_alphabet = metadata_outputs[3].scalar<string>()();
|
||||
string serialized_alphabet = metadata_outputs[4].scalar<string>()();
|
||||
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||
if (err != 0) {
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
@ -125,6 +127,8 @@ TFModelState::init(const char* model_path,
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
assert(beam_width_ > 0);
|
||||
assert(alphabet_.GetSize() > 0);
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
|
@ -18,8 +18,7 @@ struct TFModelState : public ModelState
|
||||
TFModelState();
|
||||
virtual ~TFModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int beam_width) override;
|
||||
virtual int init(const char* model_path) override;
|
||||
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
|
@ -30,11 +30,11 @@ then:
|
||||
image: ${build.docker_image}
|
||||
|
||||
env:
|
||||
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models.tar.gz"
|
||||
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models_beam_width.tar.gz"
|
||||
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
|
||||
EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a"
|
||||
EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width"
|
||||
|
||||
command:
|
||||
- "/bin/bash"
|
||||
|
@ -44,7 +44,7 @@ payload:
|
||||
MSYS: 'winsymlinks:nativestrict'
|
||||
TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow}
|
||||
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
|
||||
EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a"
|
||||
EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width"
|
||||
|
||||
command:
|
||||
- >-
|
||||
|
@ -111,6 +111,7 @@ def create_flags():
|
||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
|
||||
f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
|
||||
f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph')
|
||||
|
||||
# Reporting
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user