diff --git a/Cargo.toml b/Cargo.toml index 30cf26b..7ace970 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["sys"] [package] name = "whisper-rs" -version = "0.1.3" +version = "0.2.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -13,7 +13,7 @@ repository = "https://github.com/tazz4843/whisper-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -whisper-rs-sys = { path = "sys", version = "0.1" } +whisper-rs-sys = { path = "sys", version = "0.2" } [features] simd = [] diff --git a/README.md b/README.md index 9d3fdfd..a04d5fb 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,14 @@ See the docs: https://docs.rs/whisper-rs/ for more details. * Windows/macOS/Android aren't working! * I don't have a way to test these platforms, so I can't really help you. * If you can get it working, please open a PR! +* I get a panic during binding generation build! + * You can attempt to fix it yourself, or you can set the `WHISPER_DONT_GENERATE_BINDINGS` environment variable. + This skips attempting to build the bindings whatsoever and copies the existing ones. They may be out of date, + but it's better than nothing. + * `WHISPER_DONT_GENERATE_BINDINGS=1 cargo build` + * If you can fix the issue, please open a PR! +* M1 build info: + * See [this issue](https://github.com/tazz4843/whisper-rs/pull/2) for more info. ## License [Unlicense](LICENSE) diff --git a/src/lib.rs b/src/lib.rs index 7bb2fd6..fcc614c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub use error::WhisperError; pub use standalone::*; pub use utilities::*; pub use whisper_ctx::WhisperContext; -pub use whisper_params::{DecodeStrategy, FullParams}; +pub use whisper_params::{FullParams, SamplingStrategy}; pub type WhisperToken = std::ffi::c_int; +pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; diff --git a/src/standalone.rs b/src/standalone.rs index e111568..01262b0 100644 --- a/src/standalone.rs +++ b/src/standalone.rs @@ -42,3 +42,11 @@ pub fn token_translate() -> WhisperToken { pub fn token_transcribe() -> WhisperToken { unsafe { whisper_rs_sys::whisper_token_transcribe() } } + +/// Print system information. +/// +/// # C++ equivalent +/// `const char * whisper_print_system_info()` +pub fn print_system_info() { + unsafe { whisper_rs_sys::whisper_print_system_info() }; +} diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 7c3ee96..b479dcc 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -194,11 +194,11 @@ impl WhisperContext { /// /// # C++ equivalent /// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)` - pub fn sample_best(&mut self, needs_timestamp: bool) -> Result { + pub fn sample_best(&mut self) -> Result { if !self.decode_once { return Err(WhisperError::DecodeNotComplete); } - let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) }; + let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) }; Ok(ret) } @@ -446,6 +446,90 @@ impl WhisperContext { let r_str = c_str.to_str()?; Ok(r_str.to_string()) } + + /// Get number of tokens in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` + pub fn full_n_tokens(&self, segment: c_int) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Get the token text of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_text( + &self, + segment: c_int, + token: c_int, + ) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + let c_str = unsafe { CStr::from_ptr(ret) }; + let r_str = c_str.to_str()?; + Ok(r_str.to_string()) + } + + /// Get the token ID of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_id( + &self, + segment: c_int, + token: c_int, + ) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as WhisperToken) + } + } + + /// Get the probability of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// f32 + /// + /// # C++ equivalent + /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` + #[inline] + pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 { + unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) } + } } impl Drop for WhisperContext { diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 7d36c78..a70df18 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -1,7 +1,7 @@ use std::ffi::c_int; use std::marker::PhantomData; -pub enum DecodeStrategy { +pub enum SamplingStrategy { Greedy { n_past: c_int, }, @@ -20,26 +20,30 @@ pub struct FullParams<'a> { impl<'a> FullParams<'a> { /// Create a new set of parameters for the decoder. - pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> { + pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> { let mut fp = unsafe { - whisper_rs_sys::whisper_full_default_params(match decode_strategy { - DecodeStrategy::Greedy { .. } => 0, - DecodeStrategy::BeamSearch { .. } => 1, + whisper_rs_sys::whisper_full_default_params(match sampling_strategy { + SamplingStrategy::Greedy { .. } => { + whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY + } + SamplingStrategy::BeamSearch { .. } => { + whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH + } } as _) }; - match decode_strategy { - DecodeStrategy::Greedy { n_past } => { - fp.__bindgen_anon_1.greedy.n_past = n_past; + match sampling_strategy { + SamplingStrategy::Greedy { n_past } => { + fp.greedy.n_past = n_past; } - DecodeStrategy::BeamSearch { + SamplingStrategy::BeamSearch { n_past, beam_width, n_best, } => { - fp.__bindgen_anon_1.beam_search.n_past = n_past; - fp.__bindgen_anon_1.beam_search.beam_width = beam_width; - fp.__bindgen_anon_1.beam_search.n_best = n_best; + fp.beam_search.n_past = n_past; + fp.beam_search.beam_width = beam_width; + fp.beam_search.n_best = n_best; } } @@ -111,6 +115,34 @@ impl<'a> FullParams<'a> { pub fn set_language(&mut self, language: &'a str) { self.fp.language = language.as_ptr() as *const _; } + + /// Set the callback for new segments. + /// + /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). + /// It is still a C callback. + /// + /// # Safety + /// Do not use this function unless you know what you are doing. + /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// + /// Defaults to None. + pub unsafe fn set_new_segment_callback( + &mut self, + new_segment_callback: crate::WhisperNewSegmentCallback, + ) { + self.fp.new_segment_callback = new_segment_callback; + } + + /// Set the user data to be passed to the new segment callback. + /// + /// # Safety + /// See the safety notes for `set_new_segment_callback`. + /// + /// Defaults to None. + pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { + self.fp.new_segment_callback_user_data = user_data; + } } // following implementations are safe diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 86681fa..45ac4fd 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whisper-rs-sys" -version = "0.1.3" +version = "0.2.0" edition = "2021" description = "Rust bindings for whisper.cpp (FFI bindings)" license = "Unlicense" @@ -13,4 +13,4 @@ links = "whisper" [dependencies] [build-dependencies] -bindgen = "0.60" +bindgen = "0.61" diff --git a/sys/build.rs b/sys/build.rs index 54a20f8..f009a98 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -8,29 +8,36 @@ fn main() { println!("cargo:rustc-link-lib=static=whisper"); println!("cargo:rerun-if-changed=wrapper.h"); - let bindings = bindgen::Builder::default() - .header("wrapper.h") - .clang_arg("-I./whisper.cpp") - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - .generate(); + if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() { + let _: u64 = std::fs::copy( + "src/bindings.rs", + env::var("OUT_DIR").unwrap() + "/bindings.rs", + ).expect("Failed to copy bindings.rs"); + } else { + let bindings = bindgen::Builder::default() + .header("wrapper.h") + .clang_arg("-I./whisper.cpp") + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .generate(); - match bindings { - Ok(b) => { - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - b.write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); + match bindings { + Ok(b) => { + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + b.write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); + } + Err(e) => { + println!("cargo:warning=Unable to generate bindings: {}", e); + println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); + // copy src/bindings.rs to OUT_DIR + std::fs::copy( + "src/bindings.rs", + env::var("OUT_DIR").unwrap() + "/bindings.rs", + ) + .expect("Unable to copy bindings.rs"); + } } - Err(e) => { - println!("cargo:warning=Unable to generate bindings: {}", e); - println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); - // copy src/bindings.rs to OUT_DIR - std::fs::copy( - "src/bindings.rs", - env::var("OUT_DIR").unwrap() + "/bindings.rs", - ) - .expect("Unable to copy bindings.rs"); - } - } + }; // stop if we're on docs.rs if env::var("DOCS_RS").is_ok() { diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs index 449b3aa..7090277 100644 --- a/sys/src/bindings.rs +++ b/sys/src/bindings.rs @@ -1,5 +1,96 @@ -/* automatically generated by rust-bindgen 0.60.1 */ +/* automatically generated by rust-bindgen 0.59.2 */ +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __GLIBC_USE_ISOC2X: u32 = 0; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_ISO_10646__: u32 = 201706; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 31; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __glibc_c99_flexarr_available: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const __LONG_DOUBLE_USES_FLOAT128: u32 = 0; +pub const __HAVE_GENERIC_SELECTION: u32 = 1; +pub const __GLIBC_USE_LIB_EXT2: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT_C2X: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C2X: u32 = 0; +pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; +pub const _BITS_TYPES_H: u32 = 1; +pub const __TIMESIZE: u32 = 64; +pub const _BITS_TYPESIZES_H: u32 = 1; +pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; +pub const __INO_T_MATCHES_INO64_T: u32 = 1; +pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; +pub const __STATFS_MATCHES_STATFS64: u32 = 1; +pub const __FD_SETSIZE: u32 = 1024; +pub const _BITS_TIME64_H: u32 = 1; +pub const _BITS_WCHAR_H: u32 = 1; +pub const _BITS_STDINT_INTN_H: u32 = 1; +pub const _BITS_STDINT_UINTN_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; pub const true_: u32 = 1; pub const false_: u32 = 0; pub const __bool_true_false_are_defined: u32 = 1; @@ -8,24 +99,113 @@ pub const WHISPER_N_FFT: u32 = 400; pub const WHISPER_N_MEL: u32 = 80; pub const WHISPER_HOP_LENGTH: u32 = 160; pub const WHISPER_CHUNK_SIZE: u32 = 30; -pub type int_least64_t = i64; -pub type uint_least64_t = u64; -pub type int_fast64_t = i64; -pub type uint_fast64_t = u64; -pub type int_least32_t = i32; -pub type uint_least32_t = u32; -pub type int_fast32_t = i32; -pub type uint_fast32_t = u32; -pub type int_least16_t = i16; -pub type uint_least16_t = u16; -pub type int_fast16_t = i16; -pub type uint_fast16_t = u16; -pub type int_least8_t = i8; -pub type uint_least8_t = u8; -pub type int_fast8_t = i8; -pub type uint_fast8_t = u8; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; +pub type __u_char = ::std::os::raw::c_uchar; +pub type __u_short = ::std::os::raw::c_ushort; +pub type __u_int = ::std::os::raw::c_uint; +pub type __u_long = ::std::os::raw::c_ulong; +pub type __int8_t = ::std::os::raw::c_schar; +pub type __uint8_t = ::std::os::raw::c_uchar; +pub type __int16_t = ::std::os::raw::c_short; +pub type __uint16_t = ::std::os::raw::c_ushort; +pub type __int32_t = ::std::os::raw::c_int; +pub type __uint32_t = ::std::os::raw::c_uint; +pub type __int64_t = ::std::os::raw::c_long; +pub type __uint64_t = ::std::os::raw::c_ulong; +pub type __int_least8_t = __int8_t; +pub type __uint_least8_t = __uint8_t; +pub type __int_least16_t = __int16_t; +pub type __uint_least16_t = __uint16_t; +pub type __int_least32_t = __int32_t; +pub type __uint_least32_t = __uint32_t; +pub type __int_least64_t = __int64_t; +pub type __uint_least64_t = __uint64_t; +pub type __quad_t = ::std::os::raw::c_long; +pub type __u_quad_t = ::std::os::raw::c_ulong; +pub type __intmax_t = ::std::os::raw::c_long; +pub type __uintmax_t = ::std::os::raw::c_ulong; +pub type __dev_t = ::std::os::raw::c_ulong; +pub type __uid_t = ::std::os::raw::c_uint; +pub type __gid_t = ::std::os::raw::c_uint; +pub type __ino_t = ::std::os::raw::c_ulong; +pub type __ino64_t = ::std::os::raw::c_ulong; +pub type __mode_t = ::std::os::raw::c_uint; +pub type __nlink_t = ::std::os::raw::c_ulong; +pub type __off_t = ::std::os::raw::c_long; +pub type __off64_t = ::std::os::raw::c_long; +pub type __pid_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __fsid_t { + pub __val: [::std::os::raw::c_int; 2usize], +} +#[test] +fn bindgen_test_layout___fsid_t() { + assert_eq!( + ::std::mem::size_of::<__fsid_t>(), + 8usize, + concat!("Size of: ", stringify!(__fsid_t)) + ); + assert_eq!( + ::std::mem::align_of::<__fsid_t>(), + 4usize, + concat!("Alignment of ", stringify!(__fsid_t)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::<__fsid_t>())).__val as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(__fsid_t), + "::", + stringify!(__val) + ) + ); +} +pub type __clock_t = ::std::os::raw::c_long; +pub type __rlim_t = ::std::os::raw::c_ulong; +pub type __rlim64_t = ::std::os::raw::c_ulong; +pub type __id_t = ::std::os::raw::c_uint; +pub type __time_t = ::std::os::raw::c_long; +pub type __useconds_t = ::std::os::raw::c_uint; +pub type __suseconds_t = ::std::os::raw::c_long; +pub type __daddr_t = ::std::os::raw::c_int; +pub type __key_t = ::std::os::raw::c_int; +pub type __clockid_t = ::std::os::raw::c_int; +pub type __timer_t = *mut ::std::os::raw::c_void; +pub type __blksize_t = ::std::os::raw::c_long; +pub type __blkcnt_t = ::std::os::raw::c_long; +pub type __blkcnt64_t = ::std::os::raw::c_long; +pub type __fsblkcnt_t = ::std::os::raw::c_ulong; +pub type __fsblkcnt64_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt64_t = ::std::os::raw::c_ulong; +pub type __fsword_t = ::std::os::raw::c_long; +pub type __ssize_t = ::std::os::raw::c_long; +pub type __syscall_slong_t = ::std::os::raw::c_long; +pub type __syscall_ulong_t = ::std::os::raw::c_ulong; +pub type __loff_t = __off64_t; +pub type __caddr_t = *mut ::std::os::raw::c_char; +pub type __intptr_t = ::std::os::raw::c_long; +pub type __socklen_t = ::std::os::raw::c_uint; +pub type __sig_atomic_t = ::std::os::raw::c_int; +pub type int_least8_t = __int_least8_t; +pub type int_least16_t = __int_least16_t; +pub type int_least32_t = __int_least32_t; +pub type int_least64_t = __int_least64_t; +pub type uint_least8_t = __uint_least8_t; +pub type uint_least16_t = __uint_least16_t; +pub type uint_least32_t = __uint_least32_t; +pub type uint_least64_t = __uint_least64_t; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = __intmax_t; +pub type uintmax_t = __uintmax_t; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct whisper_context { @@ -71,7 +251,7 @@ extern "C" { ) -> ::std::os::raw::c_int; } extern "C" { - pub fn whisper_sample_best(ctx: *mut whisper_context, need_timestamp: bool) -> whisper_token; + pub fn whisper_sample_best(ctx: *mut whisper_context) -> whisper_token; } extern "C" { pub fn whisper_sample_timestamp(ctx: *mut whisper_context) -> whisper_token; @@ -127,14 +307,17 @@ extern "C" { extern "C" { pub fn whisper_print_timings(ctx: *mut whisper_context); } -pub const whisper_decode_strategy_WHISPER_DECODE_GREEDY: whisper_decode_strategy = 0; -pub const whisper_decode_strategy_WHISPER_DECODE_BEAM_SEARCH: whisper_decode_strategy = 1; +pub const whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY: whisper_sampling_strategy = 0; +pub const whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH: whisper_sampling_strategy = 1; #[doc = ""] -pub type whisper_decode_strategy = ::std::os::raw::c_uint; +pub type whisper_sampling_strategy = ::std::os::raw::c_uint; +pub type whisper_new_segment_callback = ::std::option::Option< + unsafe extern "C" fn(ctx: *mut whisper_context, user_data: *mut ::std::os::raw::c_void), +>; #[repr(C)] -#[derive(Copy, Clone)] +#[derive(Debug, Copy, Clone)] pub struct whisper_full_params { - pub strategy: whisper_decode_strategy, + pub strategy: whisper_sampling_strategy, pub n_threads: ::std::os::raw::c_int, pub offset_ms: ::std::os::raw::c_int, pub translate: bool, @@ -144,145 +327,21 @@ pub struct whisper_full_params { pub print_realtime: bool, pub print_timestamps: bool, pub language: *const ::std::os::raw::c_char, - pub __bindgen_anon_1: whisper_full_params__bindgen_ty_1, -} -#[repr(C)] -#[derive(Copy, Clone)] -pub union whisper_full_params__bindgen_ty_1 { - pub greedy: whisper_full_params__bindgen_ty_1__bindgen_ty_1, - pub beam_search: whisper_full_params__bindgen_ty_1__bindgen_ty_2, + pub greedy: whisper_full_params__bindgen_ty_1, + pub beam_search: whisper_full_params__bindgen_ty_2, + pub new_segment_callback: whisper_new_segment_callback, + pub new_segment_callback_user_data: *mut ::std::os::raw::c_void, } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct whisper_full_params__bindgen_ty_1__bindgen_ty_1 { +pub struct whisper_full_params__bindgen_ty_1 { pub n_past: ::std::os::raw::c_int, } #[test] -fn bindgen_test_layout_whisper_full_params__bindgen_ty_1__bindgen_ty_1() { - assert_eq!( - ::std::mem::size_of::(), - 4usize, - concat!( - "Size of: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_1) - ) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!( - "Alignment of ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_1) - ) - ); - fn test_field_n_past() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::< - whisper_full_params__bindgen_ty_1__bindgen_ty_1, - >::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).n_past) as usize - ptr as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_1), - "::", - stringify!(n_past) - ) - ); - } - test_field_n_past(); -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct whisper_full_params__bindgen_ty_1__bindgen_ty_2 { - pub n_past: ::std::os::raw::c_int, - pub beam_width: ::std::os::raw::c_int, - pub n_best: ::std::os::raw::c_int, -} -#[test] -fn bindgen_test_layout_whisper_full_params__bindgen_ty_1__bindgen_ty_2() { - assert_eq!( - ::std::mem::size_of::(), - 12usize, - concat!( - "Size of: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_2) - ) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!( - "Alignment of ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_2) - ) - ); - fn test_field_n_past() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::< - whisper_full_params__bindgen_ty_1__bindgen_ty_2, - >::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).n_past) as usize - ptr as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_2), - "::", - stringify!(n_past) - ) - ); - } - test_field_n_past(); - fn test_field_beam_width() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::< - whisper_full_params__bindgen_ty_1__bindgen_ty_2, - >::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).beam_width) as usize - ptr as usize - }, - 4usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_2), - "::", - stringify!(beam_width) - ) - ); - } - test_field_beam_width(); - fn test_field_n_best() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::< - whisper_full_params__bindgen_ty_1__bindgen_ty_2, - >::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).n_best) as usize - ptr as usize - }, - 8usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1__bindgen_ty_2), - "::", - stringify!(n_best) - ) - ); - } - test_field_n_best(); -} -#[test] fn bindgen_test_layout_whisper_full_params__bindgen_ty_1() { assert_eq!( ::std::mem::size_of::(), - 12usize, + 4usize, concat!("Size of: ", stringify!(whisper_full_params__bindgen_ty_1)) ); assert_eq!( @@ -293,46 +352,87 @@ fn bindgen_test_layout_whisper_full_params__bindgen_ty_1() { stringify!(whisper_full_params__bindgen_ty_1) ) ); - fn test_field_greedy() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).greedy) as usize - ptr as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1), - "::", - stringify!(greedy) - ) - ); - } - test_field_greedy(); - fn test_field_beam_search() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).beam_search) as usize - ptr as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params__bindgen_ty_1), - "::", - stringify!(beam_search) - ) - ); - } - test_field_beam_search(); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).n_past as *const _ + as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params__bindgen_ty_1), + "::", + stringify!(n_past) + ) + ); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_full_params__bindgen_ty_2 { + pub n_past: ::std::os::raw::c_int, + pub beam_width: ::std::os::raw::c_int, + pub n_best: ::std::os::raw::c_int, +} +#[test] +fn bindgen_test_layout_whisper_full_params__bindgen_ty_2() { + assert_eq!( + ::std::mem::size_of::(), + 12usize, + concat!("Size of: ", stringify!(whisper_full_params__bindgen_ty_2)) + ); + assert_eq!( + ::std::mem::align_of::(), + 4usize, + concat!( + "Alignment of ", + stringify!(whisper_full_params__bindgen_ty_2) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).n_past as *const _ + as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params__bindgen_ty_2), + "::", + stringify!(n_past) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).beam_width as *const _ + as usize + }, + 4usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params__bindgen_ty_2), + "::", + stringify!(beam_width) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).n_best as *const _ + as usize + }, + 8usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params__bindgen_ty_2), + "::", + stringify!(n_best) + ) + ); } #[test] fn bindgen_test_layout_whisper_full_params() { assert_eq!( ::std::mem::size_of::(), - 48usize, + 64usize, concat!("Size of: ", stringify!(whisper_full_params)) ); assert_eq!( @@ -340,179 +440,164 @@ fn bindgen_test_layout_whisper_full_params() { 8usize, concat!("Alignment of ", stringify!(whisper_full_params)) ); - fn test_field_strategy() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).strategy) as usize - ptr as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(strategy) - ) - ); - } - test_field_strategy(); - fn test_field_n_threads() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).n_threads) as usize - ptr as usize - }, - 4usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(n_threads) - ) - ); - } - test_field_n_threads(); - fn test_field_offset_ms() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).offset_ms) as usize - ptr as usize - }, - 8usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(offset_ms) - ) - ); - } - test_field_offset_ms(); - fn test_field_translate() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).translate) as usize - ptr as usize - }, - 12usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(translate) - ) - ); - } - test_field_translate(); - fn test_field_no_context() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).no_context) as usize - ptr as usize - }, - 13usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(no_context) - ) - ); - } - test_field_no_context(); - fn test_field_print_special_tokens() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).print_special_tokens) as usize - ptr as usize - }, - 14usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(print_special_tokens) - ) - ); - } - test_field_print_special_tokens(); - fn test_field_print_progress() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).print_progress) as usize - ptr as usize - }, - 15usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(print_progress) - ) - ); - } - test_field_print_progress(); - fn test_field_print_realtime() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).print_realtime) as usize - ptr as usize - }, - 16usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(print_realtime) - ) - ); - } - test_field_print_realtime(); - fn test_field_print_timestamps() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).print_timestamps) as usize - ptr as usize - }, - 17usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(print_timestamps) - ) - ); - } - test_field_print_timestamps(); - fn test_field_language() { - assert_eq!( - unsafe { - let uninit = ::std::mem::MaybeUninit::::uninit(); - let ptr = uninit.as_ptr(); - ::std::ptr::addr_of!((*ptr).language) as usize - ptr as usize - }, - 24usize, - concat!( - "Offset of field: ", - stringify!(whisper_full_params), - "::", - stringify!(language) - ) - ); - } - test_field_language(); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).strategy as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(strategy) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).n_threads as *const _ as usize }, + 4usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(n_threads) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).offset_ms as *const _ as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(offset_ms) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).translate as *const _ as usize }, + 12usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(translate) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).no_context as *const _ as usize }, + 13usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(no_context) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).print_special_tokens as *const _ + as usize + }, + 14usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(print_special_tokens) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).print_progress as *const _ as usize + }, + 15usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(print_progress) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).print_realtime as *const _ as usize + }, + 16usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(print_realtime) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).print_timestamps as *const _ as usize + }, + 17usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(print_timestamps) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).language as *const _ as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(language) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).greedy as *const _ as usize }, + 32usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(greedy) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).beam_search as *const _ as usize }, + 36usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(beam_search) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).new_segment_callback as *const _ + as usize + }, + 48usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(new_segment_callback) + ) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).new_segment_callback_user_data + as *const _ as usize + }, + 56usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(new_segment_callback_user_data) + ) + ); } extern "C" { - pub fn whisper_full_default_params(strategy: whisper_decode_strategy) -> whisper_full_params; + pub fn whisper_full_default_params(strategy: whisper_sampling_strategy) -> whisper_full_params; } extern "C" { pub fn whisper_full( @@ -543,3 +628,33 @@ extern "C" { i_segment: ::std::os::raw::c_int, ) -> *const ::std::os::raw::c_char; } +extern "C" { + pub fn whisper_full_n_tokens( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_full_get_token_text( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn whisper_full_get_token_id( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token; +} +extern "C" { + pub fn whisper_full_get_token_p( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> f32; +} +extern "C" { + pub fn whisper_print_system_info() -> *const ::std::os::raw::c_char; +} diff --git a/sys/whisper.cpp b/sys/whisper.cpp index 0ad085f..2c281d1 160000 --- a/sys/whisper.cpp +++ b/sys/whisper.cpp @@ -1 +1 @@ -Subproject commit 0ad085f5e88576034adc871600da296ff8016803 +Subproject commit 2c281d190b7ec351b8128ba386d110f100993973