diff --git a/DeepSpeech.py b/DeepSpeech.py
index 48b8edb6..9421e7f0 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -877,7 +877,7 @@ def package_zip():
json.dump({
'name': FLAGS.export_language,
'parameters': {
- 'beamWidth': FLAGS.beam_width,
+ 'beamWidth': FLAGS.export_beam_width,
'lmAlpha': FLAGS.lm_alpha,
'lmBeta': FLAGS.lm_beta
}
diff --git a/evaluate_tflite.py b/evaluate_tflite.py
index aba6fb68..d01db864 100644
--- a/evaluate_tflite.py
+++ b/evaluate_tflite.py
@@ -36,7 +36,8 @@ LM_BETA = 1.85
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
- ds = Model(model, BEAM_WIDTH)
+ ds = Model(model)
+ ds.setBeamWidth(BEAM_WIDTH)
ds.enableExternalScorer(scorer)
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
diff --git a/native_client/args.h b/native_client/args.h
index d5a0f869..2e7306c7 100644
--- a/native_client/args.h
+++ b/native_client/args.h
@@ -16,7 +16,9 @@ char* scorer = NULL;
char* audio = NULL;
-int beam_width = 500;
+bool set_beamwidth = false;
+
+int beam_width = 0;
bool set_alphabeta = false;
@@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv)
break;
case 'b':
+ set_beamwidth = true;
beam_width = atoi(optarg);
break;
diff --git a/native_client/client.cc b/native_client/client.cc
index 718fba75..abcadd8d 100644
--- a/native_client/client.cc
+++ b/native_client/client.cc
@@ -368,14 +368,22 @@ main(int argc, char **argv)
// Initialise DeepSpeech
ModelState* ctx;
- int status = DS_CreateModel(model, beam_width, &ctx);
+ int status = DS_CreateModel(model, &ctx);
if (status != 0) {
fprintf(stderr, "Could not create model.\n");
return 1;
}
+ if (set_beamwidth) {
+ status = DS_SetModelBeamWidth(ctx, beam_width);
+ if (status != 0) {
+ fprintf(stderr, "Could not set model beam width.\n");
+ return 1;
+ }
+ }
+
if (scorer) {
- int status = DS_EnableExternalScorer(ctx, scorer);
+ status = DS_EnableExternalScorer(ctx, scorer);
if (status != 0) {
fprintf(stderr, "Could not enable external scorer.\n");
return 1;
diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs
index e5e33370..15c2212c 100644
--- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs
+++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs
@@ -19,11 +19,10 @@ namespace DeepSpeechClient
/// Initializes a new instance of class and creates a new acoustic model.
///
/// The path to the frozen model graph.
- /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
/// Thrown when the native binary failed to create the model.
- public DeepSpeech(string aModelPath, uint aBeamWidth)
+ public DeepSpeech(string aModelPath)
{
- CreateModel(aModelPath, aBeamWidth);
+ CreateModel(aModelPath);
}
#region IDeepSpeech
@@ -32,10 +31,8 @@ namespace DeepSpeechClient
/// Create an object providing an interface to a trained DeepSpeech model.
///
/// The path to the frozen model graph.
- /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
/// Thrown when the native binary failed to create the model.
- private unsafe void CreateModel(string aModelPath,
- uint aBeamWidth)
+ private unsafe void CreateModel(string aModelPath)
{
string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aModelPath))
@@ -52,11 +49,31 @@ namespace DeepSpeechClient
throw new FileNotFoundException(exceptionMessage);
}
var resultCode = NativeImp.DS_CreateModel(aModelPath,
- aBeamWidth,
ref _modelStatePP);
EvaluateResultCode(resultCode);
}
+ ///
+ /// Get beam width value used by the model. If SetModelBeamWidth was not
+ /// called before, will return the default value loaded from the model file.
+ ///
+ /// Beam width value used by the model.
+ public unsafe uint GetModelBeamWidth()
+ {
+ return NativeImp.DS_GetModelBeamWidth(_modelStatePP);
+ }
+
+ ///
+ /// Set beam width value used by the model.
+ ///
+ /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.
+ /// Thrown on failure.
+ public unsafe void SetModelBeamWidth(uint aBeamWidth)
+ {
+ var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth);
+ EvaluateResultCode(resultCode);
+ }
+
///
/// Return the sample rate expected by the model.
///
diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs
index ecbfb7e9..f00c188d 100644
--- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs
+++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs
@@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces
/// Sample rate.
unsafe int GetModelSampleRate();
+ ///
+ /// Get beam width value used by the model. If SetModelBeamWidth was not
+ /// called before, will return the default value loaded from the model
+ /// file.
+ ///
+ /// Beam width value used by the model.
+ unsafe uint GetModelBeamWidth();
+
+ ///
+ /// Set beam width value used by the model.
+ ///
+ /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.
+ /// Thrown on failure.
+ unsafe void SetModelBeamWidth(uint aBeamWidth);
+
///
/// Enable decoding using an external scorer.
///
diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs
index 1c49feec..af28618c 100644
--- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs
+++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs
@@ -14,6 +14,17 @@ namespace DeepSpeechClient
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static extern void DS_PrintVersions();
+ [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
+ internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
+ ref IntPtr** pint);
+
+ [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
+ internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
+
+ [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
+ internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
+ uint aBeamWidth);
+
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
uint aBeamWidth,
diff --git a/native_client/dotnet/DeepSpeechConsole/Program.cs b/native_client/dotnet/DeepSpeechConsole/Program.cs
index 1f6e299b..b35c7046 100644
--- a/native_client/dotnet/DeepSpeechConsole/Program.cs
+++ b/native_client/dotnet/DeepSpeechConsole/Program.cs
@@ -46,15 +46,12 @@ namespace CSharpExamples
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
}
- const uint BEAM_WIDTH = 500;
-
Stopwatch stopwatch = new Stopwatch();
try
{
Console.WriteLine("Loading model...");
stopwatch.Start();
- using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
- BEAM_WIDTH))
+ using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
{
stopwatch.Stop();
diff --git a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
index 12e758df..f9d9a11e 100644
--- a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
+++ b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
@@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
private void newModel(String tfliteModel) {
this._tfliteStatus.setText("Creating model");
if (this._m == null) {
- this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH);
+ this._m = new DeepSpeechModel(tfliteModel);
+ this._m.setBeamWidth(BEAM_WIDTH);
}
}
diff --git a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java
index bb6bbe42..60d21256 100644
--- a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java
+++ b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java
@@ -59,7 +59,7 @@ public class BasicTest {
@Test
public void loadDeepSpeech_basic() {
- DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
+ DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.freeModel();
}
@@ -116,7 +116,8 @@ public class BasicTest {
@Test
public void loadDeepSpeech_stt_noLM() {
- DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
+ DeepSpeechModel m = new DeepSpeechModel(modelFile);
+ m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, false);
assertEquals("she had your dark suit in greasy wash water all year", decoded);
@@ -125,7 +126,8 @@ public class BasicTest {
@Test
public void loadDeepSpeech_stt_withLM() {
- DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
+ DeepSpeechModel m = new DeepSpeechModel(modelFile);
+ m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, false);
@@ -135,7 +137,8 @@ public class BasicTest {
@Test
public void loadDeepSpeech_sttWithMetadata_noLM() {
- DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
+ DeepSpeechModel m = new DeepSpeechModel(modelFile);
+ m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, true);
assertEquals("she had your dark suit in greasy wash water all year", decoded);
@@ -144,7 +147,8 @@ public class BasicTest {
@Test
public void loadDeepSpeech_sttWithMetadata_withLM() {
- DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
+ DeepSpeechModel m = new DeepSpeechModel(modelFile);
+ m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, true);
diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java
index 0438ac10..1c26e2f9 100644
--- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java
+++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java
@@ -20,16 +20,35 @@ public class DeepSpeechModel {
* @constructor
*
* @param modelPath The path to the frozen model graph.
- * @param beam_width The beam width used by the decoder. A larger beam
- * width generates better results at the cost of decoding
- * time.
*/
- public DeepSpeechModel(String modelPath, int beam_width) {
+ public DeepSpeechModel(String modelPath) {
this._mspp = impl.new_modelstatep();
- impl.CreateModel(modelPath, beam_width, this._mspp);
+ impl.CreateModel(modelPath, this._mspp);
this._msp = impl.modelstatep_value(this._mspp);
}
+ /**
+ * @brief Get beam width value used by the model. If setModelBeamWidth was not
+ * called before, will return the default value loaded from the model file.
+ *
+ * @return Beam width value used by the model.
+ */
+ public int beamWidth() {
+ return impl.GetModelBeamWidth(this._msp);
+ }
+
+ /**
+ * @brief Set beam width value used by the model.
+ *
+ * @param aBeamWidth The beam width used by the model. A larger beam width value
+ * generates better results at the cost of decoding time.
+ *
+ * @return Zero on success, non-zero on failure.
+ */
+ public int setBeamWidth(int beamWidth) {
+ return impl.SetModelBeamWidth(this._msp, beamWidth);
+ }
+
/**
* @brief Return the sample rate expected by the model.
*
diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js
index 7266b85d..09406ccc 100644
--- a/native_client/javascript/client.js
+++ b/native_client/javascript/client.js
@@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
-parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
+parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
@@ -53,10 +53,14 @@ function metadataToString(metadata) {
console.error('Loading model from file %s', args['model']);
const model_load_start = process.hrtime();
-var model = new Ds.Model(args['model'], args['beam_width']);
+var model = new Ds.Model(args['model']);
const model_load_end = process.hrtime(model_load_start);
console.error('Loaded model in %ds.', totalTime(model_load_end));
+if (args['beam_width']) {
+ model.setBeamWidth(args['beam_width']);
+}
+
var desired_sample_rate = model.sampleRate();
if (args['scorer']) {
diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js
index 772b1a82..38ecbf0a 100644
--- a/native_client/javascript/index.js
+++ b/native_client/javascript/index.js
@@ -25,14 +25,13 @@ if (process.platform === 'win32') {
* An object providing an interface to a trained DeepSpeech model.
*
* @param {string} aModelPath The path to the frozen model graph.
- * @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
*
* @throws on error
*/
-function Model() {
+function Model(aModelPath) {
this._impl = null;
- const rets = binding.CreateModel.apply(null, arguments);
+ const rets = binding.CreateModel(aModelPath);
const status = rets[0];
const impl = rets[1];
if (status !== 0) {
@@ -42,6 +41,38 @@ function Model() {
this._impl = impl;
}
+/**
+ * Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
+ * was not called before, will return the default value loaded from the
+ * model file.
+ *
+ * @return {number} Beam width value used by the model.
+ */
+Model.prototype.beamWidth = function() {
+ return binding.GetModelBeamWidth(this._impl);
+}
+
+/**
+ * Set beam width value used by the model.
+ *
+ * @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
+ *
+ * @return {number} Zero on success, non-zero on failure.
+ */
+Model.prototype.setBeamWidth = function(aBeamWidth) {
+ return binding.SetModelBeamWidth(this._impl, aBeamWidth);
+}
+
+/**
+ * Return the sample rate expected by the model.
+ *
+ * @return {number} Sample rate.
+ */
+Model.prototype.beamWidth = function() {
+ return binding.GetModelBeamWidth(this._impl);
+}
+
+
/**
* Return the sample rate expected by the model.
*
diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py
index ccb53fc4..855a6eeb 100644
--- a/native_client/python/__init__.py
+++ b/native_client/python/__init__.py
@@ -28,15 +28,12 @@ class Model(object):
:param aModelPath: Path to model file to load
:type aModelPath: str
-
- :param aBeamWidth: Decoder beam width
- :type aBeamWidth: int
"""
- def __init__(self, *args, **kwargs):
+ def __init__(self, model_path):
# make sure the attribute is there if CreateModel fails
self._impl = None
- status, impl = deepspeech.impl.CreateModel(*args, **kwargs)
+ status, impl = deepspeech.impl.CreateModel(model_path)
if status != 0:
raise RuntimeError("CreateModel failed with error code {}".format(status))
self._impl = impl
@@ -46,6 +43,29 @@ class Model(object):
deepspeech.impl.FreeModel(self._impl)
self._impl = None
+ def beamWidth(self):
+ """
+ Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
+ was not called before, will return the default value loaded from the
+ model file.
+
+ :return: Beam width value used by the model.
+ :type: int
+ """
+ return deepspeech.impl.GetModelBeamWidth(self._impl)
+
+ def setBeamWidth(self, beam_width):
+ """
+ Set beam width value used by the model.
+
+ :param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
+ :type beam_width: int
+
+ :return: Zero on success, non-zero on failure.
+ :type: int
+ """
+ return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
+
def sampleRate(self):
"""
Return the sample rate expected by the model.
diff --git a/native_client/python/client.py b/native_client/python/client.py
index 2ef88caf..26db1e00 100644
--- a/native_client/python/client.py
+++ b/native_client/python/client.py
@@ -92,7 +92,7 @@ def main():
help='Path to the external scorer file')
parser.add_argument('--audio', required=True,
help='Path to the audio file to run (WAV format)')
- parser.add_argument('--beam_width', type=int, default=500,
+ parser.add_argument('--beam_width', type=int,
help='Beam width for the CTC decoder')
parser.add_argument('--lm_alpha', type=float,
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
@@ -108,10 +108,13 @@ def main():
print('Loading model from file {}'.format(args.model), file=sys.stderr)
model_load_start = timer()
- ds = Model(args.model, args.beam_width)
+ ds = Model(args.model)
model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
+ if args.beam_width:
+ ds.setModelBeamWidth(args.beam_width)
+
desired_sample_rate = ds.sampleRate()
if args.scorer:
diff --git a/native_client/test/concurrent_streams.py b/native_client/test/concurrent_streams.py
index d799de36..e435b43f 100644
--- a/native_client/test/concurrent_streams.py
+++ b/native_client/test/concurrent_streams.py
@@ -9,12 +9,6 @@ import wave
from deepspeech import Model
-# These constants control the beam search decoder
-
-# Beam width used in the CTC decoder when building candidate transcriptions
-BEAM_WIDTH = 500
-
-
def main():
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
parser.add_argument('--model', required=True,
@@ -27,7 +21,7 @@ def main():
help='Second audio file to use in interleaved streams')
args = parser.parse_args()
- ds = Model(args.model, BEAM_WIDTH)
+ ds = Model(args.model)
if args.scorer:
ds.enableExternalScorer(args.scorer)