Merge pull request #2679 from mozilla/pr-2548-multistream

Re-land PR #2548 multistream support for .NET bindings
This commit is contained in:
Reuben Morais 2020-01-18 15:04:21 +01:00 committed by GitHub
commit 3cea430f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 120 additions and 90 deletions

View File

@ -15,6 +15,13 @@ DeepSpeech Class
:project: deepspeech-dotnet
:members:
DeepSpeechStream Class
----------------
.. doxygenclass:: DeepSpeechClient::DeepSpeechStream
:project: deepspeech-dotnet
:members:
ErrorCodes
----------

View File

@ -4,6 +4,7 @@ using DeepSpeechClient.Extensions;
using System;
using System.IO;
using DeepSpeechClient.Enums;
using DeepSpeechClient.Models;
namespace DeepSpeechClient
{
@ -13,14 +14,16 @@ namespace DeepSpeechClient
public class DeepSpeech : IDeepSpeech
{
private unsafe IntPtr** _modelStatePP;
private unsafe IntPtr** _streamingStatePP;
public DeepSpeech()
/// <summary>
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
/// </summary>
/// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
public DeepSpeech(string aModelPath, uint aBeamWidth)
{
CreateModel(aModelPath, aBeamWidth);
}
#region IDeepSpeech
@ -31,7 +34,7 @@ namespace DeepSpeechClient
/// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
public unsafe void CreateModel(string aModelPath,
private unsafe void CreateModel(string aModelPath,
uint aBeamWidth)
{
string exceptionMessage = null;
@ -118,10 +121,19 @@ namespace DeepSpeechClient
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
public unsafe void EnableDecoderWithLM(string aLMPath, string aTriePath,
float aLMAlpha, float aLMBeta)
{
string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aLMPath))
{
exceptionMessage = "Path to the language model file cannot be empty.";
}
if (!File.Exists(aLMPath))
{
exceptionMessage = $"Cannot find the language model file: {aLMPath}";
}
if (string.IsNullOrWhiteSpace(aTriePath))
{
exceptionMessage = "Path to the trie file cannot be empty.";
@ -147,37 +159,41 @@ namespace DeepSpeechClient
/// <summary>
/// Feeds audio samples to an ongoing streaming inference.
/// </summary>
/// <param name="stream">Instance of the stream to feed the data.</param>
/// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param>
public unsafe void FeedAudioContent(short[] aBuffer, uint aBufferSize)
public unsafe void FeedAudioContent(DeepSpeechStream stream, short[] aBuffer, uint aBufferSize)
{
NativeImp.DS_FeedAudioContent(_streamingStatePP, aBuffer, aBufferSize);
NativeImp.DS_FeedAudioContent(stream.GetNativePointer(), aBuffer, aBufferSize);
}
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary>
/// <returns>The STT result. The user is responsible for freeing the string.</returns>
public unsafe string FinishStream()
/// <param name="stream">Instance of the stream to finish.</param>
/// <returns>The STT result.</returns>
public unsafe string FinishStream(DeepSpeechStream stream)
{
return NativeImp.DS_FinishStream(_streamingStatePP).PtrToString();
return NativeImp.DS_FinishStream(stream.GetNativePointer()).PtrToString();
}
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary>
/// <returns>The extended metadata. The user is responsible for freeing the struct.</returns>
public unsafe Models.Metadata FinishStreamWithMetadata()
/// <param name="stream">Instance of the stream to finish.</param>
/// <returns>The extended metadata result.</returns>
public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream)
{
return NativeImp.DS_FinishStreamWithMetadata(_streamingStatePP).PtrToMetadata();
return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer()).PtrToMetadata();
}
/// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference.
/// </summary>
/// <returns>The STT intermediate result. The user is responsible for freeing the string.</returns>
public unsafe string IntermediateDecode()
/// <param name="stream">Instance of the stream to decode.</param>
/// <returns>The STT intermediate result.</returns>
public unsafe string IntermediateDecode(DeepSpeechStream stream)
{
return NativeImp.DS_IntermediateDecode(_streamingStatePP);
return NativeImp.DS_IntermediateDecode(stream.GetNativePointer());
}
/// <summary>
@ -191,11 +207,12 @@ namespace DeepSpeechClient
/// <summary>
/// Creates a new streaming inference state.
/// </summary>
/// <exception cref="ArgumentException">Thrown when the native binary failed to initialize the streaming mode.</exception>
public unsafe void CreateStream()
public unsafe DeepSpeechStream CreateStream()
{
var resultCode = NativeImp.DS_CreateStream(_modelStatePP, ref _streamingStatePP);
IntPtr** streamingStatePointer = null;
var resultCode = NativeImp.DS_CreateStream(_modelStatePP, ref streamingStatePointer);
EvaluateResultCode(resultCode);
return new DeepSpeechStream(streamingStatePointer);
}
/// <summary>
@ -203,25 +220,10 @@ namespace DeepSpeechClient
/// This can be used if you no longer need the result of an ongoing streaming
/// inference and don't want to perform a costly decode operation.
/// </summary>
public unsafe void FreeStream()
public unsafe void FreeStream(DeepSpeechStream stream)
{
NativeImp.DS_FreeStream(ref _streamingStatePP);
}
/// <summary>
/// Free a DeepSpeech allocated string
/// </summary>
public unsafe void FreeString(IntPtr intPtr)
{
NativeImp.DS_FreeString(intPtr);
}
/// <summary>
/// Free a DeepSpeech allocated Metadata struct
/// </summary>
public unsafe void FreeMetadata(IntPtr intPtr)
{
NativeImp.DS_FreeMetadata(intPtr);
NativeImp.DS_FreeStream(stream.GetNativePointer());
stream.Dispose();
}
/// <summary>
@ -229,7 +231,7 @@ namespace DeepSpeechClient
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The STT result. The user is responsible for freeing the string. Returns NULL on error.</returns>
/// <returns>The STT result. Returns NULL on error.</returns>
public unsafe string SpeechToText(short[] aBuffer, uint aBufferSize)
{
return NativeImp.DS_SpeechToText(_modelStatePP, aBuffer, aBufferSize).PtrToString();
@ -240,8 +242,8 @@ namespace DeepSpeechClient
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The extended metadata. The user is responsible for freeing the struct. Returns NULL on error.</returns>
public unsafe Models.Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize)
/// <returns>The extended metadata. Returns NULL on error.</returns>
public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize)
{
return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata();
}

View File

@ -48,6 +48,7 @@
<Compile Include="Enums\ErrorCodes.cs" />
<Compile Include="Interfaces\IDeepSpeech.cs" />
<Compile Include="Extensions\NativeExtensions.cs" />
<Compile Include="Models\DeepSpeechStream.cs" />
<Compile Include="Models\Metadata.cs" />
<Compile Include="Models\MetadataItem.cs" />
<Compile Include="NativeImp.cs" />

View File

@ -1,10 +1,11 @@
using DeepSpeechClient.Models;
using System;
using System.IO;
namespace DeepSpeechClient.Interfaces
{
/// <summary>
/// Client interface of the Mozilla's deepspeech implementation.
/// Client interface of the Mozilla's DeepSpeech implementation.
/// </summary>
public interface IDeepSpeech : IDisposable
{
@ -13,15 +14,6 @@ namespace DeepSpeechClient.Interfaces
/// </summary>
void PrintVersions();
/// <summary>
/// Create an object providing an interface to a trained DeepSpeech model.
/// </summary>
/// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
unsafe void CreateModel(string aModelPath,
uint aBeamWidth);
/// <summary>
/// Return the sample rate expected by the model.
/// </summary>
@ -36,6 +28,7 @@ namespace DeepSpeechClient.Interfaces
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
unsafe void EnableDecoderWithLM(string aLMPath,
string aTriePath,
float aLMAlpha,
@ -46,7 +39,7 @@ namespace DeepSpeechClient.Interfaces
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The STT result. The user is responsible for freeing the string. Returns NULL on error.</returns>
/// <returns>The STT result. Returns NULL on error.</returns>
unsafe string SpeechToText(short[] aBuffer,
uint aBufferSize);
@ -55,7 +48,7 @@ namespace DeepSpeechClient.Interfaces
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <returns>The extended metadata result. The user is responsible for freeing the struct. Returns NULL on error.</returns>
/// <returns>The extended metadata. Returns NULL on error.</returns>
unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer,
uint aBufferSize);
@ -64,46 +57,39 @@ namespace DeepSpeechClient.Interfaces
/// This can be used if you no longer need the result of an ongoing streaming
/// inference and don't want to perform a costly decode operation.
/// </summary>
unsafe void FreeStream();
/// <summary>
/// Free a DeepSpeech allocated string
/// </summary>
unsafe void FreeString(IntPtr intPtr);
/// <summary>
/// Free a DeepSpeech allocated Metadata struct
/// </summary>
unsafe void FreeMetadata(IntPtr intPtr);
unsafe void FreeStream(DeepSpeechStream stream);
/// <summary>
/// Creates a new streaming inference state.
/// </summary>
/// <exception cref="ArgumentException">Thrown when the native binary failed to initialize the streaming mode.</exception>
unsafe void CreateStream();
unsafe DeepSpeechStream CreateStream();
/// <summary>
/// Feeds audio samples to an ongoing streaming inference.
/// </summary>
/// <param name="stream">Instance of the stream to feed the data.</param>
/// <param name="aBuffer">An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).</param>
unsafe void FeedAudioContent(short[] aBuffer, uint aBufferSize);
unsafe void FeedAudioContent(DeepSpeechStream stream, short[] aBuffer, uint aBufferSize);
/// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference.
/// </summary>
/// <returns>The STT intermediate result. The user is responsible for freeing the string.</returns>
unsafe string IntermediateDecode();
/// <param name="stream">Instance of the stream to decode.</param>
/// <returns>The STT intermediate result.</returns>
unsafe string IntermediateDecode(DeepSpeechStream stream);
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary>
/// <returns>The STT result. The user is responsible for freeing the string.</returns>
unsafe string FinishStream();
/// <param name="stream">Instance of the stream to finish.</param>
/// <returns>The STT result.</returns>
unsafe string FinishStream(DeepSpeechStream stream);
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary>
/// <returns>The extended metadata result. The user is responsible for freeing the struct.</returns>
unsafe Metadata FinishStreamWithMetadata();
/// <param name="stream">Instance of the stream to finish.</param>
/// <returns>The extended metadata result.</returns>
unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream);
}
}

