Initial commit

This commit is contained in:
0/0 2022-10-09 20:17:31 -06:00
commit e12122a6ed
No known key found for this signature in database
GPG Key ID: DE8D5010C0AAA3DC
15 changed files with 927 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
/Cargo.lock

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "sys/whisper.cpp"]
path = sys/whisper.cpp
url = https://github.com/ggerganov/whisper.cpp

15
Cargo.toml Normal file
View File

@ -0,0 +1,15 @@
[workspace]
members = ["sys"]
[package]
name = "whisper-cpp"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
whisper-cpp-sys = { path = "sys", version = "0.1" }
[features]
simd = []

57
examples/basic_use.rs Normal file
View File

@ -0,0 +1,57 @@
use whisper_cpp::{DecodeStrategy, FullParams, WhisperContext};
// note that running this example will not do anything, as it is just a
// demonstration of how to use the library, and actual usage requires
// more dependencies than the base library.
pub fn usage() {
// load a context and model
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
// create a params object
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
// n_past defaults to 0
let mut params = FullParams::new(DecodeStrategy::Greedy { n_past: 0 });
// edit things as needed
// here we set the number of threads to use to 1
params.set_n_threads(1);
// we also enable translation
params.set_translate(true);
// and set the language to translate to to english
params.set_language("en");
// we also explicitly disable anything that prints to stdout
params.set_print_special_tokens(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
// assume we have a buffer of audio data
// here we'll make a fake one, integer samples, 16 bit, 16KHz, stereo
let audio_data = vec![0_i16; 16000 * 2];
// we must convert to 16KHz mono f32 samples for the model
// some utilities exist for this
// note that you don't need to use these, you can do it yourself or any other way you want
// these are just provided for convenience
// SIMD variants of these functions are also available, but only on nightly Rust: see the docs
let audio_data = whisper_cpp::convert_stereo_to_mono_audio(
&whisper_cpp::convert_integer_to_float_audio(&audio_data),
);
// now we can run the model
ctx.full(params, &audio_data[..])
.expect("failed to run model");
// fetch the results
let num_segments = ctx.full_n_segments();
for i in 0..num_segments {
let segment = ctx.full_get_segment_text(i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(i);
let end_timestamp = ctx.full_get_segment_t1(i);
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
}
}
fn main() {
println!("running this example does nothing! see the source code for usage");
}

37
src/error.rs Normal file
View File

@ -0,0 +1,37 @@
use std::ffi::{c_int, NulError};
use std::str::Utf8Error;
#[derive(Debug, Copy, Clone)]
pub enum WhisperError {
InitError,
SpectrogramNotInitialized,
EncodeNotComplete,
DecodeNotComplete,
InvalidThreadCount,
InvalidUtf8 {
error_len: Option<usize>,
valid_up_to: usize,
},
NullByteInString {
idx: usize,
},
NullPointer,
GenericError(c_int),
}
impl From<Utf8Error> for WhisperError {
fn from(e: Utf8Error) -> Self {
Self::InvalidUtf8 {
error_len: e.error_len(),
valid_up_to: e.valid_up_to(),
}
}
}
impl From<NulError> for WhisperError {
fn from(e: NulError) -> Self {
Self::NullByteInString {
idx: e.nul_position(),
}
}
}

15
src/lib.rs Normal file
View File

@ -0,0 +1,15 @@
#![cfg_attr(feature = "simd", feature(portable_simd))]
mod error;
mod standalone;
mod utilities;
mod whisper_ctx;
mod whisper_params;
pub use error::WhisperError;
pub use standalone::*;
pub use utilities::*;
pub use whisper_ctx::WhisperContext;
pub use whisper_params::{DecodeStrategy, FullParams};
pub type WhisperToken = std::ffi::c_int;

44
src/standalone.rs Normal file
View File

@ -0,0 +1,44 @@
//! Standalone functions that have no associated type.
use crate::WhisperToken;
use std::ffi::{c_int, CString};
/// Return the id of the specified language, returns -1 if not found
///
/// # Arguments
/// * lang: The language to get the id for.
///
/// # Returns
/// The ID of the language, None if not found.
///
/// # Panics
/// Panics if the language contains a null byte.
///
/// # C++ equivalent
/// `int whisper_lang_id(const char * lang)`
pub fn get_lang_id(lang: &str) -> Option<c_int> {
let c_lang = CString::new(lang).expect("Language contains null byte");
let ret = unsafe { whisper_cpp_sys::whisper_lang_id(c_lang.as_ptr()) };
if ret == -1 {
None
} else {
Some(ret)
}
}
// task tokens
/// Get the ID of the translate task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_translate ()`
pub fn token_translate() -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_translate() }
}
/// Get the ID of the transcribe task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe() -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_transcribe() }
}

129
src/utilities.rs Normal file
View File

