diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index a6f7cf9e29..827ca3f190 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -254,6 +254,7 @@ impl LanguageModel for CopilotChatLanguageModel { Ok(request) => request, Err(err) => return futures::future::ready(Err(err)).boxed(), }; + let is_streaming = copilot_request.stream; let request_limiter = self.request_limiter.clone(); let future = cx.spawn(async move |cx| { @@ -261,7 +262,10 @@ impl LanguageModel for CopilotChatLanguageModel { request_limiter .stream(async move { let response = request.await?; - Ok(map_to_language_model_completion_events(response)) + Ok(map_to_language_model_completion_events( + response, + is_streaming, + )) }) .await }); @@ -271,6 +275,7 @@ impl LanguageModel for CopilotChatLanguageModel { pub fn map_to_language_model_completion_events( events: Pin>>>, + is_streaming: bool, ) -> impl Stream> { #[derive(Default)] struct RawToolCall { @@ -289,7 +294,7 @@ pub fn map_to_language_model_completion_events( events, tool_calls_by_index: HashMap::default(), }, - |mut state| async move { + move |mut state| async move { if let Some(event) = state.events.next().await { match event { Ok(event) => { @@ -300,7 +305,13 @@ pub fn map_to_language_model_completion_events( )); }; - let Some(delta) = choice.delta.as_ref() else { + let delta = if is_streaming { + choice.delta.as_ref() + } else { + choice.message.as_ref() + }; + + let Some(delta) = delta else { return Some(( vec![Err(anyhow!("Response contained no delta"))], state, @@ -312,26 +323,26 @@ pub fn map_to_language_model_completion_events( events.push(Ok(LanguageModelCompletionEvent::Text(content))); } - for tool_call in &delta.tool_calls { - let entry = state - .tool_calls_by_index - .entry(tool_call.index) - .or_default(); + for tool_call in &delta.tool_calls { + let entry = state + .tool_calls_by_index + .entry(tool_call.index) + .or_default(); - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; } - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); } } + } match choice.finish_reason.as_deref() { Some("stop") => { @@ -361,7 +372,7 @@ pub fn map_to_language_model_completion_events( ))); } Some(stop_reason) => { - log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",); + log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}"); events.push(Ok(LanguageModelCompletionEvent::Stop( StopReason::EndTurn, )));