Default to fast model for thread summaries and titles + don't include system prompt / context / thinking segments (#29102)

* Adds a fast / cheaper model to providers and defaults thread
summarization to this model. Initial motivation for this was that
https://github.com/zed-industries/zed/pull/29099 would cause these
requests to fail when used with a thinking model. It doesn't seem
correct to use a thinking model for summarization.

* Skips system prompt, context, and thinking segments.

* If tool use is happening, allows 2 tool uses + one more agent response
before summarizing.

Downside of this is that there was potential for some prefix cache reuse
before, especially for title summarization (thread summarization omitted
tool results and so would not share a prefix for those). This seems fine
as these requests should typically be fairly small. Even for full thread
summarization, skipping all tool use / context should greatly reduce the
token use.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-04-19 17:26:29 -06:00 committed by GitHub
parent d48152d958
commit fbf7caf93e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 270 additions and 205 deletions

View File

@ -1,8 +1,8 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context_picker::MentionLink; use crate::context_picker::MentionLink;
use crate::thread::{ use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
ThreadEvent, ThreadFeedback, ThreadFeedback,
}; };
use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse}; use crate::tool_use::{PendingToolUseStatus, ToolUse};
@ -1285,7 +1285,7 @@ impl ActiveThread {
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.advance_prompt_id(); thread.advance_prompt_id();
thread.send_to_model(model.model, RequestKind::Chat, cx) thread.send_to_model(model.model, cx)
}); });
cx.notify(); cx.notify();
} }

View File

@ -40,7 +40,7 @@ pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant; pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent}; pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore; pub use crate::thread_store::ThreadStore;
pub use agent_diff::{AgentDiff, AgentDiffToolbar}; pub use agent_diff::{AgentDiff, AgentDiffToolbar};

View File

@ -34,7 +34,7 @@ use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector; use crate::profile_selector::ProfileSelector;
use crate::thread::{RequestKind, Thread, TokenUsageRatio}; use crate::thread::{Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
use crate::{ use crate::{
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext, AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
@ -234,7 +234,7 @@ impl MessageEditor {
} }
self.set_editor_is_expanded(false, cx); self.set_editor_is_expanded(false, cx);
self.send_to_model(RequestKind::Chat, window, cx); self.send_to_model(window, cx);
cx.notify(); cx.notify();
} }
@ -249,12 +249,7 @@ impl MessageEditor {
.is_some() .is_some()
} }
fn send_to_model( fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
&mut self,
request_kind: RequestKind,
window: &mut Window,
cx: &mut Context<Self>,
) {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return; return;
@ -331,7 +326,7 @@ impl MessageEditor {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.advance_prompt_id(); thread.advance_prompt_id();
thread.send_to_model(model, request_kind, cx); thread.send_to_model(model, cx);
}) })
.log_err(); .log_err();
}) })
@ -345,7 +340,7 @@ impl MessageEditor {
if cancelled { if cancelled {
self.set_editor_is_expanded(false, cx); self.set_editor_is_expanded(false, cx);
self.send_to_model(RequestKind::Chat, window, cx); self.send_to_model(window, cx);
} }
} }

View File

