diff --git a/src/lib.rs b/src/lib.rs index c724f55..36aafec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,3 +15,4 @@ pub use whisper_params::{FullParams, SamplingStrategy}; pub type WhisperTokenData = whisper_rs_sys::whisper_token_data; pub type WhisperToken = whisper_rs_sys::whisper_token; pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; +pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback; diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 6121a8c..ff20473 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -1,5 +1,6 @@ -use std::ffi::{c_int, CString}; +use std::ffi::{c_char, c_int, CString}; use std::marker::PhantomData; +use whisper_rs_sys::whisper_token; pub enum SamplingStrategy { Greedy { @@ -13,14 +14,15 @@ pub enum SamplingStrategy { }, } -pub struct FullParams<'a> { +pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, - phantom: PhantomData<&'a str>, + phantom_lang: PhantomData<&'a str>, + phantom_tokens: PhantomData<&'b [c_int]>, } -impl<'a> FullParams<'a> { +impl<'a, 'b> FullParams<'a, 'b> { /// Create a new set of parameters for the decoder. - pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> { + pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b> { let mut fp = unsafe { whisper_rs_sys::whisper_full_default_params(match sampling_strategy { SamplingStrategy::Greedy { .. } => { @@ -49,7 +51,8 @@ impl<'a> FullParams<'a> { Self { fp, - phantom: PhantomData, + phantom_lang: PhantomData, + phantom_tokens: PhantomData, } } @@ -60,13 +63,27 @@ impl<'a> FullParams<'a> { self.fp.n_threads = n_threads; } - /// Set the offset in milliseconds to use for decoding. + /// Set n_max_text_ctx. + /// + /// Defaults to 16384. + pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) { + self.fp.n_max_text_ctx = n_max_text_ctx; + } + + /// Set the start offset in milliseconds to use for decoding. /// /// Defaults to 0. pub fn set_offset_ms(&mut self, offset_ms: c_int) { self.fp.offset_ms = offset_ms; } + /// Set the audio duration to process in milliseconds. + /// + /// Defaults to 0. + pub fn set_duration_ms(&mut self, duration_ms: c_int) { + self.fp.duration_ms = duration_ms; + } + /// Set whether to translate the output to the language specified by `language`. /// /// Defaults to false. @@ -81,11 +98,18 @@ impl<'a> FullParams<'a> { self.fp.no_context = no_context; } - /// Set whether to print special tokens. + /// Force single segment output. This may be useful for streaming. /// /// Defaults to false. - pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) { - self.fp.print_special_tokens = print_special_tokens; + pub fn set_single_segment(&mut self, single_segment: bool) { + self.fp.single_segment = single_segment; + } + + /// Set print_special. Usage unknown. + /// + /// Defaults to false. + pub fn set_print_special(&mut self, print_special: bool) { + self.fp.print_special = print_special; } /// Set whether to print progress. @@ -109,6 +133,87 @@ impl<'a> FullParams<'a> { self.fp.print_timestamps = print_timestamps; } + /// # EXPERIMENTAL + /// + /// Enable token-level timestamps. + /// + /// Defaults to false. + pub fn set_token_timestamps(&mut self, token_timestamps: bool) { + self.fp.token_timestamps = token_timestamps; + } + + /// # EXPERIMENTAL + /// + /// Set timestamp token probability threshold. + /// + /// Defaults to 0.01. + pub fn set_thold_pt(&mut self, thold_pt: f32) { + self.fp.thold_pt = thold_pt; + } + + /// # EXPERIMENTAL + /// + /// Set timestamp token sum probability threshold. + /// + /// Defaults to 0.01. + pub fn set_thold_ptsum(&mut self, thold_ptsum: f32) { + self.fp.thold_ptsum = thold_ptsum; + } + + /// # EXPERIMENTAL + /// + /// Set maximum segment length in characters. + /// + /// Defaults to 0. + pub fn set_max_len(&mut self, max_len: c_int) { + self.fp.max_len = max_len; + } + + /// # EXPERIMENTAL + /// + /// Set maximum tokens per segment. 0 means no limit. + /// + /// Defaults to 0. + pub fn set_max_tokens(&mut self, max_tokens: c_int) { + self.fp.max_tokens = max_tokens; + } + + /// # EXPERIMENTAL + /// + /// Speed up audio ~2x by using phase vocoder. + /// + /// Defaults to false. + pub fn set_speed_up(&mut self, speed_up: bool) { + self.fp.speed_up = speed_up; + } + + /// # EXPERIMENTAL + /// + /// Overwrite the audio context size. 0 = default. + /// + /// Defaults to 0. + pub fn set_audio_ctx(&mut self, audio_ctx: c_int) { + self.fp.audio_ctx = audio_ctx; + } + + /// Set tokens to provide the model as initial input. + /// + /// These tokens are prepended to any existing text content from a previous call. + /// + /// Calling this more than once will overwrite the previous tokens. + /// + /// Defaults to an empty vector. + pub fn set_tokens(&mut self, tokens: &'b [c_int]) { + // turn into ptr and len + let tokens_ptr: *const whisper_token = tokens.as_ptr(); + let tokens_len: c_int = tokens.len() as c_int; + + // set the tokens + self.fp.prompt_tokens = tokens_ptr; + self.fp.prompt_n_tokens = tokens_len; + } + + /// Set the target language. /// /// Defaults to "en". @@ -144,10 +249,38 @@ impl<'a> FullParams<'a> { pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { self.fp.new_segment_callback_user_data = user_data; } + + /// Set the callback for starting the encoder. + /// + /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). + /// It is still a C callback. + /// + /// # Safety + /// Do not use this function unless you know what you are doing. + /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// + /// Defaults to None. + pub unsafe fn set_start_encoder_callback( + &mut self, + start_encoder_callback: crate::WhisperStartEncoderCallback, + ) { + self.fp.encoder_begin_callback = start_encoder_callback; + } + + /// Set the user data to be passed to the start encoder callback. + /// + /// # Safety + /// See the safety notes for `set_start_encoder_callback`. + /// + /// Defaults to None. + pub unsafe fn set_start_encoder_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { + self.fp.encoder_begin_callback_user_data = user_data; + } } // following implementations are safe // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // concurrent usage is prevented by &mut self on methods that modify the struct -unsafe impl<'a> Send for FullParams<'a> {} -unsafe impl<'a> Sync for FullParams<'a> {} +unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {} +unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}