@ -0,0 +1,129 @@
#[cfg(feature = "simd")]
use std::simd::{f32x16, i16x16};
/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats.
///
/// This variant does not use SIMD instructions.
///
/// # Arguments
/// * `samples` - The array of 16 bit mono audio samples.
///
/// # Returns
/// A vector of 32 bit floats.
pub fn convert_integer_to_float_audio(samples: &[i16]) -> Vec<f32> {
let mut floats = Vec::with_capacity(samples.len());
for sample in samples {
floats.push(*sample as f32 / 32768.0);
}
floats
}
/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats.
///
/// This variant uses SIMD instructions, and as such is only available on
/// nightly Rust.
///
/// # Arguments
/// * `samples` - The array of 16 bit mono audio samples.
///
/// # Returns
/// A vector of 32 bit floats.
#[cfg(feature = "simd")]
pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec<f32> {
let mut floats = Vec::with_capacity(samples.len());
let div_arr = f32x16::splat(32768.0);
let chunks = samples.chunks_exact(16);
let remainder = chunks.remainder();
for chunk in chunks {
let simd = i16x16::from_slice(chunk).cast::<f32>();
let simd = simd / div_arr;
floats.extend(&simd.to_array()[..]);
}
// Handle the remainder.
// do this normally because it's only a few samples and the overhead of
// converting to SIMD is not worth it.
for sample in remainder {
floats.push(*sample as f32 / 32768.0);
}
floats
}
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
///
/// This variant does not use SIMD instructions.
///
/// # Arguments
/// * `samples` - The array of 32 bit floating point stereo PCM audio samples.
///
/// # Returns
/// A vector of 32 bit floating point mono PCM audio samples.
pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec<f32> {
let mut mono = Vec::with_capacity(samples.len() / 2);
for i in (0..samples.len()).step_by(2) {
mono.push((samples[i] + samples[i + 1]) / 2.0);
}
mono
}
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
///
/// This variant uses SIMD instructions, and as such is only available on
/// nightly Rust.
///
/// # Arguments
/// * `samples` - The array of 32 bit floating point stereo PCM audio samples.
///
/// # Returns
/// A vector of 32 bit floating point mono PCM audio samples.
#[cfg(feature = "simd")]
pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> {
let mut mono = Vec::with_capacity(samples.len() / 2);
let div_array = f32x16::splat(2.0);
let chunks = samples.chunks_exact(32);
let remainder = chunks.remainder();
for chunk in chunks {
let [c1, c2] = [0, 1].map(|offset| {
let mut arr = [0.0; 16];
std::iter::zip(&mut arr, chunk.iter().skip(offset).step_by(2).copied())
.for_each(|(a, c)| *a = c);
arr
});
let c1 = f32x16::from(c1);
let c2 = f32x16::from(c2);
let mono_simd = (c1 + c2) / div_array;
mono.extend(&mono_simd.to_array()[..]);
}
// Handle the remainder.
// do this normally because it's only a few samples and the overhead of
// converting to SIMD is not worth it.
for i in (0..remainder.len()).step_by(2) {
mono.push((remainder[i] + remainder[i + 1]) / 2.0);
}
mono
}
#[cfg(test)]
mod test {
use super::*;
#[test]
pub fn assert_stereo_to_mono_simd() {
// fake some sample data, of 1028 elements
let mut samples = Vec::with_capacity(1028);
for i in 0..1028 {
samples.push(i as f32);
}
let mono_simd = convert_stereo_to_mono_audio_simd(&samples);
let mono = convert_stereo_to_mono_audio(&samples);
assert_eq!(mono_simd, mono);
}
}

455
src/whisper_ctx.rs Normal file
View File

