Add tool use support for OpenAI models (#28051)
This PR adds support for using tools to the OpenAI models. Release Notes: - agent: Added support for tool use with OpenAI models (Preview only).
This commit is contained in:
parent
4d8df0a00b
commit
7492ec3f67
@ -587,7 +587,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
match self.model {
|
||||
CloudModel::Anthropic(_) => true,
|
||||
CloudModel::Google(_) => true,
|
||||
CloudModel::OpenAi(_) => false,
|
||||
CloudModel::OpenAi(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
@ -705,15 +705,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
Ok(
|
||||
crate::provider::open_ai::map_to_language_model_completion_events(
|
||||
Box::pin(response_lines(response)),
|
||||
),
|
||||
)
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
|
@ -1,7 +1,8 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
@ -10,17 +11,20 @@ use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
use open_ai::{ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
@ -322,12 +326,8 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
> {
|
||||
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
Ok(open_ai::extract_text_from_events(completions.await?)
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,33 +337,186 @@ pub fn into_open_ai(
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model.starts_with("o1-");
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => messages.push(match message.role {
|
||||
Role::User => open_ai::RequestMessage::User { content: text },
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = open_ai::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: open_ai::ToolCallContent::Function {
|
||||
function: open_ai::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| {
|
||||
matches!(message, open_ai::RequestMessage::Assistant { .. })
|
||||
}) {
|
||||
if let open_ai::RequestMessage::Assistant { tool_calls, .. } =
|
||||
last_assistant_message
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
}
|
||||
} else {
|
||||
messages.push(open_ai::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
messages.push(open_ai::RequestMessage::Tool {
|
||||
content: tool_result.content.to_string(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
open_ai::Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(msg.string_contents()),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
messages,
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
max_tokens: max_output_tokens,
|
||||
tools: Vec::new(),
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = state
|
||||
.tool_calls_by_index
|
||||
.entry(tool_call.index)
|
||||
.or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
},
|
||||
));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::ToolUse,
|
||||
)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.flat_map(futures::stream::iter)
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: open_ai::Model,
|
||||
|
@ -2,7 +2,7 @@ mod supported_countries;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::{
|
||||
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
|
||||
AsyncBufReadExt, AsyncReadExt, StreamExt,
|
||||
io::BufReader,
|
||||
stream::{self, BoxStream},
|
||||
};
|
||||
@ -618,14 +618,3 @@ pub fn embed<'a>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
response.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user