diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 425caa2f45..89f8191dfa 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -330,41 +330,23 @@ impl LanguageModel for LmStudioLanguageModel { let future = self.request_limiter.stream(async move { let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(fragment) => { - // Skip empty deltas - if fragment.choices[0].delta.is_object() - && fragment.choices[0].delta.as_object().unwrap().is_empty() - { - return None; - } - // Try to parse the delta as ChatMessage - if let Ok(chat_message) = serde_json::from_value::( - fragment.choices[0].delta.clone(), - ) { - let content = match chat_message { - ChatMessage::User { content } => content, - ChatMessage::Assistant { content, .. } => { - content.unwrap_or_default() - } - ChatMessage::System { content } => content, - }; - if !content.is_empty() { - Some(Ok(content)) - } else { - None - } - } else { - None - } - } + // Create a stream mapper to handle content across multiple deltas + let stream_mapper = LmStudioStreamMapper::new(); + + let stream = response + .map(move |response| { + response.and_then(|fragment| stream_mapper.process_fragment(fragment)) + }) + .filter_map(|result| async move { + match result { + Ok(Some(content)) => Some(Ok(content)), + Ok(None) => None, Err(error) => Some(Err(error)), } }) .boxed(); + Ok(stream) }); @@ -382,6 +364,40 @@ impl LanguageModel for LmStudioLanguageModel { } } +// This will be more useful when we implement tool calling. Currently keeping it empty. +struct LmStudioStreamMapper {} + +impl LmStudioStreamMapper { + fn new() -> Self { + Self {} + } + + fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result> { + // Most of the time, there will be only one choice + let Some(choice) = fragment.choices.first() else { + return Ok(None); + }; + + // Extract the delta content + if let Ok(delta) = + serde_json::from_value::(choice.delta.clone()) + { + if let Some(content) = delta.content { + if !content.is_empty() { + return Ok(Some(content)); + } + } + } + + // If there's a finish_reason, we're done + if choice.finish_reason.is_some() { + return Ok(None); + } + + Ok(None) + } +} + struct ConfigurationView { state: gpui::Entity, loading_models_task: Option>, diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index 0d17db63c1..8cad0eccb7 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -221,6 +221,14 @@ pub enum CompatibilityType { Mlx, } +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessageDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + pub async fn complete( client: &dyn HttpClient, api_url: &str,