Multi-stream support .NET

Adds multi-stream support for the .NET client using the same acoustic model.
This commit is contained in:
Carlos Fonseca M 2019-11-29 08:06:18 -06:00 committed by Reuben Morais
parent fe2477b25c
commit 923729d920
7 changed files with 120 additions and 69 deletions

View File

@ -15,6 +15,13 @@ DeepSpeech Class
:project: deepspeech-dotnet :project: deepspeech-dotnet
:members: :members:
DeepSpeechStream Class
----------------
.. doxygenclass:: DeepSpeechClient::DeepSpeechStream
:project: deepspeech-dotnet
:members:
ErrorCodes ErrorCodes
---------- ----------

View File

@ -4,6 +4,7 @@ using DeepSpeechClient.Extensions;
using System; using System;
using System.IO; using System.IO;
using DeepSpeechClient.Enums; using DeepSpeechClient.Enums;
using DeepSpeechClient.Models;
namespace DeepSpeechClient namespace DeepSpeechClient
{ {
@ -13,14 +14,16 @@ namespace DeepSpeechClient
public class DeepSpeech : IDeepSpeech public class DeepSpeech : IDeepSpeech
{ {
private unsafe IntPtr** _modelStatePP; private unsafe IntPtr** _modelStatePP;
private unsafe IntPtr** _streamingStatePP;
/// <summary>
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
/// </summary>
public DeepSpeech() /// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
public DeepSpeech(string aModelPath, uint aBeamWidth)
{ {
CreateModel(aModelPath, aBeamWidth);
} }
#region IDeepSpeech #region IDeepSpeech
@ -31,7 +34,7 @@ namespace DeepSpeechClient
/// <param name="aModelPath">The path to the frozen model graph.</param> /// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param> /// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception> /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
public unsafe void CreateModel(string aModelPath, private unsafe void CreateModel(string aModelPath,
uint aBeamWidth) uint aBeamWidth)
{ {
string exceptionMessage = null; string exceptionMessage = null;
@ -118,10 +121,19 @@ namespace DeepSpeechClient
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param> /// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param> /// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception> /// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
public unsafe void EnableDecoderWithLM(string aLMPath, string aTriePath, public unsafe void EnableDecoderWithLM(string aLMPath, string aTriePath,
float aLMAlpha, float aLMBeta) float aLMAlpha, float aLMBeta)
{ {
string exceptionMessage = null; string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aLMPath))
{
exceptionMessage = "Path to the language model file cannot be empty.";
}
if (!File.Exists(aLMPath))
{
exceptionMessage = $"Cannot find the language model file: {aLMPath}";
}
if (string.IsNullOrWhiteSpace(aTriePath)) if (string.IsNullOrWhiteSpace(aTriePath))
{ {
exceptionMessage = "Path to the trie file cannot be empty."; exceptionMessage = "Path to the trie file cannot be empty.";
@ -147,37 +159,41 @@ namespace DeepSpeechClient
/// <summary> /// <summary>
/// Feeds audio samples to an ongoing streaming inference. /// Feeds audio samples to an ongoing streaming inference.
/// </summary> /// </summary>
/// <param name="stream">Instance of the stream to feed the data.</param>
/// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param>
public unsafe void FeedAudioContent(short[] aBuffer, uint aBufferSize) public unsafe void FeedAudioContent(DeepSpeechStream stream, short[] aBuffer, uint aBufferSize)
{ {
NativeImp.DS_FeedAudioContent(_streamingStatePP, aBuffer, aBufferSize); NativeImp.DS_FeedAudioContent(stream.GetNativePointer(), aBuffer, aBufferSize);
} }
/// <summary> /// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary> /// </summary>
/// <returns>The STT result. The user is responsible for freeing the string.</returns> /// <param name="stream">Instance of the stream to finish.</param>
public unsafe string FinishStream() /// <returns>The STT result.</returns>
public unsafe string FinishStream(DeepSpeechStream stream)
{ {
return NativeImp.DS_FinishStream(_streamingStatePP).PtrToString(); return NativeImp.DS_FinishStream(stream.GetNativePointer()).PtrToString();
} }
/// <summary> /// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary> /// </summary>
/// <returns>The extended metadata. The user is responsible for freeing the struct.</returns> /// <param name="stream">Instance of the stream to finish.</param>
public unsafe Models.Metadata FinishStreamWithMetadata() /// <returns>The extended metadata result.</returns>
public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream)
{ {
return NativeImp.DS_FinishStreamWithMetadata(_streamingStatePP).PtrToMetadata(); return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer()).PtrToMetadata();
} }
/// <summary> /// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference. /// Computes the intermediate decoding of an ongoing streaming inference.
/// </summary> /// </summary>
/// <returns>The STT intermediate result. The user is responsible for freeing the string.</returns> /// <param name="stream">Instance of the stream to decode.</param>
public unsafe string IntermediateDecode() /// <returns>The STT intermediate result.</returns>
public unsafe string IntermediateDecode(DeepSpeechStream stream)
{ {
return NativeImp.DS_IntermediateDecode(_streamingStatePP); return NativeImp.DS_IntermediateDecode(stream.GetNativePointer());
} }
/// <summary> /// <summary>
@ -191,11 +207,12 @@ namespace DeepSpeechClient
/// <summary> /// <summary>
/// Creates a new streaming inference state. /// Creates a new streaming inference state.
/// </summary> /// </summary>
/// <exception cref="ArgumentException">Thrown when the native binary failed to initialize the streaming mode.</exception> public unsafe DeepSpeechStream CreateStream()
public unsafe void CreateStream()
{ {
var resultCode = NativeImp.DS_CreateStream(_modelStatePP, ref _streamingStatePP); IntPtr** streamingStatePointer = null;
var resultCode = NativeImp.DS_CreateStream(_modelStatePP, ref streamingStatePointer);
EvaluateResultCode(resultCode); EvaluateResultCode(resultCode);
return new DeepSpeechStream(streamingStatePointer);
} }
/// <summary> /// <summary>
@ -203,13 +220,10 @@ namespace DeepSpeechClient
/// This can be used if you no longer need the result of an ongoing streaming /// This can be used if you no longer need the result of an ongoing streaming
/// inference and don't want to perform a costly decode operation. /// inference and don't want to perform a costly decode operation.
/// </summary> /// </summary>
public unsafe void FreeStream() public unsafe void FreeStream(DeepSpeechStream stream)
{ {
NativeImp.DS_FreeStream(ref _streamingStatePP); NativeImp.DS_FreeStream(stream.GetNativePointer());
} stream.Dispose();
{
NativeImp.DS_FreeMetadata(intPtr);
} }
/// <summary> /// <summary>
@ -217,7 +231,7 @@ namespace DeepSpeechClient
/// </summary> /// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param> /// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The STT result. The user is responsible for freeing the string. Returns NULL on error.</returns> /// <returns>The STT result. Returns NULL on error.</returns>
public unsafe string SpeechToText(short[] aBuffer, uint aBufferSize) public unsafe string SpeechToText(short[] aBuffer, uint aBufferSize)
{ {
return NativeImp.DS_SpeechToText(_modelStatePP, aBuffer, aBufferSize).PtrToString(); return NativeImp.DS_SpeechToText(_modelStatePP, aBuffer, aBufferSize).PtrToString();
@ -228,8 +242,8 @@ namespace DeepSpeechClient
/// </summary> /// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param> /// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The extended metadata. The user is responsible for freeing the struct. Returns NULL on error.</returns> /// <returns>The extended metadata. Returns NULL on error.</returns>
public unsafe Models.Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize) public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize)
{ {
return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata(); return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata();
} }

View File

@ -48,6 +48,7 @@
<Compile Include="Enums\ErrorCodes.cs" /> <Compile Include="Enums\ErrorCodes.cs" />
<Compile Include="Interfaces\IDeepSpeech.cs" /> <Compile Include="Interfaces\IDeepSpeech.cs" />
<Compile Include="Extensions\NativeExtensions.cs" /> <Compile Include="Extensions\NativeExtensions.cs" />
<Compile Include="Models\DeepSpeechStream.cs" />
<Compile Include="Models\Metadata.cs" /> <Compile Include="Models\Metadata.cs" />
<Compile Include="Models\MetadataItem.cs" /> <Compile Include="Models\MetadataItem.cs" />
<Compile Include="NativeImp.cs" /> <Compile Include="NativeImp.cs" />

View File

@ -1,10 +1,11 @@
using DeepSpeechClient.Models; using DeepSpeechClient.Models;
using System; using System;
using System.IO;
namespace DeepSpeechClient.Interfaces namespace DeepSpeechClient.Interfaces
{ {
/// <summary> /// <summary>
/// Client interface of the Mozilla's deepspeech implementation. /// Client interface of the Mozilla's DeepSpeech implementation.
/// </summary> /// </summary>
public interface IDeepSpeech : IDisposable public interface IDeepSpeech : IDisposable
{ {
@ -13,15 +14,6 @@ namespace DeepSpeechClient.Interfaces
/// </summary> /// </summary>
void PrintVersions(); void PrintVersions();
/// <summary>
/// Create an object providing an interface to a trained DeepSpeech model.
/// </summary>
/// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
unsafe void CreateModel(string aModelPath,
uint aBeamWidth);
/// <summary> /// <summary>
/// Return the sample rate expected by the model. /// Return the sample rate expected by the model.
/// </summary> /// </summary>
@ -36,6 +28,7 @@ namespace DeepSpeechClient.Interfaces
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param> /// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param> /// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception> /// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
unsafe void EnableDecoderWithLM(string aLMPath, unsafe void EnableDecoderWithLM(string aLMPath,
string aTriePath, string aTriePath,
float aLMAlpha, float aLMAlpha,
@ -46,7 +39,7 @@ namespace DeepSpeechClient.Interfaces
/// </summary> /// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param> /// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The STT result. The user is responsible for freeing the string. Returns NULL on error.</returns> /// <returns>The STT result. Returns NULL on error.</returns>
unsafe string SpeechToText(short[] aBuffer, unsafe string SpeechToText(short[] aBuffer,
uint aBufferSize); uint aBufferSize);
@ -55,7 +48,7 @@ namespace DeepSpeechClient.Interfaces
/// </summary> /// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param> /// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The extended metadata result. The user is responsible for freeing the struct. Returns NULL on error.</returns> /// <returns>The extended metadata. Returns NULL on error.</returns>
unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer,
uint aBufferSize); uint aBufferSize);
@ -64,37 +57,39 @@ namespace DeepSpeechClient.Interfaces
/// This can be used if you no longer need the result of an ongoing streaming /// This can be used if you no longer need the result of an ongoing streaming
/// inference and don't want to perform a costly decode operation. /// inference and don't want to perform a costly decode operation.
/// </summary> /// </summary>
unsafe void FreeStream(); unsafe void FreeStream(DeepSpeechStream stream);
/// <summary> /// <summary>
/// Creates a new streaming inference state. /// Creates a new streaming inference state.
/// </summary> /// </summary>
/// <exception cref="ArgumentException">Thrown when the native binary failed to initialize the streaming mode.</exception> unsafe DeepSpeechStream CreateStream();
unsafe void CreateStream();
/// <summary> /// <summary>
/// Feeds audio samples to an ongoing streaming inference. /// Feeds audio samples to an ongoing streaming inference.
/// </summary> /// </summary>
/// <param name="stream">Instance of the stream to feed the data.</param>
/// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param> /// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param>
unsafe void FeedAudioContent(short[] aBuffer, uint aBufferSize); unsafe void FeedAudioContent(DeepSpeechStream stream, short[] aBuffer, uint aBufferSize);
/// <summary> /// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference. /// Computes the intermediate decoding of an ongoing streaming inference.
/// </summary> /// </summary>
/// <returns>The STT intermediate result. The user is responsible for freeing the string.</returns> /// <param name="stream">Instance of the stream to decode.</param>
unsafe string IntermediateDecode(); /// <returns>The STT intermediate result.</returns>
unsafe string IntermediateDecode(DeepSpeechStream stream);
/// <summary> /// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary> /// </summary>
/// <returns>The STT result. The user is responsible for freeing the string.</returns> /// <param name="stream">Instance of the stream to finish.</param>
unsafe string FinishStream(); /// <returns>The STT result.</returns>
unsafe string FinishStream(DeepSpeechStream stream);
/// <summary> /// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary> /// </summary>
/// <returns>The extended metadata result. The user is responsible for freeing the struct.</returns> /// <param name="stream">Instance of the stream to finish.</param>
unsafe Metadata FinishStreamWithMetadata(); /// <returns>The extended metadata result.</returns>
unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream);
} }
} }

View File

@ -0,0 +1,35 @@
using System;
namespace DeepSpeechClient.Models
{
/// <summary>
/// Wrapper of the pointer used for the decoding stream.
/// </summary>
public class DeepSpeechStream : IDisposable
{
private unsafe IntPtr** _streamingStatePp;
/// <summary>
/// Initializes a new instance of <see cref="DeepSpeechStream"/>.
/// </summary>
/// <param name="streamingStatePP">Native pointer of the native stream.</param>
public unsafe DeepSpeechStream(IntPtr** streamingStatePP)
{
_streamingStatePp = streamingStatePP;
}
/// <summary>
/// Gets the native pointer.
/// </summary>
/// <exception cref="InvalidOperationException">Thrown when the stream has been disposed or not yet initialized.</exception>
/// <returns>Native pointer of the stream.</returns>
internal unsafe IntPtr** GetNativePointer()
{
if (_streamingStatePp == null)
throw new InvalidOperationException("Cannot use a disposed or uninitialized stream.");
return _streamingStatePp;
}
public unsafe void Dispose() => _streamingStatePp = null;
}
}

View File

@ -48,7 +48,7 @@ namespace DeepSpeechClient
ref IntPtr** retval); ref IntPtr** retval);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern void DS_FreeStream(ref IntPtr** aSctx); internal static unsafe extern void DS_FreeStream(IntPtr** aSctx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern void DS_FreeMetadata(IntPtr metadata); internal static unsafe extern void DS_FreeMetadata(IntPtr metadata);

View File

@ -53,16 +53,13 @@ namespace CSharpExamples
const float LM_BETA = 1.85f; const float LM_BETA = 1.85f;
Stopwatch stopwatch = new Stopwatch(); Stopwatch stopwatch = new Stopwatch();
using (IDeepSpeech sttClient = new DeepSpeech())
{
try try
{ {
Console.WriteLine("Loading model..."); Console.WriteLine("Loading model...");
stopwatch.Start(); stopwatch.Start();
sttClient.CreateModel( using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
model ?? "output_graph.pbmm", BEAM_WIDTH))
BEAM_WIDTH); {
stopwatch.Stop(); stopwatch.Stop();
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms"); Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
@ -88,12 +85,14 @@ namespace CSharpExamples
string speechResult; string speechResult;
if (extended) if (extended)
{ {
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer, Convert.ToUInt32(waveBuffer.MaxSize / 2)); Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
speechResult = MetadataToString(metaResult); speechResult = MetadataToString(metaResult);
} }
else else
{ {
speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer, Convert.ToUInt32(waveBuffer.MaxSize / 2)); speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
} }
stopwatch.Stop(); stopwatch.Stop();
@ -104,6 +103,7 @@ namespace CSharpExamples
} }
waveBuffer.Clear(); waveBuffer.Clear();
} }
}
catch (Exception ex) catch (Exception ex)
{ {
Console.WriteLine(ex.Message); Console.WriteLine(ex.Message);
@ -111,4 +111,3 @@ namespace CSharpExamples
} }
} }
} }
}