From 7492ec3f67c1e0c34fe81032cac29d0327c5d60c Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 3 Apr 2025 16:55:11 -0400 Subject: [PATCH] Add tool use support for OpenAI models (#28051) This PR adds support for using tools to the OpenAI models. Release Notes: - agent: Added support for tool use with OpenAI models (Preview only). --- crates/language_models/src/provider/cloud.rs | 16 +- .../language_models/src/provider/open_ai.rs | 207 +++++++++++++++--- crates/open_ai/src/open_ai.rs | 13 +- 3 files changed, 188 insertions(+), 48 deletions(-) diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 05b3b23e4e..975e504580 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -587,7 +587,7 @@ impl LanguageModel for CloudLanguageModel { match self.model { CloudModel::Anthropic(_) => true, CloudModel::Google(_) => true, - CloudModel::OpenAi(_) => false, + CloudModel::OpenAi(_) => true, } } @@ -705,15 +705,13 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok(open_ai::extract_text_from_events(response_lines(response))) + Ok( + crate::provider::open_ai::map_to_language_model_completion_events( + Box::pin(response_lines(response)), + ), + ) }); - async move { - Ok(future - .await? - .map(|result| result.map(LanguageModelCompletionEvent::Text)) - .boxed()) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::Google(model) => { let client = self.client.clone(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 170006a943..0f02642e25 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,7 +1,8 @@ use anyhow::{Context as _, Result, anyhow}; -use collections::BTreeMap; +use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; +use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, @@ -10,17 +11,20 @@ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, }; use open_ai::{ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr as _; use std::sync::Arc; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::ResultExt; +use util::{ResultExt, maybe}; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; @@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel { } fn supports_tools(&self) -> bool { - false + true } fn telemetry_id(&self) -> String { @@ -322,12 +326,8 @@ impl LanguageModel for OpenAiLanguageModel { > { let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens()); let completions = self.stream_completion(request, cx); - async move { - Ok(open_ai::extract_text_from_events(completions.await?) - .map(|result| result.map(LanguageModelCompletionEvent::Text)) - .boxed()) - } - .boxed() + async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) } + .boxed() } } @@ -337,33 +337,186 @@ pub fn into_open_ai( max_output_tokens: Option, ) -> open_ai::Request { let stream = !model.starts_with("o1-"); + + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) => messages.push(match message.role { + Role::User => open_ai::RequestMessage::User { content: text }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { content: text }, + }), + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = open_ai::ToolCall { + id: tool_use.id.to_string(), + content: open_ai::ToolCallContent::Function { + function: open_ai::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| { + matches!(message, open_ai::RequestMessage::Assistant { .. }) + }) { + if let open_ai::RequestMessage::Assistant { tool_calls, .. } = + last_assistant_message + { + tool_calls.push(tool_call); + } + } else { + messages.push(open_ai::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + messages.push(open_ai::RequestMessage::Tool { + content: tool_result.content.to_string(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + open_ai::Request { model, - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => open_ai::RequestMessage::User { - content: msg.string_contents(), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(msg.string_contents()), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: msg.string_contents(), - }, - }) - .collect(), + messages, stream, stop: request.stop, temperature: request.temperature.unwrap_or(1.0), max_tokens: max_output_tokens, - tools: Vec::new(), + tools: request + .tools + .into_iter() + .map(|tool| open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), tool_choice: None, } } +pub fn map_to_language_model_completion_events( + events: Pin>>>, +) -> impl Stream> { + #[derive(Default)] + struct RawToolCall { + id: String, + name: String, + arguments: String, + } + + struct State { + events: Pin>>>, + tool_calls_by_index: HashMap, + } + + futures::stream::unfold( + State { + events, + tool_calls_by_index: HashMap::default(), + }, + |mut state| async move { + if let Some(event) = state.events.next().await { + match event { + Ok(event) => { + let Some(choice) = event.choices.first() else { + return Some(( + vec![Err(anyhow!("Response contained no choices"))], + state, + )); + }; + + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = state + .tool_calls_by_index + .entry(tool_call.index) + .or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + Some("tool_calls") => { + events.extend(state.tool_calls_by_index.drain().map( + |(_, tool_call)| { + maybe!({ + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.into(), + name: tool_call.name.as_str().into(), + input: serde_json::Value::from_str( + &tool_call.arguments, + )?, + }, + )) + }) + }, + )); + + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::ToolUse, + ))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + None => {} + } + + return Some((events, state)); + } + Err(err) => return Some((vec![Err(err)], state)), + } + } + + None + }, + ) + .flat_map(futures::stream::iter) +} + pub fn count_open_ai_tokens( request: LanguageModelRequest, model: open_ai::Model, diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 4ad13b1d89..586d864da4 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -2,7 +2,7 @@ mod supported_countries; use anyhow::{Context as _, Result, anyhow}; use futures::{ - AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, + AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::{self, BoxStream}, }; @@ -618,14 +618,3 @@ pub fn embed<'a>( } } } - -pub fn extract_text_from_events( - response: impl Stream>, -) -> impl Stream> { - response.filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) -}