View File

@ -0,0 +1,35 @@
using System;
namespace DeepSpeechClient.Models
{
/// <summary>
/// Wrapper of the pointer used for the decoding stream.
/// </summary>
public class DeepSpeechStream : IDisposable
{
private unsafe IntPtr** _streamingStatePp;
/// <summary>
/// Initializes a new instance of <see cref="DeepSpeechStream"/>.
/// </summary>
/// <param name="streamingStatePP">Native pointer of the native stream.</param>
public unsafe DeepSpeechStream(IntPtr** streamingStatePP)
{
_streamingStatePp = streamingStatePP;
}
/// <summary>
/// Gets the native pointer.
/// </summary>
/// <exception cref="InvalidOperationException">Thrown when the stream has been disposed or not yet initialized.</exception>
/// <returns>Native pointer of the stream.</returns>
internal unsafe IntPtr** GetNativePointer()
{
if (_streamingStatePp == null)
throw new InvalidOperationException("Cannot use a disposed or uninitialized stream.");
return _streamingStatePp;
}
public unsafe void Dispose() => _streamingStatePp = null;
}
}

View File

@ -48,7 +48,7 @@ namespace DeepSpeechClient
ref IntPtr** retval);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern void DS_FreeStream(ref IntPtr** aSctx);
internal static unsafe extern void DS_FreeStream(IntPtr** aSctx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern void DS_FreeMetadata(IntPtr metadata);

View File

@ -53,16 +53,13 @@ namespace CSharpExamples
const float LM_BETA = 1.85f;
Stopwatch stopwatch = new Stopwatch();
using (IDeepSpeech sttClient = new DeepSpeech())
{
try
{
Console.WriteLine("Loading model...");
stopwatch.Start();
sttClient.CreateModel(
model ?? "output_graph.pbmm",
BEAM_WIDTH);
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
BEAM_WIDTH))
{
stopwatch.Stop();
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
@ -88,12 +85,14 @@ namespace CSharpExamples
string speechResult;
if (extended)
{
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer, Convert.ToUInt32(waveBuffer.MaxSize / 2));
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
speechResult = MetadataToString(metaResult);
}
else
{
speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer, Convert.ToUInt32(waveBuffer.MaxSize / 2));
speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
}
stopwatch.Stop();
@ -104,6 +103,7 @@ namespace CSharpExamples
}
waveBuffer.Clear();
}
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
@ -111,4 +111,3 @@ namespace CSharpExamples
}
}
}
}