commit e12122a6ed64a0eb099b583be3fff8f10fdd6959 Author: 0/0 Date: Sun Oct 9 20:17:31 2022 -0600 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..0863136 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sys/whisper.cpp"] + path = sys/whisper.cpp + url = https://github.com/ggerganov/whisper.cpp diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..58c4b0a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,15 @@ +[workspace] +members = ["sys"] + +[package] +name = "whisper-cpp" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +whisper-cpp-sys = { path = "sys", version = "0.1" } + +[features] +simd = [] \ No newline at end of file diff --git a/examples/basic_use.rs b/examples/basic_use.rs new file mode 100644 index 0000000..de25f03 --- /dev/null +++ b/examples/basic_use.rs @@ -0,0 +1,57 @@ +use whisper_cpp::{DecodeStrategy, FullParams, WhisperContext}; + +// note that running this example will not do anything, as it is just a +// demonstration of how to use the library, and actual usage requires +// more dependencies than the base library. +pub fn usage() { + // load a context and model + let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); + + // create a params object + // note that currently the only implemented strategy is Greedy, BeamSearch is a WIP + // n_past defaults to 0 + let mut params = FullParams::new(DecodeStrategy::Greedy { n_past: 0 }); + + // edit things as needed + // here we set the number of threads to use to 1 + params.set_n_threads(1); + // we also enable translation + params.set_translate(true); + // and set the language to translate to to english + params.set_language("en"); + // we also explicitly disable anything that prints to stdout + params.set_print_special_tokens(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + + // assume we have a buffer of audio data + // here we'll make a fake one, integer samples, 16 bit, 16KHz, stereo + let audio_data = vec![0_i16; 16000 * 2]; + + // we must convert to 16KHz mono f32 samples for the model + // some utilities exist for this + // note that you don't need to use these, you can do it yourself or any other way you want + // these are just provided for convenience + // SIMD variants of these functions are also available, but only on nightly Rust: see the docs + let audio_data = whisper_cpp::convert_stereo_to_mono_audio( + &whisper_cpp::convert_integer_to_float_audio(&audio_data), + ); + + // now we can run the model + ctx.full(params, &audio_data[..]) + .expect("failed to run model"); + + // fetch the results + let num_segments = ctx.full_n_segments(); + for i in 0..num_segments { + let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); + let start_timestamp = ctx.full_get_segment_t0(i); + let end_timestamp = ctx.full_get_segment_t1(i); + println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); + } +} + +fn main() { + println!("running this example does nothing! see the source code for usage"); +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..1a8200f --- /dev/null +++ b/src/error.rs @@ -0,0 +1,37 @@ +use std::ffi::{c_int, NulError}; +use std::str::Utf8Error; + +#[derive(Debug, Copy, Clone)] +pub enum WhisperError { + InitError, + SpectrogramNotInitialized, + EncodeNotComplete, + DecodeNotComplete, + InvalidThreadCount, + InvalidUtf8 { + error_len: Option, + valid_up_to: usize, + }, + NullByteInString { + idx: usize, + }, + NullPointer, + GenericError(c_int), +} + +impl From for WhisperError { + fn from(e: Utf8Error) -> Self { + Self::InvalidUtf8 { + error_len: e.error_len(), + valid_up_to: e.valid_up_to(), + } + } +} + +impl From for WhisperError { + fn from(e: NulError) -> Self { + Self::NullByteInString { + idx: e.nul_position(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..7bb2fd6 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,15 @@ +#![cfg_attr(feature = "simd", feature(portable_simd))] + +mod error; +mod standalone; +mod utilities; +mod whisper_ctx; +mod whisper_params; + +pub use error::WhisperError; +pub use standalone::*; +pub use utilities::*; +pub use whisper_ctx::WhisperContext; +pub use whisper_params::{DecodeStrategy, FullParams}; + +pub type WhisperToken = std::ffi::c_int; diff --git a/src/standalone.rs b/src/standalone.rs new file mode 100644 index 0000000..1f798e2 --- /dev/null +++ b/src/standalone.rs @@ -0,0 +1,44 @@ +//! Standalone functions that have no associated type. + +use crate::WhisperToken; +use std::ffi::{c_int, CString}; + +/// Return the id of the specified language, returns -1 if not found +/// +/// # Arguments +/// * lang: The language to get the id for. +/// +/// # Returns +/// The ID of the language, None if not found. +/// +/// # Panics +/// Panics if the language contains a null byte. +/// +/// # C++ equivalent +/// `int whisper_lang_id(const char * lang)` +pub fn get_lang_id(lang: &str) -> Option { + let c_lang = CString::new(lang).expect("Language contains null byte"); + let ret = unsafe { whisper_cpp_sys::whisper_lang_id(c_lang.as_ptr()) }; + if ret == -1 { + None + } else { + Some(ret) + } +} + +// task tokens +/// Get the ID of the translate task token. +/// +/// # C++ equivalent +/// `whisper_token whisper_token_translate ()` +pub fn token_translate() -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_translate() } +} + +/// Get the ID of the transcribe task token. +/// +/// # C++ equivalent +/// `whisper_token whisper_token_transcribe()` +pub fn token_transcribe() -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_transcribe() } +} diff --git a/src/utilities.rs b/src/utilities.rs new file mode 100644 index 0000000..47b3374 --- /dev/null +++ b/src/utilities.rs @@ -0,0 +1,129 @@ +#[cfg(feature = "simd")] +use std::simd::{f32x16, i16x16}; + +/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats. +/// +/// This variant does not use SIMD instructions. +/// +/// # Arguments +/// * `samples` - The array of 16 bit mono audio samples. +/// +/// # Returns +/// A vector of 32 bit floats. +pub fn convert_integer_to_float_audio(samples: &[i16]) -> Vec { + let mut floats = Vec::with_capacity(samples.len()); + for sample in samples { + floats.push(*sample as f32 / 32768.0); + } + floats +} + +/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats. +/// +/// This variant uses SIMD instructions, and as such is only available on +/// nightly Rust. +/// +/// # Arguments +/// * `samples` - The array of 16 bit mono audio samples. +/// +/// # Returns +/// A vector of 32 bit floats. +#[cfg(feature = "simd")] +pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec { + let mut floats = Vec::with_capacity(samples.len()); + + let div_arr = f32x16::splat(32768.0); + + let chunks = samples.chunks_exact(16); + let remainder = chunks.remainder(); + for chunk in chunks { + let simd = i16x16::from_slice(chunk).cast::(); + let simd = simd / div_arr; + floats.extend(&simd.to_array()[..]); + } + + // Handle the remainder. + // do this normally because it's only a few samples and the overhead of + // converting to SIMD is not worth it. + for sample in remainder { + floats.push(*sample as f32 / 32768.0); + } + + floats +} + +/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. +/// +/// This variant does not use SIMD instructions. +/// +/// # Arguments +/// * `samples` - The array of 32 bit floating point stereo PCM audio samples. +/// +/// # Returns +/// A vector of 32 bit floating point mono PCM audio samples. +pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec { + let mut mono = Vec::with_capacity(samples.len() / 2); + for i in (0..samples.len()).step_by(2) { + mono.push((samples[i] + samples[i + 1]) / 2.0); + } + mono +} + +/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. +/// +/// This variant uses SIMD instructions, and as such is only available on +/// nightly Rust. +/// +/// # Arguments +/// * `samples` - The array of 32 bit floating point stereo PCM audio samples. +/// +/// # Returns +/// A vector of 32 bit floating point mono PCM audio samples. +#[cfg(feature = "simd")] +pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec { + let mut mono = Vec::with_capacity(samples.len() / 2); + + let div_array = f32x16::splat(2.0); + + let chunks = samples.chunks_exact(32); + let remainder = chunks.remainder(); + for chunk in chunks { + let [c1, c2] = [0, 1].map(|offset| { + let mut arr = [0.0; 16]; + std::iter::zip(&mut arr, chunk.iter().skip(offset).step_by(2).copied()) + .for_each(|(a, c)| *a = c); + arr + }); + + let c1 = f32x16::from(c1); + let c2 = f32x16::from(c2); + let mono_simd = (c1 + c2) / div_array; + mono.extend(&mono_simd.to_array()[..]); + } + + // Handle the remainder. + // do this normally because it's only a few samples and the overhead of + // converting to SIMD is not worth it. + for i in (0..remainder.len()).step_by(2) { + mono.push((remainder[i] + remainder[i + 1]) / 2.0); + } + + mono +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + pub fn assert_stereo_to_mono_simd() { + // fake some sample data, of 1028 elements + let mut samples = Vec::with_capacity(1028); + for i in 0..1028 { + samples.push(i as f32); + } + let mono_simd = convert_stereo_to_mono_audio_simd(&samples); + let mono = convert_stereo_to_mono_audio(&samples); + assert_eq!(mono_simd, mono); + } +} diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs new file mode 100644 index 0000000..318084b --- /dev/null +++ b/src/whisper_ctx.rs @@ -0,0 +1,455 @@ +use crate::error::WhisperError; +use crate::whisper_params::FullParams; +use crate::WhisperToken; +use std::ffi::{c_int, CStr, CString}; + +/// Safe Rust wrapper around a Whisper context. +/// +/// You likely want to create this with [WhisperContext::new], +/// then run a full transcription with [WhisperContext::full]. +#[derive(Debug)] +pub struct WhisperContext { + ctx: *mut whisper_cpp_sys::whisper_context, + /// has the spectrogram been initialized in at least one way? + spectrogram_initialized: bool, + /// has the data been encoded? + encode_complete: bool, + /// has decode been called at least once? + decode_once: bool, +} + +impl WhisperContext { + /// Create a new WhisperContext. + /// + /// # Arguments + /// * path: The path to the model file. + /// + /// # Returns + /// Ok(Self) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_context * whisper_init(const char * path_model);` + pub fn new(path: &str) -> Result { + let path_cstr = CString::new(path)?; + let ctx = unsafe { whisper_cpp_sys::whisper_init(path_cstr.as_ptr()) }; + if ctx.is_null() { + Err(WhisperError::InitError) + } else { + Ok(Self { + ctx, + spectrogram_initialized: false, + encode_complete: false, + decode_once: false, + }) + } + } + + /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. + /// The resulting spectrogram is stored in the context transparently. + /// + /// # Arguments + /// * pcm: The raw PCM audio. + /// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` + pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { + if threads < 1 { + return Err(WhisperError::InvalidThreadCount); + } + let ret = unsafe { + whisper_cpp_sys::whisper_pcm_to_mel( + self.ctx, + pcm.as_ptr(), + pcm.len() as c_int, + threads as c_int, + ) + }; + if ret == 0 { + self.spectrogram_initialized = true; + Ok(()) + } else { + Err(WhisperError::GenericError(ret)) + } + } + + /// This can be used to set a custom log mel spectrogram inside the provided whisper context. + /// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + /// + /// # Note + /// This is a low-level function. + /// If you're a typical user, you probably don't want to use this function. + /// See instead [WhisperContext::pcm_to_mel]. + /// + /// # Arguments + /// * data: The log mel spectrogram. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` + pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { + let ret = unsafe { + whisper_cpp_sys::whisper_set_mel( + self.ctx, + data.as_ptr(), + data.len() as c_int, + 80 as c_int, + ) + }; + if ret == 0 { + self.spectrogram_initialized = true; + Ok(()) + } else { + Err(WhisperError::GenericError(ret)) + } + } + + /// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + /// Make sure to call [WhisperContext::pcm_to_mel] or [[WhisperContext::set_mel] first. + /// + /// # Arguments + /// * offset: Can be used to specify the offset of the first frame in the spectrogram. Usually 0. + /// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` + pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> { + if !self.spectrogram_initialized { + return Err(WhisperError::SpectrogramNotInitialized); + } + if threads < 1 { + return Err(WhisperError::InvalidThreadCount); + } + let ret = + unsafe { whisper_cpp_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) }; + if ret == 0 { + self.encode_complete = true; + Ok(()) + } else { + Err(WhisperError::GenericError(ret)) + } + } + + /// Run the Whisper decoder to obtain the logits and probabilities for the next token. + /// Make sure to call [WhisperContext::encode] first. + /// tokens + n_tokens is the provided context for the decoder. + /// + /// # Arguments + /// * tokens: The tokens to decode. + /// * n_tokens: The number of tokens to decode. + /// * n_past: The number of past tokens to use for the decoding. + /// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` + pub fn decode( + &mut self, + tokens: &[WhisperToken], + n_past: usize, + threads: usize, + ) -> Result<(), WhisperError> { + if !self.encode_complete { + return Err(WhisperError::EncodeNotComplete); + } + if threads < 1 { + return Err(WhisperError::InvalidThreadCount); + } + let ret = unsafe { + whisper_cpp_sys::whisper_decode( + self.ctx, + tokens.as_ptr(), + tokens.len() as c_int, + n_past as c_int, + threads as c_int, + ) + }; + if ret == 0 { + self.decode_once = true; + Ok(()) + } else { + Err(WhisperError::GenericError(ret)) + } + } + + // Token sampling functions + /// Return the token with the highest probability. + /// Make sure to call [WhisperContext::decode] first. + /// + /// # Arguments + /// * needs_timestamp + /// + /// # Returns + /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)` + pub fn sample_best(&mut self, needs_timestamp: bool) -> Result { + if !self.decode_once { + return Err(WhisperError::DecodeNotComplete); + } + let ret = unsafe { whisper_cpp_sys::whisper_sample_best(self.ctx, needs_timestamp) }; + Ok(ret) + } + + /// Return the token with the most probable timestamp. + /// Make sure to call [WhisperContext::decode] first. + /// + /// # Returns + /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `whisper_token whisper_sample_timestamp(struct whisper_context * ctx)` + pub fn sample_timestamp(&mut self) -> Result { + if !self.decode_once { + return Err(WhisperError::DecodeNotComplete); + } + let ret = unsafe { whisper_cpp_sys::whisper_sample_timestamp(self.ctx) }; + Ok(ret) + } + + // model attributes + /// Get the mel spectrogram length. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_n_len (struct whisper_context * ctx)` + pub fn n_len(&self) -> Result { + let ret = unsafe { whisper_cpp_sys::whisper_n_len(self.ctx) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Get n_vocab. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_n_vocab (struct whisper_context * ctx)` + pub fn n_vocab(&self) -> Result { + let ret = unsafe { whisper_cpp_sys::whisper_n_vocab(self.ctx) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Get n_text_ctx. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_n_text_ctx (struct whisper_context * ctx)` + pub fn n_text_ctx(&self) -> Result { + let ret = unsafe { whisper_cpp_sys::whisper_n_text_ctx(self.ctx) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Does this model support multiple languages? + /// + /// # C++ equivalent + /// `int whisper_is_multilingual(struct whisper_context * ctx)` + pub fn is_multilingual(&self) -> bool { + unsafe { whisper_cpp_sys::whisper_is_multilingual(self.ctx) != 0 } + } + + /// The probabilities for the next token. + /// Make sure to call [WhisperContext::decode] first. + /// + /// # Returns + /// Ok(*const f32) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `float * whisper_get_probs(struct whisper_context * ctx)` + pub fn get_probs(&mut self) -> Result<*const f32, WhisperError> { + if !self.decode_once { + return Err(WhisperError::DecodeNotComplete); + } + let ret = unsafe { whisper_cpp_sys::whisper_get_probs(self.ctx) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + Ok(ret) + } + + /// Convert a token ID to a string. + /// + /// # Arguments + /// * token_id: ID of the token. + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` + pub fn token_to_str(&self, token_id: WhisperToken) -> Result { + let ret = unsafe { whisper_cpp_sys::whisper_token_to_str(self.ctx, token_id) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + let c_str = unsafe { CStr::from_ptr(ret) }; + let r_str = c_str.to_str()?; + Ok(r_str.to_string()) + } + + // special tokens + /// Get the ID of the eot token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` + pub fn token_eot(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_eot(self.ctx) } + } + + /// Get the ID of the sot token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` + pub fn token_sot(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_sot(self.ctx) } + } + + /// Get the ID of the prev token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` + pub fn token_prev(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_prev(self.ctx) } + } + + /// Get the ID of the solm token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` + pub fn token_solm(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_solm(self.ctx) } + } + + /// Get the ID of the not token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_not (struct whisper_context * ctx)` + pub fn token_not(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_not(self.ctx) } + } + + /// Get the ID of the beg token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` + pub fn token_beg(&self) -> WhisperToken { + unsafe { whisper_cpp_sys::whisper_token_beg(self.ctx) } + } + + /// Print performance statistics to stdout. + /// + /// # C++ equivalent + /// `void whisper_print_timings(struct whisper_context * ctx)` + pub fn print_timings(&self) { + unsafe { whisper_cpp_sys::whisper_print_timings(self.ctx) } + } + + /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + /// Uses the specified decoding strategy to obtain the text. + /// + /// This is usually the only function you need to call as an end user. + /// + /// # Arguments + /// * params: [crate::FullParams] struct. + /// * pcm: PCM audio data. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` + pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { + let ret = unsafe { + whisper_cpp_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int) + }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Number of generated text segments. + /// A segment can be a few words, a sentence, or even a paragraph. + /// + /// # C++ equivalent + /// `int whisper_full_n_segments(struct whisper_context * ctx)` + pub fn full_n_segments(&self) -> c_int { + unsafe { whisper_cpp_sys::whisper_full_n_segments(self.ctx) } + } + + /// Get the start time of the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # C++ equivalent + /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` + pub fn full_get_segment_t0(&self, segment: c_int) -> i64 { + unsafe { whisper_cpp_sys::whisper_full_get_segment_t0(self.ctx, segment) } + } + + /// Get the end time of the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # C++ equivalent + /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` + pub fn full_get_segment_t1(&self, segment: c_int) -> i64 { + unsafe { whisper_cpp_sys::whisper_full_get_segment_t1(self.ctx, segment) } + } + + /// Get the text of the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` + pub fn full_get_segment_text(&self, segment: c_int) -> Result { + let ret = unsafe { whisper_cpp_sys::whisper_full_get_segment_text(self.ctx, segment) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + let c_str = unsafe { CStr::from_ptr(ret) }; + let r_str = c_str.to_str()?; + Ok(r_str.to_string()) + } +} + +impl Drop for WhisperContext { + fn drop(&mut self) { + unsafe { whisper_cpp_sys::whisper_free(self.ctx) }; + } +} diff --git a/src/whisper_params.rs b/src/whisper_params.rs new file mode 100644 index 0000000..782a6af --- /dev/null +++ b/src/whisper_params.rs @@ -0,0 +1,114 @@ +use std::ffi::c_int; +use std::marker::PhantomData; + +pub enum DecodeStrategy { + Greedy { + n_past: c_int, + }, + /// not implemented yet, results of using this unknown + BeamSearch { + n_past: c_int, + beam_width: c_int, + n_best: c_int, + }, +} + +pub struct FullParams<'a> { + pub(crate) fp: whisper_cpp_sys::whisper_full_params, + phantom: PhantomData<&'a str>, +} + +impl<'a> FullParams<'a> { + /// Create a new set of parameters for the decoder. + pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> { + let mut fp = unsafe { + whisper_cpp_sys::whisper_full_default_params(match decode_strategy { + DecodeStrategy::Greedy { .. } => 0, + DecodeStrategy::BeamSearch { .. } => 1, + } as _) + }; + + match decode_strategy { + DecodeStrategy::Greedy { n_past } => { + fp.__bindgen_anon_1.greedy.n_past = n_past; + } + DecodeStrategy::BeamSearch { + n_past, + beam_width, + n_best, + } => { + fp.__bindgen_anon_1.beam_search.n_past = n_past; + fp.__bindgen_anon_1.beam_search.beam_width = beam_width; + fp.__bindgen_anon_1.beam_search.n_best = n_best; + } + } + + Self { + fp, + phantom: PhantomData, + } + } + + /// Set the number of threads to use for decoding. + /// + /// Defaults to min(4, std::thread::hardware_concurrency()). + pub fn set_n_threads(&mut self, n_threads: c_int) { + self.fp.n_threads = n_threads; + } + + /// Set the offset in milliseconds to use for decoding. + /// + /// Defaults to 0. + pub fn set_offset_ms(&mut self, offset_ms: c_int) { + self.fp.offset_ms = offset_ms; + } + + /// Set whether to translate the output to the language specified by `language`. + /// + /// Defaults to false. + pub fn set_translate(&mut self, translate: bool) { + self.fp.translate = translate; + } + + /// Set no_context. Usage unknown. + /// + /// Defaults to false. + pub fn set_no_context(&mut self, no_context: bool) { + self.fp.no_context = no_context; + } + + /// Set whether to print special tokens. + /// + /// Defaults to false. + pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) { + self.fp.print_special_tokens = print_special_tokens; + } + + /// Set whether to print progress. + /// + /// Defaults to true. + pub fn set_print_progress(&mut self, print_progress: bool) { + self.fp.print_progress = print_progress; + } + + /// Set print_realtime. Usage unknown. + /// + /// Defaults to false. + pub fn set_print_realtime(&mut self, print_realtime: bool) { + self.fp.print_realtime = print_realtime; + } + + /// Set whether to print timestamps. + /// + /// Defaults to true. + pub fn set_print_timestamps(&mut self, print_timestamps: bool) { + self.fp.print_timestamps = print_timestamps; + } + + /// Set the target language. + /// + /// Defaults to "en". + pub fn set_language(&mut self, language: &'a str) { + self.fp.language = language.as_ptr() as *const _; + } +} diff --git a/sys/Cargo.toml b/sys/Cargo.toml new file mode 100644 index 0000000..57db3d4 --- /dev/null +++ b/sys/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "whisper-cpp-sys" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[build-dependencies] +bindgen = "0.60" diff --git a/sys/build.rs b/sys/build.rs new file mode 100644 index 0000000..8d5d3d5 --- /dev/null +++ b/sys/build.rs @@ -0,0 +1,38 @@ +extern crate bindgen; + +use std::env; +use std::path::PathBuf; + +fn main() { + // Tell cargo to look for shared libraries in the specified directory + println!("cargo:rustc-link-search=whisper.cpp"); + + // Tell cargo to tell rustc to link the system bzip2 + // shared library. + println!("cargo:rustc-link-lib=static=whisper"); + + // Tell cargo to invalidate the built crate whenever the wrapper changes + println!("cargo:rerun-if-changed=wrapper.h"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // The input header we would like to generate + // bindings for. + .header("wrapper.h") + .clang_arg("-I/home/deck/CLionProjects/whisper.cpp") + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file. + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} diff --git a/sys/src/lib.rs b/sys/src/lib.rs new file mode 100644 index 0000000..a38a13a --- /dev/null +++ b/sys/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/sys/whisper.cpp b/sys/whisper.cpp new file mode 160000 index 0000000..7edaa7d --- /dev/null +++ b/sys/whisper.cpp @@ -0,0 +1 @@ +Subproject commit 7edaa7da4bdd072890c312ed764fc62b1fadb98f diff --git a/sys/wrapper.h b/sys/wrapper.h new file mode 100644 index 0000000..683d201 --- /dev/null +++ b/sys/wrapper.h @@ -0,0 +1 @@ +#include \ No newline at end of file