Add tool use support for OpenAI models (#28051)
This PR adds support for using tools to the OpenAI models. Release Notes: - agent: Added support for tool use with OpenAI models (Preview only).
This commit is contained in:
parent
4d8df0a00b
commit
7492ec3f67
@ -587,7 +587,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
match self.model {
|
match self.model {
|
||||||
CloudModel::Anthropic(_) => true,
|
CloudModel::Anthropic(_) => true,
|
||||||
CloudModel::Google(_) => true,
|
CloudModel::Google(_) => true,
|
||||||
CloudModel::OpenAi(_) => false,
|
CloudModel::OpenAi(_) => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -705,15 +705,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
Ok(
|
||||||
|
crate::provider::open_ai::map_to_language_model_completion_events(
|
||||||
|
Box::pin(response_lines(response)),
|
||||||
|
),
|
||||||
|
)
|
||||||
});
|
});
|
||||||
async move {
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
Ok(future
|
|
||||||
.await?
|
|
||||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
|
||||||
.boxed())
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use collections::BTreeMap;
|
use collections::{BTreeMap, HashMap};
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
|
use futures::Stream;
|
||||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||||
@ -10,17 +11,20 @@ use http_client::HttpClient;
|
|||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||||
|
RateLimiter, Role, StopReason,
|
||||||
};
|
};
|
||||||
use open_ai::{ResponseStreamEvent, stream_completion};
|
use open_ai::{ResponseStreamEvent, stream_completion};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::str::FromStr as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||||
use util::ResultExt;
|
use util::{ResultExt, maybe};
|
||||||
|
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
false
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
@ -322,11 +326,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||||||
> {
|
> {
|
||||||
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||||
let completions = self.stream_completion(request, cx);
|
let completions = self.stream_completion(request, cx);
|
||||||
async move {
|
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
|
||||||
Ok(open_ai::extract_text_from_events(completions.await?)
|
|
||||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
|
||||||
.boxed())
|
|
||||||
}
|
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -337,33 +337,186 @@ pub fn into_open_ai(
|
|||||||
max_output_tokens: Option<u32>,
|
max_output_tokens: Option<u32>,
|
||||||
) -> open_ai::Request {
|
) -> open_ai::Request {
|
||||||
let stream = !model.starts_with("o1-");
|
let stream = !model.starts_with("o1-");
|
||||||
open_ai::Request {
|
|
||||||
model,
|
let mut messages = Vec::new();
|
||||||
messages: request
|
for message in request.messages {
|
||||||
.messages
|
for content in message.content {
|
||||||
.into_iter()
|
match content {
|
||||||
.map(|msg| match msg.role {
|
MessageContent::Text(text) => messages.push(match message.role {
|
||||||
Role::User => open_ai::RequestMessage::User {
|
Role::User => open_ai::RequestMessage::User { content: text },
|
||||||
content: msg.string_contents(),
|
|
||||||
},
|
|
||||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||||
content: Some(msg.string_contents()),
|
content: Some(text),
|
||||||
tool_calls: Vec::new(),
|
tool_calls: Vec::new(),
|
||||||
},
|
},
|
||||||
Role::System => open_ai::RequestMessage::System {
|
Role::System => open_ai::RequestMessage::System { content: text },
|
||||||
content: msg.string_contents(),
|
}),
|
||||||
|
MessageContent::Image(_) => {}
|
||||||
|
MessageContent::ToolUse(tool_use) => {
|
||||||
|
let tool_call = open_ai::ToolCall {
|
||||||
|
id: tool_use.id.to_string(),
|
||||||
|
content: open_ai::ToolCallContent::Function {
|
||||||
|
function: open_ai::FunctionContent {
|
||||||
|
name: tool_use.name.to_string(),
|
||||||
|
arguments: serde_json::to_string(&tool_use.input)
|
||||||
|
.unwrap_or_default(),
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
.collect(),
|
};
|
||||||
|
|
||||||
|
if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| {
|
||||||
|
matches!(message, open_ai::RequestMessage::Assistant { .. })
|
||||||
|
}) {
|
||||||
|
if let open_ai::RequestMessage::Assistant { tool_calls, .. } =
|
||||||
|
last_assistant_message
|
||||||
|
{
|
||||||
|
tool_calls.push(tool_call);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages.push(open_ai::RequestMessage::Assistant {
|
||||||
|
content: None,
|
||||||
|
tool_calls: vec![tool_call],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::ToolResult(tool_result) => {
|
||||||
|
messages.push(open_ai::RequestMessage::Tool {
|
||||||
|
content: tool_result.content.to_string(),
|
||||||
|
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
open_ai::Request {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
stream,
|
stream,
|
||||||
stop: request.stop,
|
stop: request.stop,
|
||||||
temperature: request.temperature.unwrap_or(1.0),
|
temperature: request.temperature.unwrap_or(1.0),
|
||||||
max_tokens: max_output_tokens,
|
max_tokens: max_output_tokens,
|
||||||
tools: Vec::new(),
|
tools: request
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| open_ai::ToolDefinition::Function {
|
||||||
|
function: open_ai::FunctionDefinition {
|
||||||
|
name: tool.name,
|
||||||
|
description: Some(tool.description),
|
||||||
|
parameters: Some(tool.input_schema),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn map_to_language_model_completion_events(
|
||||||
|
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||||
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||||
|
#[derive(Default)]
|
||||||
|
struct RawToolCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||||
|
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
futures::stream::unfold(
|
||||||
|
State {
|
||||||
|
events,
|
||||||
|
tool_calls_by_index: HashMap::default(),
|
||||||
|
},
|
||||||
|
|mut state| async move {
|
||||||
|
if let Some(event) = state.events.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(event) => {
|
||||||
|
let Some(choice) = event.choices.first() else {
|
||||||
|
return Some((
|
||||||
|
vec![Err(anyhow!("Response contained no choices"))],
|
||||||
|
state,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
if let Some(content) = choice.delta.content.clone() {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
let entry = state
|
||||||
|
.tool_calls_by_index
|
||||||
|
.entry(tool_call.index)
|
||||||
|
.or_default();
|
||||||
|
|
||||||
|
if let Some(tool_id) = tool_call.id.clone() {
|
||||||
|
entry.id = tool_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(function) = tool_call.function.as_ref() {
|
||||||
|
if let Some(name) = function.name.clone() {
|
||||||
|
entry.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(arguments) = function.arguments.clone() {
|
||||||
|
entry.arguments.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match choice.finish_reason.as_deref() {
|
||||||
|
Some("stop") => {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::EndTurn,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Some("tool_calls") => {
|
||||||
|
events.extend(state.tool_calls_by_index.drain().map(
|
||||||
|
|(_, tool_call)| {
|
||||||
|
maybe!({
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: tool_call.id.into(),
|
||||||
|
name: tool_call.name.as_str().into(),
|
||||||
|
input: serde_json::Value::from_str(
|
||||||
|
&tool_call.arguments,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::ToolUse,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Some(stop_reason) => {
|
||||||
|
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::EndTurn,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some((events, state));
|
||||||
|
}
|
||||||
|
Err(err) => return Some((vec![Err(err)], state)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.flat_map(futures::stream::iter)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn count_open_ai_tokens(
|
pub fn count_open_ai_tokens(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
model: open_ai::Model,
|
model: open_ai::Model,
|
||||||
|
@ -2,7 +2,7 @@ mod supported_countries;
|
|||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use futures::{
|
use futures::{
|
||||||
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
|
AsyncBufReadExt, AsyncReadExt, StreamExt,
|
||||||
io::BufReader,
|
io::BufReader,
|
||||||
stream::{self, BoxStream},
|
stream::{self, BoxStream},
|
||||||
};
|
};
|
||||||
@ -618,14 +618,3 @@ pub fn embed<'a>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn extract_text_from_events(
|
|
||||||
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
|
||||||
) -> impl Stream<Item = Result<String>> {
|
|
||||||
response.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user