@ -0,0 +1,455 @@
use crate::error::WhisperError;
use crate::whisper_params::FullParams;
use crate::WhisperToken;
use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context.
///
/// You likely want to create this with [WhisperContext::new],
/// then run a full transcription with [WhisperContext::full].
#[derive(Debug)]
pub struct WhisperContext {
ctx: *mut whisper_cpp_sys::whisper_context,
/// has the spectrogram been initialized in at least one way?
spectrogram_initialized: bool,
/// has the data been encoded?
encode_complete: bool,
/// has decode been called at least once?
decode_once: bool,
}
impl WhisperContext {
/// Create a new WhisperContext.
///
/// # Arguments
/// * path: The path to the model file.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_cpp_sys::whisper_init(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self {
ctx,
spectrogram_initialized: false,
encode_complete: false,
decode_once: false,
})
}
}
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// The resulting spectrogram is stored in the context transparently.
///
/// # Arguments
/// * pcm: The raw PCM audio.
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let ret = unsafe {
whisper_cpp_sys::whisper_pcm_to_mel(
self.ctx,
pcm.as_ptr(),
pcm.len() as c_int,
threads as c_int,
)
};
if ret == 0 {
self.spectrogram_initialized = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// This can be used to set a custom log mel spectrogram inside the provided whisper context.
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
///
/// # Note
/// This is a low-level function.
/// If you're a typical user, you probably don't want to use this function.
/// See instead [WhisperContext::pcm_to_mel].
///
/// # Arguments
/// * data: The log mel spectrogram.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
let ret = unsafe {
whisper_cpp_sys::whisper_set_mel(
self.ctx,
data.as_ptr(),
data.len() as c_int,
80 as c_int,
)
};
if ret == 0 {
self.spectrogram_initialized = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
/// Make sure to call [WhisperContext::pcm_to_mel] or [[WhisperContext::set_mel] first.
///
/// # Arguments
/// * offset: Can be used to specify the offset of the first frame in the spectrogram. Usually 0.
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let ret =
unsafe { whisper_cpp_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) };
if ret == 0 {
self.encode_complete = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// Run the Whisper decoder to obtain the logits and probabilities for the next token.
/// Make sure to call [WhisperContext::encode] first.
/// tokens + n_tokens is the provided context for the decoder.
///
/// # Arguments
/// * tokens: The tokens to decode.
/// * n_tokens: The number of tokens to decode.
/// * n_past: The number of past tokens to use for the decoding.
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
pub fn decode(
&mut self,
tokens: &[WhisperToken],
n_past: usize,
threads: usize,
) -> Result<(), WhisperError> {
if !self.encode_complete {
return Err(WhisperError::EncodeNotComplete);
}
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let ret = unsafe {
whisper_cpp_sys::whisper_decode(
self.ctx,
tokens.as_ptr(),
tokens.len() as c_int,
n_past as c_int,
threads as c_int,
)
};
if ret == 0 {
self.decode_once = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
// Token sampling functions
/// Return the token with the highest probability.
/// Make sure to call [WhisperContext::decode] first.
///
/// # Arguments
/// * needs_timestamp
///
/// # Returns
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
pub fn sample_best(&mut self, needs_timestamp: bool) -> Result<WhisperToken, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_sample_best(self.ctx, needs_timestamp) };
Ok(ret)
}
/// Return the token with the most probable timestamp.
/// Make sure to call [WhisperContext::decode] first.
///
/// # Returns
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `whisper_token whisper_sample_timestamp(struct whisper_context * ctx)`
pub fn sample_timestamp(&mut self) -> Result<WhisperToken, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_sample_timestamp(self.ctx) };
Ok(ret)
}
// model attributes
/// Get the mel spectrogram length.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_n_len (struct whisper_context * ctx)`
pub fn n_len(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_len(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as c_int)
}
}
/// Get n_vocab.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_n_vocab (struct whisper_context * ctx)`
pub fn n_vocab(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_vocab(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as c_int)
}
}
/// Get n_text_ctx.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx)`
pub fn n_text_ctx(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_text_ctx(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as c_int)
}
}
/// Does this model support multiple languages?
///
/// # C++ equivalent
/// `int whisper_is_multilingual(struct whisper_context * ctx)`
pub fn is_multilingual(&self) -> bool {
unsafe { whisper_cpp_sys::whisper_is_multilingual(self.ctx) != 0 }
}
/// The probabilities for the next token.
/// Make sure to call [WhisperContext::decode] first.
///
/// # Returns
/// Ok(*const f32) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `float * whisper_get_probs(struct whisper_context * ctx)`
pub fn get_probs(&mut self) -> Result<*const f32, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_get_probs(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
Ok(ret)
}
/// Convert a token ID to a string.
///
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_token_to_str(self.ctx, token_id) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
// special tokens
/// Get the ID of the eot token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
pub fn token_eot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_eot(self.ctx) }
}
/// Get the ID of the sot token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
pub fn token_sot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_sot(self.ctx) }
}
/// Get the ID of the prev token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
pub fn token_prev(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_prev(self.ctx) }
}
/// Get the ID of the solm token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
pub fn token_solm(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_solm(self.ctx) }
}
/// Get the ID of the not token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
pub fn token_not(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_not(self.ctx) }
}
/// Get the ID of the beg token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
pub fn token_beg(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_beg(self.ctx) }
}
/// Print performance statistics to stdout.
///
/// # C++ equivalent
/// `void whisper_print_timings(struct whisper_context * ctx)`
pub fn print_timings(&self) {
unsafe { whisper_cpp_sys::whisper_print_timings(self.ctx) }
}
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
/// Uses the specified decoding strategy to obtain the text.
///
/// This is usually the only function you need to call as an end user.
///
/// # Arguments
/// * params: [crate::FullParams] struct.
/// * pcm: PCM audio data.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
let ret = unsafe {
whisper_cpp_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int)
};
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as c_int)
}
}
/// Number of generated text segments.
/// A segment can be a few words, a sentence, or even a paragraph.
///
/// # C++ equivalent
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
pub fn full_n_segments(&self) -> c_int {
unsafe { whisper_cpp_sys::whisper_full_n_segments(self.ctx) }
}
/// Get the start time of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t0(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t0(self.ctx, segment) }
}
/// Get the end time of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t1(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t1(self.ctx, segment) }
}
/// Get the text of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_full_get_segment_text(self.ctx, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
}
impl Drop for WhisperContext {
fn drop(&mut self) {
unsafe { whisper_cpp_sys::whisper_free(self.ctx) };
}
}

