add new params
This commit is contained in:
parent
491eeeadde
commit
f810c63a39
@ -15,3 +15,4 @@ pub use whisper_params::{FullParams, SamplingStrategy};
|
||||
pub type WhisperTokenData = whisper_rs_sys::whisper_token_data;
|
||||
pub type WhisperToken = whisper_rs_sys::whisper_token;
|
||||
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
|
||||
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::ffi::{c_int, CString};
|
||||
use std::ffi::{c_char, c_int, CString};
|
||||
use std::marker::PhantomData;
|
||||
use whisper_rs_sys::whisper_token;
|
||||
|
||||
pub enum SamplingStrategy {
|
||||
Greedy {
|
||||
@ -13,14 +14,15 @@ pub enum SamplingStrategy {
|
||||
},
|
||||
}
|
||||
|
||||
pub struct FullParams<'a> {
|
||||
pub struct FullParams<'a, 'b> {
|
||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||
phantom: PhantomData<&'a str>,
|
||||
phantom_lang: PhantomData<&'a str>,
|
||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||
}
|
||||
|
||||
impl<'a> FullParams<'a> {
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
/// Create a new set of parameters for the decoder.
|
||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> {
|
||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b> {
|
||||
let mut fp = unsafe {
|
||||
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
||||
SamplingStrategy::Greedy { .. } => {
|
||||
@ -49,7 +51,8 @@ impl<'a> FullParams<'a> {
|
||||
|
||||
Self {
|
||||
fp,
|
||||
phantom: PhantomData,
|
||||
phantom_lang: PhantomData,
|
||||
phantom_tokens: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,13 +63,27 @@ impl<'a> FullParams<'a> {
|
||||
self.fp.n_threads = n_threads;
|
||||
}
|
||||
|
||||
/// Set the offset in milliseconds to use for decoding.
|
||||
/// Set n_max_text_ctx.
|
||||
///
|
||||
/// Defaults to 16384.
|
||||
pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) {
|
||||
self.fp.n_max_text_ctx = n_max_text_ctx;
|
||||
}
|
||||
|
||||
/// Set the start offset in milliseconds to use for decoding.
|
||||
///
|
||||
/// Defaults to 0.
|
||||
pub fn set_offset_ms(&mut self, offset_ms: c_int) {
|
||||
self.fp.offset_ms = offset_ms;
|
||||
}
|
||||
|
||||
/// Set the audio duration to process in milliseconds.
|
||||
///
|
||||
/// Defaults to 0.
|
||||
pub fn set_duration_ms(&mut self, duration_ms: c_int) {
|
||||
self.fp.duration_ms = duration_ms;
|
||||
}
|
||||
|
||||
/// Set whether to translate the output to the language specified by `language`.
|
||||
///
|
||||
/// Defaults to false.
|
||||
@ -81,11 +98,18 @@ impl<'a> FullParams<'a> {
|
||||
self.fp.no_context = no_context;
|
||||
}
|
||||
|
||||
/// Set whether to print special tokens.
|
||||
/// Force single segment output. This may be useful for streaming.
|
||||
///
|
||||
/// Defaults to false.
|
||||
pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) {
|
||||
self.fp.print_special_tokens = print_special_tokens;
|
||||
pub fn set_single_segment(&mut self, single_segment: bool) {
|
||||
self.fp.single_segment = single_segment;
|
||||
}
|
||||
|
||||
/// Set print_special. Usage unknown.
|
||||
///
|
||||
/// Defaults to false.
|
||||
pub fn set_print_special(&mut self, print_special: bool) {
|
||||
self.fp.print_special = print_special;
|
||||
}
|
||||
|
||||
/// Set whether to print progress.
|
||||
@ -109,6 +133,87 @@ impl<'a> FullParams<'a> {
|
||||
self.fp.print_timestamps = print_timestamps;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Enable token-level timestamps.
|
||||
///
|
||||
/// Defaults to false.
|
||||
pub fn set_token_timestamps(&mut self, token_timestamps: bool) {
|
||||
self.fp.token_timestamps = token_timestamps;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Set timestamp token probability threshold.
|
||||
///
|
||||
/// Defaults to 0.01.
|
||||
pub fn set_thold_pt(&mut self, thold_pt: f32) {
|
||||
self.fp.thold_pt = thold_pt;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Set timestamp token sum probability threshold.
|
||||
///
|
||||
/// Defaults to 0.01.
|
||||
pub fn set_thold_ptsum(&mut self, thold_ptsum: f32) {
|
||||
self.fp.thold_ptsum = thold_ptsum;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Set maximum segment length in characters.
|
||||
///
|
||||
/// Defaults to 0.
|
||||
pub fn set_max_len(&mut self, max_len: c_int) {
|
||||
self.fp.max_len = max_len;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Set maximum tokens per segment. 0 means no limit.
|
||||
///
|
||||
/// Defaults to 0.
|
||||
pub fn set_max_tokens(&mut self, max_tokens: c_int) {
|
||||
self.fp.max_tokens = max_tokens;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Speed up audio ~2x by using phase vocoder.
|
||||
///
|
||||
/// Defaults to false.
|
||||
pub fn set_speed_up(&mut self, speed_up: bool) {
|
||||
self.fp.speed_up = speed_up;
|
||||
}
|
||||
|
||||
/// # EXPERIMENTAL
|
||||
///
|
||||
/// Overwrite the audio context size. 0 = default.
|
||||
///
|
||||
/// Defaults to 0.
|
||||
pub fn set_audio_ctx(&mut self, audio_ctx: c_int) {
|
||||
self.fp.audio_ctx = audio_ctx;
|
||||
}
|
||||
|
||||
/// Set tokens to provide the model as initial input.
|
||||
///
|
||||
/// These tokens are prepended to any existing text content from a previous call.
|
||||
///
|
||||
/// Calling this more than once will overwrite the previous tokens.
|
||||
///
|
||||
/// Defaults to an empty vector.
|
||||
pub fn set_tokens(&mut self, tokens: &'b [c_int]) {
|
||||
// turn into ptr and len
|
||||
let tokens_ptr: *const whisper_token = tokens.as_ptr();
|
||||
let tokens_len: c_int = tokens.len() as c_int;
|
||||
|
||||
// set the tokens
|
||||
self.fp.prompt_tokens = tokens_ptr;
|
||||
self.fp.prompt_n_tokens = tokens_len;
|
||||
}
|
||||
|
||||
|
||||
/// Set the target language.
|
||||
///
|
||||
/// Defaults to "en".
|
||||
@ -144,10 +249,38 @@ impl<'a> FullParams<'a> {
|
||||
pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
|
||||
self.fp.new_segment_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
/// Set the callback for starting the encoder.
|
||||
///
|
||||
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
|
||||
/// It is still a C callback.
|
||||
///
|
||||
/// # Safety
|
||||
/// Do not use this function unless you know what you are doing.
|
||||
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
|
||||
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub unsafe fn set_start_encoder_callback(
|
||||
&mut self,
|
||||
start_encoder_callback: crate::WhisperStartEncoderCallback,
|
||||
) {
|
||||
self.fp.encoder_begin_callback = start_encoder_callback;
|
||||
}
|
||||
|
||||
/// Set the user data to be passed to the start encoder callback.
|
||||
///
|
||||
/// # Safety
|
||||
/// See the safety notes for `set_start_encoder_callback`.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub unsafe fn set_start_encoder_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
|
||||
self.fp.encoder_begin_callback_user_data = user_data;
|
||||
}
|
||||
}
|
||||
|
||||
// following implementations are safe
|
||||
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
||||
// concurrent usage is prevented by &mut self on methods that modify the struct
|
||||
unsafe impl<'a> Send for FullParams<'a> {}
|
||||
unsafe impl<'a> Sync for FullParams<'a> {}
|
||||
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {}
|
||||
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user