Add tool use support for OpenAI models (#28051)

This PR adds support for using tools to the OpenAI models.

Release Notes:

- agent: Added support for tool use with OpenAI models (Preview only).
This commit is contained in:
Marshall Bowers 2025-04-03 16:55:11 -04:00 committed by GitHub
parent 4d8df0a00b
commit 7492ec3f67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 188 additions and 48 deletions

View File

@ -587,7 +587,7 @@ impl LanguageModel for CloudLanguageModel {
match self.model {
CloudModel::Anthropic(_) => true,
CloudModel::Google(_) => true,
CloudModel::OpenAi(_) => false,
CloudModel::OpenAi(_) => true,
}
}
@ -705,15 +705,13 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
Ok(
crate::provider::open_ai::map_to_language_model_completion_events(
Box::pin(response_lines(response)),
),
)
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();

View File

@ -1,7 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@ -10,17 +11,20 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
};
use open_ai::{ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use util::{ResultExt, maybe};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn supports_tools(&self) -> bool {
false
true
}
fn telemetry_id(&self) -> String {
@ -322,12 +326,8 @@ impl LanguageModel for OpenAiLanguageModel {
> {
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
Ok(open_ai::extract_text_from_events(completions.await?)
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
.boxed()
}
}
@ -337,33 +337,186 @@ pub fn into_open_ai(
max_output_tokens: Option<u32>,
) -> open_ai::Request {
let stream = !model.starts_with("o1-");
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) => messages.push(match message.role {
Role::User => open_ai::RequestMessage::User { content: text },
Role::Assistant => open_ai::RequestMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System { content: text },
}),
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
let tool_call = open_ai::ToolCall {
id: tool_use.id.to_string(),
content: open_ai::ToolCallContent::Function {
function: open_ai::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
},
},
};
if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| {
matches!(message, open_ai::RequestMessage::Assistant { .. })
}) {
if let open_ai::RequestMessage::Assistant { tool_calls, .. } =
last_assistant_message
{
tool_calls.push(tool_call);
}
} else {
messages.push(open_ai::RequestMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
});
}
}
MessageContent::ToolResult(tool_result) => {
messages.push(open_ai::RequestMessage::Tool {
content: tool_result.content.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
}
}
}
open_ai::Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => open_ai::RequestMessage::User {
content: msg.string_contents(),
},
Role::Assistant => open_ai::RequestMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System {
content: msg.string_contents(),
},
})
.collect(),
messages,
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
tools: Vec::new(),
tools: request
.tools
.into_iter()
.map(|tool| open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
tool_choice: None,
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
futures::stream::unfold(
State {
events,
tool_calls_by_index: HashMap::default(),
},
|mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
state,
));
};
let mut events = Vec::new();
if let Some(content) = choice.delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
for tool_call in tool_calls {
let entry = state
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
},
));
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::ToolUse,
)));
}
Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
None => {}
}
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
}
}
None
},
)
.flat_map(futures::stream::iter)
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,

View File

@ -2,7 +2,7 @@ mod supported_countries;
use anyhow::{Context as _, Result, anyhow};
use futures::{
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
AsyncBufReadExt, AsyncReadExt, StreamExt,
io::BufReader,
stream::{self, BoxStream},
};
@ -618,14 +618,3 @@ pub fn embed<'a>(
}
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
}