114
src/whisper_params.rs Normal file
View File

@ -0,0 +1,114 @@
use std::ffi::c_int;
use std::marker::PhantomData;
pub enum DecodeStrategy {
Greedy {
n_past: c_int,
},
/// not implemented yet, results of using this unknown
BeamSearch {
n_past: c_int,
beam_width: c_int,
n_best: c_int,
},
}
pub struct FullParams<'a> {
pub(crate) fp: whisper_cpp_sys::whisper_full_params,
phantom: PhantomData<&'a str>,
}
impl<'a> FullParams<'a> {
/// Create a new set of parameters for the decoder.
pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> {
let mut fp = unsafe {
whisper_cpp_sys::whisper_full_default_params(match decode_strategy {
DecodeStrategy::Greedy { .. } => 0,
DecodeStrategy::BeamSearch { .. } => 1,
} as _)
};
match decode_strategy {
DecodeStrategy::Greedy { n_past } => {
fp.__bindgen_anon_1.greedy.n_past = n_past;
}
DecodeStrategy::BeamSearch {
n_past,
beam_width,
n_best,
} => {
fp.__bindgen_anon_1.beam_search.n_past = n_past;
fp.__bindgen_anon_1.beam_search.beam_width = beam_width;
fp.__bindgen_anon_1.beam_search.n_best = n_best;
}
}
Self {
fp,
phantom: PhantomData,
}
}
/// Set the number of threads to use for decoding.
///
/// Defaults to min(4, std::thread::hardware_concurrency()).
pub fn set_n_threads(&mut self, n_threads: c_int) {
self.fp.n_threads = n_threads;
}
/// Set the offset in milliseconds to use for decoding.
///
/// Defaults to 0.
pub fn set_offset_ms(&mut self, offset_ms: c_int) {
self.fp.offset_ms = offset_ms;
}
/// Set whether to translate the output to the language specified by `language`.
///
/// Defaults to false.
pub fn set_translate(&mut self, translate: bool) {
self.fp.translate = translate;
}
/// Set no_context. Usage unknown.
///
/// Defaults to false.
pub fn set_no_context(&mut self, no_context: bool) {
self.fp.no_context = no_context;
}
/// Set whether to print special tokens.
///
/// Defaults to false.
pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) {
self.fp.print_special_tokens = print_special_tokens;
}
/// Set whether to print progress.
///
/// Defaults to true.
pub fn set_print_progress(&mut self, print_progress: bool) {
self.fp.print_progress = print_progress;
}
/// Set print_realtime. Usage unknown.
///
/// Defaults to false.
pub fn set_print_realtime(&mut self, print_realtime: bool) {
self.fp.print_realtime = print_realtime;
}
/// Set whether to print timestamps.
///
/// Defaults to true.
pub fn set_print_timestamps(&mut self, print_timestamps: bool) {
self.fp.print_timestamps = print_timestamps;
}
/// Set the target language.
///
/// Defaults to "en".
pub fn set_language(&mut self, language: &'a str) {
self.fp.language = language.as_ptr() as *const _;
}
}

11
sys/Cargo.toml Normal file
View File

@ -0,0 +1,11 @@
[package]
name = "whisper-cpp-sys"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
[build-dependencies]
bindgen = "0.60"

38
sys/build.rs Normal file
View File

@ -0,0 +1,38 @@
extern crate bindgen;
use std::env;
use std::path::PathBuf;
fn main() {
// Tell cargo to look for shared libraries in the specified directory
println!("cargo:rustc-link-search=whisper.cpp");
// Tell cargo to tell rustc to link the system bzip2
// shared library.
println!("cargo:rustc-link-lib=static=whisper");
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=wrapper.h");
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("wrapper.h")
.clang_arg("-I/home/deck/CLionProjects/whisper.cpp")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}

5
sys/src/lib.rs Normal file
View File

@ -0,0 +1,5 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));

1
sys/whisper.cpp Submodule

@ -0,0 +1 @@
Subproject commit 7edaa7da4bdd072890c312ed764fc62b1fadb98f

1
sys/wrapper.h Normal file
View File

@ -0,0 +1 @@
#include <whisper.h>