@ -40,13 +40,6 @@ use crate::thread_store::{
}; };
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER}; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
Chat,
/// Used when summarizing a thread.
Summarize,
}
#[derive( #[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)] )]
@ -949,13 +942,8 @@ impl Thread {
}) })
} }
pub fn send_to_model( pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
&mut self, let mut request = self.to_completion_request(cx);
model: Arc<dyn LanguageModel>,
request_kind: RequestKind,
cx: &mut Context<Self>,
) {
let mut request = self.to_completion_request(request_kind, cx);
if model.supports_tools() { if model.supports_tools() {
request.tools = { request.tools = {
let mut tools = Vec::new(); let mut tools = Vec::new();
@ -994,11 +982,7 @@ impl Thread {
false false
} }
pub fn to_completion_request( pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
&self,
request_kind: RequestKind,
cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest { let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()), thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()),
@ -1045,18 +1029,8 @@ impl Thread {
cache: false, cache: false,
}; };
match request_kind { self.tool_use
RequestKind::Chat => { .attach_tool_results(message.id, &mut request_message);
self.tool_use
.attach_tool_results(message.id, &mut request_message);
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
if self.tool_use.message_has_tool_results(message.id) {
continue;
}
}
}
if !message.context.is_empty() { if !message.context.is_empty() {
request_message request_message
@ -1089,15 +1063,8 @@ impl Thread {
}; };
} }
match request_kind { self.tool_use
RequestKind::Chat => { .attach_tool_uses(message.id, &mut request_message);
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
}
};
request.messages.push(request_message); request.messages.push(request_message);
} }
@ -1112,6 +1079,54 @@ impl Thread {
request request
} }
fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
};
for message in &self.messages {
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
cache: false,
};
// Skip tool results during summarization.
if self.tool_use.message_has_tool_results(message.id) {
continue;
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => request_message
.content
.push(MessageContent::Text(text.clone())),
MessageSegment::Thinking { .. } => {}
MessageSegment::RedactedThinking(_) => {}
}
}
if request_message.content.is_empty() {
continue;
}
request.messages.push(request_message);
}
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(added_user_message)],
cache: false,
});
request
}
fn attached_tracked_files_state( fn attached_tracked_files_state(
&self, &self,
messages: &mut Vec<LanguageModelRequestMessage>, messages: &mut Vec<LanguageModelRequestMessage>,
@ -1293,7 +1308,12 @@ impl Thread {
.pending_completions .pending_completions
.retain(|completion| completion.id != pending_completion_id); .retain(|completion| completion.id != pending_completion_id);
if thread.summary.is_none() && thread.messages.len() >= 2 { // If there is a response without tool use, summarize the message. Otherwise,
// allow two tool uses before summarizing.
if thread.summary.is_none()
&& thread.messages.len() >= 2
&& (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
{
thread.summarize(cx); thread.summarize(cx);
} }
})?; })?;
@ -1403,18 +1423,12 @@ impl Thread {
return; return;
} }
let mut request = self.to_completion_request(RequestKind::Summarize, cx); let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
request.messages.push(LanguageModelRequestMessage { Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
role: Role::User, If the conversation is about a specific subject, include it in the title. \
content: vec![ Be descriptive. DO NOT speak in the first person.";
"Generate a concise 3-7 word title for this conversation, omitting punctuation. \
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ let request = self.to_summarize_request(added_user_message.into());
If the conversation is about a specific subject, include it in the title. \
Be descriptive. DO NOT speak in the first person."
.into(),
],
cache: false,
});
self.pending_summary = cx.spawn(async move |this, cx| { self.pending_summary = cx.spawn(async move |this, cx| {
async move { async move {
@ -1476,21 +1490,14 @@ impl Thread {
return None; return None;
} }
let mut request = self.to_completion_request(RequestKind::Summarize, cx); let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1. A brief overview of what was discussed\n\
2. Key facts or information discovered\n\
3. Outcomes or conclusions reached\n\
4. Any action items or next steps if any\n\
Format it in Markdown with headings and bullet points.";
request.messages.push(LanguageModelRequestMessage { let request = self.to_summarize_request(added_user_message.into());
role: Role::User,
content: vec![
"Generate a detailed summary of this conversation. Include:\n\
1. A brief overview of what was discussed\n\
2. Key facts or information discovered\n\
3. Outcomes or conclusions reached\n\
4. Any action items or next steps if any\n\
Format it in Markdown with headings and bullet points."
.into(),
],
cache: false,
});
let task = cx.spawn(async move |thread, cx| { let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion_text(request, &cx); let stream = model.stream_completion_text(request, &cx);
@ -1538,7 +1545,7 @@ impl Thread {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> { pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx); self.auto_capture_telemetry(cx);
let request = self.to_completion_request(RequestKind::Chat, cx); let request = self.to_completion_request(cx);
let messages = Arc::new(request.messages); let messages = Arc::new(request.messages);
let pending_tool_uses = self let pending_tool_uses = self
.tool_use .tool_use
@ -1650,7 +1657,7 @@ impl Thread {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx); self.attach_tool_results(cx);
if !canceled { if !canceled {
self.send_to_model(model, RequestKind::Chat, cx); self.send_to_model(model, cx);
} }
} }
} }
@ -2275,9 +2282,7 @@ fn main() {{
assert_eq!(message.context, expected_context); assert_eq!(message.context, expected_context);
// Check message in request // Check message in request
let request = thread.update(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2); assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context); let expected_full_message = format!("{}Please explain this code", expected_context);
@ -2367,9 +2372,7 @@ fn main() {{
assert!(message3.context.contains("file3.rs")); assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included // Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
// The request should contain all 3 messages // The request should contain all 3 messages
assert_eq!(request.messages.len(), 4); assert_eq!(request.messages.len(), 4);
@ -2419,9 +2422,7 @@ fn main() {{
assert_eq!(message.context, ""); assert_eq!(message.context, "");
// Check message in request // Check message in request
let request = thread.update(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2); assert_eq!(request.messages.len(), 2);
assert_eq!( assert_eq!(
@ -2439,9 +2440,7 @@ fn main() {{
assert_eq!(message2.context, ""); assert_eq!(message2.context, "");
// Check that both messages appear in the request // Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 3); assert_eq!(request.messages.len(), 3);
assert_eq!( assert_eq!(
@ -2481,9 +2480,7 @@ fn main() {{
}); });
// Create a request and check that it doesn't have a stale buffer warning yet // Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.update(cx, |thread, cx| { let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
// Make sure we don't have a stale file warning yet // Make sure we don't have a stale file warning yet
let has_stale_warning = initial_request.messages.iter().any(|msg| { let has_stale_warning = initial_request.messages.iter().any(|msg| {
@ -2511,9 +2508,7 @@ fn main() {{
}); });
// Create a new request and check for the stale buffer warning // Create a new request and check for the stale buffer warning
let new_request = thread.update(cx, |thread, cx| { let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
thread.to_completion_request(RequestKind::Chat, cx)
});
// We should have a stale file warning as the last message // We should have a stale file warning as the last message
let last_message = new_request let last_message = new_request

View File

@ -74,6 +74,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Self::Claude3_5Haiku
}
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-5-sonnet") { if id.starts_with("claude-3-5-sonnet") {
Ok(Self::Claude3_5Sonnet) Ok(Self::Claude3_5Sonnet)

View File

@ -84,6 +84,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Self::Claude3_5Haiku
}
pub fn from_id(id: &str) -> anyhow::Result<Self> { pub fn from_id(id: &str) -> anyhow::Result<Self> {
if id.starts_with("claude-3-5-sonnet-v2") { if id.starts_with("claude-3-5-sonnet-v2") {
Ok(Self::Claude3_5SonnetV2) Ok(Self::Claude3_5SonnetV2)

View File

@ -61,6 +61,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Self::Claude3_7Sonnet
}
pub fn uses_streaming(&self) -> bool { pub fn uses_streaming(&self) -> bool {
match self { match self {
Self::Gpt4o Self::Gpt4o

View File

@ -64,6 +64,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Model::Chat
}
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
match id { match id {
"deepseek-chat" => Ok(Self::Chat), "deepseek-chat" => Ok(Self::Chat),

View File

@ -1,4 +1,4 @@
use agent::{RequestKind, ThreadEvent, ThreadStore}; use agent::{ThreadEvent, ThreadStore};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress; use client::proto::LspWorkProgress;
@ -472,7 +472,7 @@ impl Example {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
let context = vec![]; let context = vec![];
thread.insert_user_message(this.prompt.clone(), context, None, cx); thread.insert_user_message(this.prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx); thread.send_to_model(model, cx);
})?; })?;
event_handler_task.await?; event_handler_task.await?;

View File

@ -412,6 +412,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Model {
Model::Gemini15Flash
}
pub fn id(&self) -> &str { pub fn id(&self) -> &str {
match self { match self {
Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Pro => "gemini-1.5-pro",

View File

@ -49,6 +49,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
Some(Arc::new(FakeLanguageModel::default())) Some(Arc::new(FakeLanguageModel::default()))
} }
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}
fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())] vec![Arc::new(FakeLanguageModel::default())]
} }

View File

@ -370,6 +370,7 @@ pub trait LanguageModelProvider: 'static {
IconName::ZedAssistant IconName::ZedAssistant
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>; fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
Vec::new() Vec::new()

View File

@ -5,6 +5,7 @@ use crate::{
use collections::BTreeMap; use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::sync::Arc; use std::sync::Arc;
use util::maybe;
pub fn init(cx: &mut App) { pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default()); let registry = cx.new(|_cx| LanguageModelRegistry::default());
@ -18,6 +19,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)] #[derive(Default)]
pub struct LanguageModelRegistry { pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>, default_model: Option<ConfiguredModel>,
default_fast_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>, inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>, commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>, thread_summary_model: Option<ConfiguredModel>,
@ -202,6 +204,14 @@ impl LanguageModelRegistry {
(None, None) => {} (None, None) => {}
_ => cx.emit(Event::DefaultModelChanged), _ => cx.emit(Event::DefaultModelChanged),
} }
self.default_fast_model = maybe!({
let provider = &model.as_ref()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider: provider.clone(),
model: fast_model,
})
});
self.default_model = model; self.default_model = model;
} }
@ -254,21 +264,37 @@ impl LanguageModelRegistry {
} }
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> { pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.inline_assistant_model self.inline_assistant_model
.clone() .clone()
.or_else(|| self.default_model()) .or_else(|| self.default_model.clone())
} }
pub fn commit_message_model(&self) -> Option<ConfiguredModel> { pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.commit_message_model self.commit_message_model
.clone() .clone()
.or_else(|| self.default_model()) .or_else(|| self.default_model.clone())
} }
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> { pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.thread_summary_model self.thread_summary_model
.clone() .clone()
.or_else(|| self.default_model()) .or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_model.clone())
} }
/// The models to use for inline assists. Returns the union of the active /// The models to use for inline assists. Returns the union of the active

View File

@ -201,7 +201,7 @@ impl AnthropicLanguageModelProvider {
state: self.state.clone(), state: self.state.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4), request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> })
} }
} }
@ -227,14 +227,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = anthropic::Model::default(); Some(self.create_language_model(anthropic::Model::default()))
Some(Arc::new(AnthropicModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
state: self.state.clone(), Some(self.create_language_model(anthropic::Model::default_fast()))
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
} }
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {

View File

@ -286,6 +286,18 @@ impl BedrockLanguageModelProvider {
state, state,
} }
} }
fn create_language_model(&self, model: bedrock::Model) -> Arc<dyn LanguageModel> {
Arc::new(BedrockModel {
id: LanguageModelId::from(model.id().to_string()),
model,
http_client: self.http_client.clone(),
handler: self.handler.clone(),
state: self.state.clone(),
client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
})
}
} }
impl LanguageModelProvider for BedrockLanguageModelProvider { impl LanguageModelProvider for BedrockLanguageModelProvider {
@ -302,16 +314,11 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = bedrock::Model::default(); Some(self.create_language_model(bedrock::Model::default()))
Some(Arc::new(BedrockModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
http_client: self.http_client.clone(), Some(self.create_language_model(bedrock::Model::default_fast()))
handler: self.handler.clone(),
state: self.state.clone(),
client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
}))
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -343,17 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
models models
.into_values() .into_values()
.map(|model| { .map(|model| self.create_language_model(model))
Arc::new(BedrockModel {
id: LanguageModelId::from(model.id().to_string()),
model,
http_client: self.http_client.clone(),
handler: self.handler.clone(),
state: self.state.clone(),
client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }

View File

@ -242,7 +242,7 @@ impl CloudLanguageModelProvider {
llm_api_token: llm_api_token.clone(), llm_api_token: llm_api_token.clone(),
client: self.client.clone(), client: self.client.clone(),
request_limiter: RateLimiter::new(4), request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> })
} }
} }
@ -270,13 +270,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::default()); let model = CloudModel::Anthropic(anthropic::Model::default());
Some(Arc::new(CloudLanguageModel { Some(self.create_language_model(model, llm_api_token))
id: LanguageModelId::from(model.id().to_string()), }
model,
llm_api_token: llm_api_token.clone(), fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
client: self.client.clone(), let llm_api_token = self.state.read(cx).llm_api_token.clone();
request_limiter: RateLimiter::new(4), let model = CloudModel::Anthropic(anthropic::Model::default_fast());
})) Some(self.create_language_model(model, llm_api_token))
} }
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {

View File

@ -70,6 +70,13 @@ impl CopilotChatLanguageModelProvider {
Self { state } Self { state }
} }
fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
})
}
} }
impl LanguageModelProviderState for CopilotChatLanguageModelProvider { impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
@ -94,21 +101,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = CopilotChatModel::default(); Some(self.create_language_model(CopilotChatModel::default()))
Some(Arc::new(CopilotChatLanguageModel { }
model,
request_limiter: RateLimiter::new(4), fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
}) as Arc<dyn LanguageModel>) Some(self.create_language_model(CopilotChatModel::default_fast()))
} }
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter() CopilotChatModel::iter()
.map(|model| { .map(|model| self.create_language_model(model))
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }

View File

@ -140,6 +140,16 @@ impl DeepSeekLanguageModelProvider {
Self { http_client, state } Self { http_client, state }
} }
fn create_language_model(&self, model: deepseek::Model) -> Arc<dyn LanguageModel> {
Arc::new(DeepSeekLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
}
} }
impl LanguageModelProviderState for DeepSeekLanguageModelProvider { impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
@ -164,14 +174,11 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = deepseek::Model::Chat; Some(self.create_language_model(deepseek::Model::default()))
Some(Arc::new(DeepSeekLanguageModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
state: self.state.clone(), Some(self.create_language_model(deepseek::Model::default_fast()))
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -198,15 +205,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
models models
.into_values() .into_values()
.map(|model| { .map(|model| self.create_language_model(model))
Arc::new(DeepSeekLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }

View File

@ -150,6 +150,16 @@ impl GoogleLanguageModelProvider {
Self { http_client, state } Self { http_client, state }
} }
fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(GoogleLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
} }
impl LanguageModelProviderState for GoogleLanguageModelProvider { impl LanguageModelProviderState for GoogleLanguageModelProvider {
@ -174,14 +184,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = google_ai::Model::default(); Some(self.create_language_model(google_ai::Model::default()))
Some(Arc::new(GoogleLanguageModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
state: self.state.clone(), Some(self.create_language_model(google_ai::Model::default_fast()))
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {

View File

@ -157,6 +157,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.provided_models(cx).into_iter().next() self.provided_models(cx).into_iter().next()
} }
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.default_model(cx)
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default(); let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();

View File

@ -144,6 +144,16 @@ impl MistralLanguageModelProvider {
Self { http_client, state } Self { http_client, state }
} }
fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
Arc::new(MistralLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
} }
impl LanguageModelProviderState for MistralLanguageModelProvider { impl LanguageModelProviderState for MistralLanguageModelProvider {
@ -168,14 +178,11 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = mistral::Model::default(); Some(self.create_language_model(mistral::Model::default()))
Some(Arc::new(MistralLanguageModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
state: self.state.clone(), Some(self.create_language_model(mistral::Model::default_fast()))
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {

View File

@ -162,6 +162,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.provided_models(cx).into_iter().next() self.provided_models(cx).into_iter().next()
} }
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.default_model(cx)
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default(); let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();

View File

@ -148,6 +148,16 @@ impl OpenAiLanguageModelProvider {
Self { http_client, state } Self { http_client, state }
} }
fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
} }
impl LanguageModelProviderState for OpenAiLanguageModelProvider { impl LanguageModelProviderState for OpenAiLanguageModelProvider {
@ -172,14 +182,11 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = open_ai::Model::default(); Some(self.create_language_model(open_ai::Model::default()))
Some(Arc::new(OpenAiLanguageModel { }
id: LanguageModelId::from(model.id().to_string()),
model, fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
state: self.state.clone(), Some(self.create_language_model(open_ai::Model::default_fast()))
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -211,15 +218,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
models models
.into_values() .into_values()
.map(|model| { .map(|model| self.create_language_model(model))
Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }

View File

@ -69,6 +69,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Model::MistralSmallLatest
}
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
match id { match id {
"codestral-latest" => Ok(Self::CodestralLatest), "codestral-latest" => Ok(Self::CodestralLatest),

View File

@ -102,6 +102,10 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Self::FourPointOneMini
}
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
match id { match id {
"gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo), "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),