Fix consumers of DS_CreateModel
This commit is contained in:
		
							parent
							
								
									8e9b6ef7b3
								
							
						
					
					
						commit
						c512383aec
					
				@ -877,7 +877,7 @@ def package_zip():
 | 
				
			|||||||
        json.dump({
 | 
					        json.dump({
 | 
				
			||||||
            'name': FLAGS.export_language,
 | 
					            'name': FLAGS.export_language,
 | 
				
			||||||
            'parameters': {
 | 
					            'parameters': {
 | 
				
			||||||
                'beamWidth': FLAGS.beam_width,
 | 
					                'beamWidth': FLAGS.export_beam_width,
 | 
				
			||||||
                'lmAlpha': FLAGS.lm_alpha,
 | 
					                'lmAlpha': FLAGS.lm_alpha,
 | 
				
			||||||
                'lmBeta': FLAGS.lm_beta
 | 
					                'lmBeta': FLAGS.lm_beta
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
				
			|||||||
@ -36,7 +36,8 @@ LM_BETA = 1.85
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
 | 
					def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
 | 
				
			||||||
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
 | 
					    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
 | 
				
			||||||
    ds = Model(model, BEAM_WIDTH)
 | 
					    ds = Model(model)
 | 
				
			||||||
 | 
					    ds.setBeamWidth(BEAM_WIDTH)
 | 
				
			||||||
    ds.enableExternalScorer(scorer)
 | 
					    ds.enableExternalScorer(scorer)
 | 
				
			||||||
    ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
 | 
					    ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,9 @@ char* scorer = NULL;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
char* audio = NULL;
 | 
					char* audio = NULL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int beam_width = 500;
 | 
					bool set_beamwidth = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int beam_width = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool set_alphabeta = false;
 | 
					bool set_alphabeta = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv)
 | 
				
			|||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        case 'b':
 | 
					        case 'b':
 | 
				
			||||||
 | 
					            set_beamwidth = true;
 | 
				
			||||||
            beam_width = atoi(optarg);
 | 
					            beam_width = atoi(optarg);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -368,14 +368,22 @@ main(int argc, char **argv)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Initialise DeepSpeech
 | 
					  // Initialise DeepSpeech
 | 
				
			||||||
  ModelState* ctx;
 | 
					  ModelState* ctx;
 | 
				
			||||||
  int status = DS_CreateModel(model, beam_width, &ctx);
 | 
					  int status = DS_CreateModel(model, &ctx);
 | 
				
			||||||
  if (status != 0) {
 | 
					  if (status != 0) {
 | 
				
			||||||
    fprintf(stderr, "Could not create model.\n");
 | 
					    fprintf(stderr, "Could not create model.\n");
 | 
				
			||||||
    return 1;
 | 
					    return 1;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (set_beamwidth) {
 | 
				
			||||||
 | 
					    status = DS_SetModelBeamWidth(ctx, beam_width);
 | 
				
			||||||
 | 
					    if (status != 0) {
 | 
				
			||||||
 | 
					      fprintf(stderr, "Could not set model beam width.\n");
 | 
				
			||||||
 | 
					      return 1;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (scorer) {
 | 
					  if (scorer) {
 | 
				
			||||||
    int status = DS_EnableExternalScorer(ctx, scorer);
 | 
					    status = DS_EnableExternalScorer(ctx, scorer);
 | 
				
			||||||
    if (status != 0) {
 | 
					    if (status != 0) {
 | 
				
			||||||
      fprintf(stderr, "Could not enable external scorer.\n");
 | 
					      fprintf(stderr, "Could not enable external scorer.\n");
 | 
				
			||||||
      return 1;
 | 
					      return 1;
 | 
				
			||||||
 | 
				
			|||||||
@ -19,11 +19,10 @@ namespace DeepSpeechClient
 | 
				
			|||||||
        /// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
 | 
					        /// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
 | 
				
			||||||
        /// </summary>
 | 
					        /// </summary>
 | 
				
			||||||
        /// <param name="aModelPath">The path to the frozen model graph.</param>
 | 
					        /// <param name="aModelPath">The path to the frozen model graph.</param>
 | 
				
			||||||
        /// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
 | 
					 | 
				
			||||||
        /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
 | 
					        /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
 | 
				
			||||||
        public DeepSpeech(string aModelPath, uint aBeamWidth)
 | 
					        public DeepSpeech(string aModelPath)
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            CreateModel(aModelPath, aBeamWidth);
 | 
					            CreateModel(aModelPath);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #region IDeepSpeech
 | 
					        #region IDeepSpeech
 | 
				
			||||||
@ -32,10 +31,8 @@ namespace DeepSpeechClient
 | 
				
			|||||||
        /// Create an object providing an interface to a trained DeepSpeech model.
 | 
					        /// Create an object providing an interface to a trained DeepSpeech model.
 | 
				
			||||||
        /// </summary>
 | 
					        /// </summary>
 | 
				
			||||||
        /// <param name="aModelPath">The path to the frozen model graph.</param>
 | 
					        /// <param name="aModelPath">The path to the frozen model graph.</param>
 | 
				
			||||||
        /// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
 | 
					 | 
				
			||||||
        /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
 | 
					        /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
 | 
				
			||||||
        private unsafe void CreateModel(string aModelPath,
 | 
					        private unsafe void CreateModel(string aModelPath)
 | 
				
			||||||
            uint aBeamWidth)
 | 
					 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            string exceptionMessage = null;
 | 
					            string exceptionMessage = null;
 | 
				
			||||||
            if (string.IsNullOrWhiteSpace(aModelPath))
 | 
					            if (string.IsNullOrWhiteSpace(aModelPath))
 | 
				
			||||||
@ -52,11 +49,31 @@ namespace DeepSpeechClient
 | 
				
			|||||||
                throw new FileNotFoundException(exceptionMessage);
 | 
					                throw new FileNotFoundException(exceptionMessage);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            var resultCode = NativeImp.DS_CreateModel(aModelPath,
 | 
					            var resultCode = NativeImp.DS_CreateModel(aModelPath,
 | 
				
			||||||
                            aBeamWidth,
 | 
					 | 
				
			||||||
                            ref _modelStatePP);
 | 
					                            ref _modelStatePP);
 | 
				
			||||||
            EvaluateResultCode(resultCode);
 | 
					            EvaluateResultCode(resultCode);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /// <summary>
 | 
				
			||||||
 | 
					        /// Get beam width value used by the model. If SetModelBeamWidth was not
 | 
				
			||||||
 | 
					        /// called before, will return the default value loaded from the model file.
 | 
				
			||||||
 | 
					        /// </summary>
 | 
				
			||||||
 | 
					        /// <returns>Beam width value used by the model.</returns>
 | 
				
			||||||
 | 
					        public unsafe uint GetModelBeamWidth()
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            return NativeImp.DS_GetModelBeamWidth(_modelStatePP);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /// <summary>
 | 
				
			||||||
 | 
					        /// Set beam width value used by the model.
 | 
				
			||||||
 | 
					        /// </summary>
 | 
				
			||||||
 | 
					        /// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
 | 
				
			||||||
 | 
					        /// <exception cref="ArgumentException">Thrown on failure.</exception>
 | 
				
			||||||
 | 
					        public unsafe void SetModelBeamWidth(uint aBeamWidth)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth);
 | 
				
			||||||
 | 
					            EvaluateResultCode(resultCode);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /// <summary>
 | 
					        /// <summary>
 | 
				
			||||||
        /// Return the sample rate expected by the model.
 | 
					        /// Return the sample rate expected by the model.
 | 
				
			||||||
        /// </summary>
 | 
					        /// </summary>
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces
 | 
				
			|||||||
        /// <returns>Sample rate.</returns>
 | 
					        /// <returns>Sample rate.</returns>
 | 
				
			||||||
        unsafe int GetModelSampleRate();
 | 
					        unsafe int GetModelSampleRate();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /// <summary>
 | 
				
			||||||
 | 
					        /// Get beam width value used by the model. If SetModelBeamWidth was not
 | 
				
			||||||
 | 
					        /// called before, will return the default value loaded from the model
 | 
				
			||||||
 | 
					        /// file.
 | 
				
			||||||
 | 
					        /// </summary>
 | 
				
			||||||
 | 
					        /// <returns>Beam width value used by the model.</returns>
 | 
				
			||||||
 | 
					        unsafe uint GetModelBeamWidth();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /// <summary>
 | 
				
			||||||
 | 
					        /// Set beam width value used by the model.
 | 
				
			||||||
 | 
					        /// </summary>
 | 
				
			||||||
 | 
					        /// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
 | 
				
			||||||
 | 
					        /// <exception cref="ArgumentException">Thrown on failure.</exception>
 | 
				
			||||||
 | 
					        unsafe void SetModelBeamWidth(uint aBeamWidth);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /// <summary>
 | 
					        /// <summary>
 | 
				
			||||||
        /// Enable decoding using an external scorer.
 | 
					        /// Enable decoding using an external scorer.
 | 
				
			||||||
        /// </summary>
 | 
					        /// </summary>
 | 
				
			||||||
 | 
				
			|||||||
@ -14,6 +14,17 @@ namespace DeepSpeechClient
 | 
				
			|||||||
        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
					        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
				
			||||||
        internal static extern void DS_PrintVersions();
 | 
					        internal static extern void DS_PrintVersions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
				
			||||||
 | 
					        internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
 | 
				
			||||||
 | 
					                   ref IntPtr** pint);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
				
			||||||
 | 
					        internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
				
			||||||
 | 
					        internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
 | 
				
			||||||
 | 
					                   uint aBeamWidth);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
					        [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
 | 
				
			||||||
        internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
 | 
					        internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
 | 
				
			||||||
                   uint aBeamWidth,
 | 
					                   uint aBeamWidth,
 | 
				
			||||||
 | 
				
			|||||||
@ -46,15 +46,12 @@ namespace CSharpExamples
 | 
				
			|||||||
                extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
 | 
					                extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            const uint BEAM_WIDTH = 500;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            Stopwatch stopwatch = new Stopwatch();
 | 
					            Stopwatch stopwatch = new Stopwatch();
 | 
				
			||||||
            try
 | 
					            try
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                Console.WriteLine("Loading model...");
 | 
					                Console.WriteLine("Loading model...");
 | 
				
			||||||
                stopwatch.Start();
 | 
					                stopwatch.Start();
 | 
				
			||||||
                using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
 | 
					                using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
 | 
				
			||||||
                    BEAM_WIDTH))
 | 
					 | 
				
			||||||
                {
 | 
					                {
 | 
				
			||||||
                    stopwatch.Stop();
 | 
					                    stopwatch.Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
 | 
				
			|||||||
    private void newModel(String tfliteModel) {
 | 
					    private void newModel(String tfliteModel) {
 | 
				
			||||||
        this._tfliteStatus.setText("Creating model");
 | 
					        this._tfliteStatus.setText("Creating model");
 | 
				
			||||||
        if (this._m == null) {
 | 
					        if (this._m == null) {
 | 
				
			||||||
            this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH);
 | 
					            this._m = new DeepSpeechModel(tfliteModel);
 | 
				
			||||||
 | 
					            this._m.setBeamWidth(BEAM_WIDTH);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -59,7 +59,7 @@ public class BasicTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void loadDeepSpeech_basic() {
 | 
					    public void loadDeepSpeech_basic() {
 | 
				
			||||||
        DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
 | 
					        DeepSpeechModel m = new DeepSpeechModel(modelFile);
 | 
				
			||||||
        m.freeModel();
 | 
					        m.freeModel();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -116,7 +116,8 @@ public class BasicTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void loadDeepSpeech_stt_noLM() {
 | 
					    public void loadDeepSpeech_stt_noLM() {
 | 
				
			||||||
        DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
 | 
					        DeepSpeechModel m = new DeepSpeechModel(modelFile);
 | 
				
			||||||
 | 
					        m.setBeamWidth(BEAM_WIDTH);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String decoded = doSTT(m, false);
 | 
					        String decoded = doSTT(m, false);
 | 
				
			||||||
        assertEquals("she had your dark suit in greasy wash water all year", decoded);
 | 
					        assertEquals("she had your dark suit in greasy wash water all year", decoded);
 | 
				
			||||||
@ -125,7 +126,8 @@ public class BasicTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void loadDeepSpeech_stt_withLM() {
 | 
					    public void loadDeepSpeech_stt_withLM() {
 | 
				
			||||||
        DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
 | 
					        DeepSpeechModel m = new DeepSpeechModel(modelFile);
 | 
				
			||||||
 | 
					        m.setBeamWidth(BEAM_WIDTH);
 | 
				
			||||||
        m.enableExternalScorer(scorerFile);
 | 
					        m.enableExternalScorer(scorerFile);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String decoded = doSTT(m, false);
 | 
					        String decoded = doSTT(m, false);
 | 
				
			||||||
@ -135,7 +137,8 @@ public class BasicTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void loadDeepSpeech_sttWithMetadata_noLM() {
 | 
					    public void loadDeepSpeech_sttWithMetadata_noLM() {
 | 
				
			||||||
        DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
 | 
					        DeepSpeechModel m = new DeepSpeechModel(modelFile);
 | 
				
			||||||
 | 
					        m.setBeamWidth(BEAM_WIDTH);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String decoded = doSTT(m, true);
 | 
					        String decoded = doSTT(m, true);
 | 
				
			||||||
        assertEquals("she had your dark suit in greasy wash water all year", decoded);
 | 
					        assertEquals("she had your dark suit in greasy wash water all year", decoded);
 | 
				
			||||||
@ -144,7 +147,8 @@ public class BasicTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void loadDeepSpeech_sttWithMetadata_withLM() {
 | 
					    public void loadDeepSpeech_sttWithMetadata_withLM() {
 | 
				
			||||||
        DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
 | 
					        DeepSpeechModel m = new DeepSpeechModel(modelFile);
 | 
				
			||||||
 | 
					        m.setBeamWidth(BEAM_WIDTH);
 | 
				
			||||||
        m.enableExternalScorer(scorerFile);
 | 
					        m.enableExternalScorer(scorerFile);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String decoded = doSTT(m, true);
 | 
					        String decoded = doSTT(m, true);
 | 
				
			||||||
 | 
				
			|||||||
@ -20,16 +20,35 @@ public class DeepSpeechModel {
 | 
				
			|||||||
    * @constructor
 | 
					    * @constructor
 | 
				
			||||||
    *
 | 
					    *
 | 
				
			||||||
    * @param modelPath The path to the frozen model graph.
 | 
					    * @param modelPath The path to the frozen model graph.
 | 
				
			||||||
    * @param beam_width The beam width used by the decoder. A larger beam
 | 
					 | 
				
			||||||
    *                   width generates better results at the cost of decoding
 | 
					 | 
				
			||||||
    *                   time.
 | 
					 | 
				
			||||||
    */
 | 
					    */
 | 
				
			||||||
    public DeepSpeechModel(String modelPath, int beam_width) {
 | 
					    public DeepSpeechModel(String modelPath) {
 | 
				
			||||||
        this._mspp = impl.new_modelstatep();
 | 
					        this._mspp = impl.new_modelstatep();
 | 
				
			||||||
        impl.CreateModel(modelPath, beam_width, this._mspp);
 | 
					        impl.CreateModel(modelPath, this._mspp);
 | 
				
			||||||
        this._msp  = impl.modelstatep_value(this._mspp);
 | 
					        this._msp  = impl.modelstatep_value(this._mspp);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   /**
 | 
				
			||||||
 | 
					    * @brief Get beam width value used by the model. If setModelBeamWidth was not
 | 
				
			||||||
 | 
					    *        called before, will return the default value loaded from the model file.
 | 
				
			||||||
 | 
					    *
 | 
				
			||||||
 | 
					    * @return Beam width value used by the model.
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					    public int beamWidth() {
 | 
				
			||||||
 | 
					        return impl.GetModelBeamWidth(this._msp);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * @brief Set beam width value used by the model.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param aBeamWidth The beam width used by the model. A larger beam width value
 | 
				
			||||||
 | 
					     *                   generates better results at the cost of decoding time.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @return Zero on success, non-zero on failure.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public int setBeamWidth(int beamWidth) {
 | 
				
			||||||
 | 
					        return impl.SetModelBeamWidth(this._msp, beamWidth);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   /**
 | 
					   /**
 | 
				
			||||||
    * @brief Return the sample rate expected by the model.
 | 
					    * @brief Return the sample rate expected by the model.
 | 
				
			||||||
    *
 | 
					    *
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D
 | 
				
			|||||||
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
 | 
					parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
 | 
				
			||||||
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
 | 
					parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
 | 
				
			||||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
 | 
					parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
 | 
				
			||||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
 | 
					parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
 | 
				
			||||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
 | 
					parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
 | 
				
			||||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
 | 
					parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
 | 
				
			||||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
 | 
					parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
 | 
				
			||||||
@ -53,10 +53,14 @@ function metadataToString(metadata) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
console.error('Loading model from file %s', args['model']);
 | 
					console.error('Loading model from file %s', args['model']);
 | 
				
			||||||
const model_load_start = process.hrtime();
 | 
					const model_load_start = process.hrtime();
 | 
				
			||||||
var model = new Ds.Model(args['model'], args['beam_width']);
 | 
					var model = new Ds.Model(args['model']);
 | 
				
			||||||
const model_load_end = process.hrtime(model_load_start);
 | 
					const model_load_end = process.hrtime(model_load_start);
 | 
				
			||||||
console.error('Loaded model in %ds.', totalTime(model_load_end));
 | 
					console.error('Loaded model in %ds.', totalTime(model_load_end));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if (args['beam_width']) {
 | 
				
			||||||
 | 
					  model.setBeamWidth(args['beam_width']);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var desired_sample_rate = model.sampleRate();
 | 
					var desired_sample_rate = model.sampleRate();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if (args['scorer']) {
 | 
					if (args['scorer']) {
 | 
				
			||||||
 | 
				
			|||||||
@ -25,14 +25,13 @@ if (process.platform === 'win32') {
 | 
				
			|||||||
 * An object providing an interface to a trained DeepSpeech model.
 | 
					 * An object providing an interface to a trained DeepSpeech model.
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 * @param {string} aModelPath The path to the frozen model graph.
 | 
					 * @param {string} aModelPath The path to the frozen model graph.
 | 
				
			||||||
 * @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
 | 
					 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 * @throws on error
 | 
					 * @throws on error
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
function Model() {
 | 
					function Model(aModelPath) {
 | 
				
			||||||
    this._impl = null;
 | 
					    this._impl = null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const rets = binding.CreateModel.apply(null, arguments);
 | 
					    const rets = binding.CreateModel(aModelPath);
 | 
				
			||||||
    const status = rets[0];
 | 
					    const status = rets[0];
 | 
				
			||||||
    const impl = rets[1];
 | 
					    const impl = rets[1];
 | 
				
			||||||
    if (status !== 0) {
 | 
					    if (status !== 0) {
 | 
				
			||||||
@ -42,6 +41,38 @@ function Model() {
 | 
				
			|||||||
    this._impl = impl;
 | 
					    this._impl = impl;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
 | 
				
			||||||
 | 
					 * was not called before, will return the default value loaded from the
 | 
				
			||||||
 | 
					 * model file.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return {number} Beam width value used by the model.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					Model.prototype.beamWidth = function() {
 | 
				
			||||||
 | 
					    return binding.GetModelBeamWidth(this._impl);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Set beam width value used by the model.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return {number} Zero on success, non-zero on failure.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					Model.prototype.setBeamWidth = function(aBeamWidth) {
 | 
				
			||||||
 | 
					    return binding.SetModelBeamWidth(this._impl, aBeamWidth);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Return the sample rate expected by the model.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return {number} Sample rate.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					Model.prototype.beamWidth = function() {
 | 
				
			||||||
 | 
					    return binding.GetModelBeamWidth(this._impl);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Return the sample rate expected by the model.
 | 
					 * Return the sample rate expected by the model.
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 | 
				
			|||||||
@ -28,15 +28,12 @@ class Model(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    :param aModelPath: Path to model file to load
 | 
					    :param aModelPath: Path to model file to load
 | 
				
			||||||
    :type aModelPath: str
 | 
					    :type aModelPath: str
 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param aBeamWidth: Decoder beam width
 | 
					 | 
				
			||||||
    :type aBeamWidth: int
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self,  *args, **kwargs):
 | 
					    def __init__(self, model_path):
 | 
				
			||||||
        # make sure the attribute is there if CreateModel fails
 | 
					        # make sure the attribute is there if CreateModel fails
 | 
				
			||||||
        self._impl = None
 | 
					        self._impl = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        status, impl = deepspeech.impl.CreateModel(*args, **kwargs)
 | 
					        status, impl = deepspeech.impl.CreateModel(model_path)
 | 
				
			||||||
        if status != 0:
 | 
					        if status != 0:
 | 
				
			||||||
            raise RuntimeError("CreateModel failed with error code {}".format(status))
 | 
					            raise RuntimeError("CreateModel failed with error code {}".format(status))
 | 
				
			||||||
        self._impl = impl
 | 
					        self._impl = impl
 | 
				
			||||||
@ -46,6 +43,29 @@ class Model(object):
 | 
				
			|||||||
            deepspeech.impl.FreeModel(self._impl)
 | 
					            deepspeech.impl.FreeModel(self._impl)
 | 
				
			||||||
            self._impl = None
 | 
					            self._impl = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def beamWidth(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
 | 
				
			||||||
 | 
					        was not called before, will return the default value loaded from the
 | 
				
			||||||
 | 
					        model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: Beam width value used by the model.
 | 
				
			||||||
 | 
					        :type: int
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return deepspeech.impl.GetModelBeamWidth(self._impl)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setBeamWidth(self, beam_width):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Set beam width value used by the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
 | 
				
			||||||
 | 
					        :type beam_width: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: Zero on success, non-zero on failure.
 | 
				
			||||||
 | 
					        :type: int
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sampleRate(self):
 | 
					    def sampleRate(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Return the sample rate expected by the model.
 | 
					        Return the sample rate expected by the model.
 | 
				
			||||||
 | 
				
			|||||||
@ -92,7 +92,7 @@ def main():
 | 
				
			|||||||
                        help='Path to the external scorer file')
 | 
					                        help='Path to the external scorer file')
 | 
				
			||||||
    parser.add_argument('--audio', required=True,
 | 
					    parser.add_argument('--audio', required=True,
 | 
				
			||||||
                        help='Path to the audio file to run (WAV format)')
 | 
					                        help='Path to the audio file to run (WAV format)')
 | 
				
			||||||
    parser.add_argument('--beam_width', type=int, default=500,
 | 
					    parser.add_argument('--beam_width', type=int,
 | 
				
			||||||
                        help='Beam width for the CTC decoder')
 | 
					                        help='Beam width for the CTC decoder')
 | 
				
			||||||
    parser.add_argument('--lm_alpha', type=float,
 | 
					    parser.add_argument('--lm_alpha', type=float,
 | 
				
			||||||
                        help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
 | 
					                        help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
 | 
				
			||||||
@ -108,10 +108,13 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    print('Loading model from file {}'.format(args.model), file=sys.stderr)
 | 
					    print('Loading model from file {}'.format(args.model), file=sys.stderr)
 | 
				
			||||||
    model_load_start = timer()
 | 
					    model_load_start = timer()
 | 
				
			||||||
    ds = Model(args.model, args.beam_width)
 | 
					    ds = Model(args.model)
 | 
				
			||||||
    model_load_end = timer() - model_load_start
 | 
					    model_load_end = timer() - model_load_start
 | 
				
			||||||
    print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
 | 
					    print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.beam_width:
 | 
				
			||||||
 | 
					        ds.setModelBeamWidth(args.beam_width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    desired_sample_rate = ds.sampleRate()
 | 
					    desired_sample_rate = ds.sampleRate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.scorer:
 | 
					    if args.scorer:
 | 
				
			||||||
 | 
				
			|||||||
@ -9,12 +9,6 @@ import wave
 | 
				
			|||||||
from deepspeech import Model
 | 
					from deepspeech import Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# These constants control the beam search decoder
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Beam width used in the CTC decoder when building candidate transcriptions
 | 
					 | 
				
			||||||
BEAM_WIDTH = 500
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
 | 
					    parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
 | 
				
			||||||
    parser.add_argument('--model', required=True,
 | 
					    parser.add_argument('--model', required=True,
 | 
				
			||||||
@ -27,7 +21,7 @@ def main():
 | 
				
			|||||||
                        help='Second audio file to use in interleaved streams')
 | 
					                        help='Second audio file to use in interleaved streams')
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ds = Model(args.model, BEAM_WIDTH)
 | 
					    ds = Model(args.model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.scorer:
 | 
					    if args.scorer:
 | 
				
			||||||
        ds.enableExternalScorer(args.scorer)
 | 
					        ds.enableExternalScorer(args.scorer)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user