From fbf7caf93eb18d817bed6bd405d914590ba409d2 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sat, 19 Apr 2025 17:26:29 -0600 Subject: [PATCH] Default to fast model for thread summaries and titles + don't include system prompt / context / thinking segments (#29102) * Adds a fast / cheaper model to providers and defaults thread summarization to this model. Initial motivation for this was that https://github.com/zed-industries/zed/pull/29099 would cause these requests to fail when used with a thinking model. It doesn't seem correct to use a thinking model for summarization. * Skips system prompt, context, and thinking segments. * If tool use is happening, allows 2 tool uses + one more agent response before summarizing. Downside of this is that there was potential for some prefix cache reuse before, especially for title summarization (thread summarization omitted tool results and so would not share a prefix for those). This seems fine as these requests should typically be fairly small. Even for full thread summarization, skipping all tool use / context should greatly reduce the token use. Release Notes: - N/A --- crates/agent/src/active_thread.rs | 6 +- crates/agent/src/assistant.rs | 2 +- crates/agent/src/message_editor.rs | 15 +- crates/agent/src/thread.rs | 169 +++++++++--------- crates/anthropic/src/anthropic.rs | 4 + crates/bedrock/src/models.rs | 4 + crates/copilot/src/copilot_chat.rs | 4 + crates/deepseek/src/deepseek.rs | 4 + crates/eval/src/example.rs | 4 +- crates/google_ai/src/google_ai.rs | 4 + crates/language_model/src/fake_provider.rs | 4 + crates/language_model/src/language_model.rs | 1 + crates/language_model/src/registry.rs | 32 +++- .../language_models/src/provider/anthropic.rs | 15 +- .../language_models/src/provider/bedrock.rs | 39 ++-- crates/language_models/src/provider/cloud.rs | 16 +- .../src/provider/copilot_chat.rs | 24 +-- .../language_models/src/provider/deepseek.rs | 33 ++-- crates/language_models/src/provider/google.rs | 23 ++- .../language_models/src/provider/lmstudio.rs | 4 + .../language_models/src/provider/mistral.rs | 23 ++- crates/language_models/src/provider/ollama.rs | 4 + .../language_models/src/provider/open_ai.rs | 33 ++-- crates/mistral/src/mistral.rs | 4 + crates/open_ai/src/open_ai.rs | 4 + 25 files changed, 270 insertions(+), 205 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 4c3fbe878e..2a4a00cf23 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,8 +1,8 @@ use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ - LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, + ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -1285,7 +1285,7 @@ impl ActiveThread { self.thread.update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model.model, RequestKind::Chat, cx) + thread.send_to_model(model.model, cx) }); cx.notify(); } diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 1f067af734..03e13d6f68 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -40,7 +40,7 @@ pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; pub use crate::inline_assistant::InlineAssistant; -pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent}; +pub use crate::thread::{Message, Thread, ThreadEvent}; pub use crate::thread_store::ThreadStore; pub use agent_diff::{AgentDiff, AgentDiffToolbar}; diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index d64c95b9b9..ac16df4c97 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -34,7 +34,7 @@ use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; -use crate::thread::{RequestKind, Thread, TokenUsageRatio}; +use crate::thread::{Thread, TokenUsageRatio}; use crate::thread_store::ThreadStore; use crate::{ AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext, @@ -234,7 +234,7 @@ impl MessageEditor { } self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); cx.notify(); } @@ -249,12 +249,7 @@ impl MessageEditor { .is_some() } - fn send_to_model( - &mut self, - request_kind: RequestKind, - window: &mut Window, - cx: &mut Context, - ) { + fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { let model_registry = LanguageModelRegistry::read_global(cx); let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { return; @@ -331,7 +326,7 @@ impl MessageEditor { thread .update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model, request_kind, cx); + thread.send_to_model(model, cx); }) .log_err(); }) @@ -345,7 +340,7 @@ impl MessageEditor { if cancelled { self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index e8ca584fa0..17f7e20387 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -40,13 +40,6 @@ use crate::thread_store::{ }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER}; -#[derive(Debug, Clone, Copy)] -pub enum RequestKind { - Chat, - /// Used when summarizing a thread. - Summarize, -} - #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] @@ -949,13 +942,8 @@ impl Thread { }) } - pub fn send_to_model( - &mut self, - model: Arc, - request_kind: RequestKind, - cx: &mut Context, - ) { - let mut request = self.to_completion_request(request_kind, cx); + pub fn send_to_model(&mut self, model: Arc, cx: &mut Context) { + let mut request = self.to_completion_request(cx); if model.supports_tools() { request.tools = { let mut tools = Vec::new(); @@ -994,11 +982,7 @@ impl Thread { false } - pub fn to_completion_request( - &self, - request_kind: RequestKind, - cx: &mut Context, - ) -> LanguageModelRequest { + pub fn to_completion_request(&self, cx: &mut Context) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: Some(self.id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()), @@ -1045,18 +1029,8 @@ impl Thread { cache: false, }; - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_results(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - if self.tool_use.message_has_tool_results(message.id) { - continue; - } - } - } + self.tool_use + .attach_tool_results(message.id, &mut request_message); if !message.context.is_empty() { request_message @@ -1089,15 +1063,8 @@ impl Thread { }; } - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_uses(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - } - }; + self.tool_use + .attach_tool_uses(message.id, &mut request_message); request.messages.push(request_message); } @@ -1112,6 +1079,54 @@ impl Thread { request } + fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest { + let mut request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + }; + + for message in &self.messages { + let mut request_message = LanguageModelRequestMessage { + role: message.role, + content: Vec::new(), + cache: false, + }; + + // Skip tool results during summarization. + if self.tool_use.message_has_tool_results(message.id) { + continue; + } + + for segment in &message.segments { + match segment { + MessageSegment::Text(text) => request_message + .content + .push(MessageContent::Text(text.clone())), + MessageSegment::Thinking { .. } => {} + MessageSegment::RedactedThinking(_) => {} + } + } + + if request_message.content.is_empty() { + continue; + } + + request.messages.push(request_message); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(added_user_message)], + cache: false, + }); + + request + } + fn attached_tracked_files_state( &self, messages: &mut Vec, @@ -1293,7 +1308,12 @@ impl Thread { .pending_completions .retain(|completion| completion.id != pending_completion_id); - if thread.summary.is_none() && thread.messages.len() >= 2 { + // If there is a response without tool use, summarize the message. Otherwise, + // allow two tool uses before summarizing. + if thread.summary.is_none() + && thread.messages.len() >= 2 + && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) + { thread.summarize(cx); } })?; @@ -1403,18 +1423,12 @@ impl Thread { return; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ - Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ - If the conversation is about a specific subject, include it in the title. \ - Be descriptive. DO NOT speak in the first person." - .into(), - ], - cache: false, - }); + let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ + Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ + If the conversation is about a specific subject, include it in the title. \ + Be descriptive. DO NOT speak in the first person."; + + let request = self.to_summarize_request(added_user_message.into()); self.pending_summary = cx.spawn(async move |this, cx| { async move { @@ -1476,21 +1490,14 @@ impl Thread { return None; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); + let added_user_message = "Generate a detailed summary of this conversation. Include:\n\ + 1. A brief overview of what was discussed\n\ + 2. Key facts or information discovered\n\ + 3. Outcomes or conclusions reached\n\ + 4. Any action items or next steps if any\n\ + Format it in Markdown with headings and bullet points."; - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a detailed summary of this conversation. Include:\n\ - 1. A brief overview of what was discussed\n\ - 2. Key facts or information discovered\n\ - 3. Outcomes or conclusions reached\n\ - 4. Any action items or next steps if any\n\ - Format it in Markdown with headings and bullet points." - .into(), - ], - cache: false, - }); + let request = self.to_summarize_request(added_user_message.into()); let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion_text(request, &cx); @@ -1538,7 +1545,7 @@ impl Thread { pub fn use_pending_tools(&mut self, cx: &mut Context) -> Vec { self.auto_capture_telemetry(cx); - let request = self.to_completion_request(RequestKind::Chat, cx); + let request = self.to_completion_request(cx); let messages = Arc::new(request.messages); let pending_tool_uses = self .tool_use @@ -1650,7 +1657,7 @@ impl Thread { if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { self.attach_tool_results(cx); if !canceled { - self.send_to_model(model, RequestKind::Chat, cx); + self.send_to_model(model, cx); } } } @@ -2275,9 +2282,7 @@ fn main() {{ assert_eq!(message.context, expected_context); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); @@ -2367,9 +2372,7 @@ fn main() {{ assert!(message3.context.contains("file3.rs")); // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // The request should contain all 3 messages assert_eq!(request.messages.len(), 4); @@ -2419,9 +2422,7 @@ fn main() {{ assert_eq!(message.context, ""); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); assert_eq!( @@ -2439,9 +2440,7 @@ fn main() {{ assert_eq!(message2.context, ""); // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 3); assert_eq!( @@ -2481,9 +2480,7 @@ fn main() {{ }); // Create a request and check that it doesn't have a stale buffer warning yet - let initial_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // Make sure we don't have a stale file warning yet let has_stale_warning = initial_request.messages.iter().any(|msg| { @@ -2511,9 +2508,7 @@ fn main() {{ }); // Create a new request and check for the stale buffer warning - let new_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // We should have a stale file warning as the last message let last_message = new_request diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 684feaca3b..e7edb0e086 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -74,6 +74,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> Result { if id.starts_with("claude-3-5-sonnet") { Ok(Self::Claude3_5Sonnet) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 052e5c2ca1..8ead77f9c4 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -84,6 +84,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> anyhow::Result { if id.starts_with("claude-3-5-sonnet-v2") { Ok(Self::Claude3_5SonnetV2) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 255c39cb84..2bcb82c1ee 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -61,6 +61,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_7Sonnet + } + pub fn uses_streaming(&self) -> bool { match self { Self::Gpt4o diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs index 07f6a959e1..9c19f1ae2f 100644 --- a/crates/deepseek/src/deepseek.rs +++ b/crates/deepseek/src/deepseek.rs @@ -64,6 +64,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::Chat + } + pub fn from_id(id: &str) -> Result { match id { "deepseek-chat" => Ok(Self::Chat), diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 982daeaed7..78b7eb7af9 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,4 +1,4 @@ -use agent::{RequestKind, ThreadEvent, ThreadStore}; +use agent::{ThreadEvent, ThreadStore}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::ToolWorkingSet; use client::proto::LspWorkProgress; @@ -472,7 +472,7 @@ impl Example { thread.update(cx, |thread, cx| { let context = vec![]; thread.insert_user_message(this.prompt.clone(), context, None, cx); - thread.send_to_model(model, RequestKind::Chat, cx); + thread.send_to_model(model, cx); })?; event_handler_task.await?; diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 09278d6ed2..e26750936d 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -412,6 +412,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Model { + Model::Gemini15Flash + } + pub fn id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 56df184d36..25f2a496e7 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -49,6 +49,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Some(Arc::new(FakeLanguageModel::default())) } + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 206958e82f..71d8551bd5 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -370,6 +370,7 @@ pub trait LanguageModelProvider: 'static { IconName::ZedAssistant } fn default_model(&self, cx: &App) -> Option>; + fn default_fast_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn recommended_models(&self, _cx: &App) -> Vec> { Vec::new() diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 45be22457f..62f216094b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -5,6 +5,7 @@ use crate::{ use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::sync::Arc; +use util::maybe; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); @@ -18,6 +19,7 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, + default_fast_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, @@ -202,6 +204,14 @@ impl LanguageModelRegistry { (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } + self.default_fast_model = maybe!({ + let provider = &model.as_ref()?.provider; + let fast_model = provider.default_fast_model(cx)?; + Some(ConfiguredModel { + provider: provider.clone(), + model: fast_model, + }) + }); self.default_model = model; } @@ -254,21 +264,37 @@ impl LanguageModelRegistry { } pub fn inline_assistant_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.inline_assistant_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn commit_message_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.commit_message_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn thread_summary_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.thread_summary_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_fast_model.clone()) + .or_else(|| self.default_model.clone()) } /// The models to use for inline assists. Returns the union of the active diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 6a29976504..f998969bfe 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -201,7 +201,7 @@ impl AnthropicLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -227,14 +227,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = anthropic::Model::default(); - Some(Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(anthropic::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic::Model::default_fast())) } fn recommended_models(&self, _cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index c4ef48404f..a2748b45be 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -286,6 +286,18 @@ impl BedrockLanguageModelProvider { state, } } + + fn create_language_model(&self, model: bedrock::Model) -> Arc { + Arc::new(BedrockModel { + id: LanguageModelId::from(model.id().to_string()), + model, + http_client: self.http_client.clone(), + handler: self.handler.clone(), + state: self.state.clone(), + client: OnceCell::new(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProvider for BedrockLanguageModelProvider { @@ -302,16 +314,11 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = bedrock::Model::default(); - Some(Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(bedrock::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(bedrock::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -343,17 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 80c8d0dcc3..b9911d5d46 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -242,7 +242,7 @@ impl CloudLanguageModelProvider { llm_api_token: llm_api_token.clone(), client: self.client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -270,13 +270,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let llm_api_token = self.state.read(cx).llm_api_token.clone(); let model = CloudModel::Anthropic(anthropic::Model::default()); - Some(Arc::new(CloudLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - llm_api_token: llm_api_token.clone(), - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(model, llm_api_token)) + } + + fn default_fast_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default_fast()); + Some(self.create_language_model(model, llm_api_token)) } fn recommended_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 3d4924b890..255de2d536 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -70,6 +70,13 @@ impl CopilotChatLanguageModelProvider { Self { state } } + + fn create_language_model(&self, model: CopilotChatModel) -> Arc { + Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { @@ -94,21 +101,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = CopilotChatModel::default(); - Some(Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc) + Some(self.create_language_model(CopilotChatModel::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(CopilotChatModel::default_fast())) } fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() - .map(|model| { - Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index e4f1cd830a..9989e4c6b1 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -140,6 +140,16 @@ impl DeepSeekLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: deepseek::Model) -> Arc { + Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for DeepSeekLanguageModelProvider { @@ -164,14 +174,11 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = deepseek::Model::Chat; - Some(Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(deepseek::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(deepseek::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -198,15 +205,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index c754e63bbd..bbe6c58353 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -150,6 +150,16 @@ impl GoogleLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: google_ai::Model) -> Arc { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for GoogleLanguageModelProvider { @@ -174,14 +184,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = google_ai::Model::default(); - Some(Arc::new(GoogleLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(google_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(google_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index bd8b6303f8..2f5ae9ebb6 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index b9017398e6..a5009c76a6 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -144,6 +144,16 @@ impl MistralLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: mistral::Model) -> Arc { + Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -168,14 +178,11 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = mistral::Model::default(); - Some(Arc::new(MistralLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(mistral::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(mistral::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 465dc0d659..17c50c8eaf 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -162,6 +162,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 020c642520..188a219e2d 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -148,6 +148,16 @@ impl OpenAiLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: open_ai::Model) -> Arc { + Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -172,14 +182,11 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = open_ai::Model::default(); - Some(Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(open_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -211,15 +218,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index de2457e0bf..27de5fccc2 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -69,6 +69,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::MistralSmallLatest + } + pub fn from_id(id: &str) -> Result { match id { "codestral-latest" => Ok(Self::CodestralLatest), diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 9284b4a9b2..f9e0b7d4e3 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -102,6 +102,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::FourPointOneMini + } + pub fn from_id(id: &str) -> Result { match id { "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),