agent: Allow customizing temperature by provider/model (#30033)

Adds a new `agent.model_parameters` setting that allows the user to
specify a custom temperature for a provider AND/OR model:

```json5
    "model_parameters": [
      // To set parameters for all requests to OpenAI models:
      {
        "provider": "openai",
        "temperature": 0.5
      },
      // To set parameters for all requests in general:
      {
        "temperature": 0
      },
      // To set parameters for a specific provider and model:
      {
        "provider": "zed.dev",
        "model": "claude-3-7-sonnet-latest",
        "temperature": 1.0
      }
    ],
```

Release Notes:

- agent: Allow customizing temperature by provider/model

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Agus Zubiaga 2025-05-06 17:36:25 -03:00 committed by Joseph T. Lyons
parent 44ef5bd95f
commit 42894f6c8b
22 changed files with 348 additions and 106 deletions

2
Cargo.lock generated
View File

@ -595,6 +595,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anthropic", "anthropic",
"anyhow", "anyhow",
"collections",
"deepseek", "deepseek",
"feature_flags", "feature_flags",
"fs", "fs",
@ -3006,6 +3007,7 @@ dependencies = [
"anyhow", "anyhow",
"assistant", "assistant",
"assistant_context_editor", "assistant_context_editor",
"assistant_settings",
"assistant_slash_command", "assistant_slash_command",
"assistant_tool", "assistant_tool",
"async-stripe", "async-stripe",

View File

@ -605,13 +605,11 @@
// //
// Default: main // Default: main
"fallback_branch_name": "main", "fallback_branch_name": "main",
// Whether to sort entries in the panel by path // Whether to sort entries in the panel by path
// or by status (the default). // or by status (the default).
// //
// Default: false // Default: false
"sort_by_path": false, "sort_by_path": false,
"scrollbar": { "scrollbar": {
// When to show the scrollbar in the git panel. // When to show the scrollbar in the git panel.
// //
@ -661,6 +659,28 @@
// The model to use. // The model to use.
"model": "claude-3-7-sonnet-latest" "model": "claude-3-7-sonnet-latest"
}, },
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
// and model are optional, so that you can specify parameters for either one.
"model_parameters": [
// To set parameters for all requests to OpenAI models:
// {
// "provider": "openai",
// "temperature": 0.5
// }
//
// To set parameters for all requests in general:
// {
// "temperature": 0
// }
//
// To set parameters for a specific provider and model:
// {
// "provider": "zed.dev",
// "model": "claude-3-7-sonnet-latest",
// "temperature": 1.0
// }
],
// When enabled, the agent can run potentially destructive actions without asking for your confirmation. // When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false, "always_allow_tool_actions": false,
// When enabled, the agent will stream edits. // When enabled, the agent will stream edits.

View File

@ -1417,7 +1417,10 @@ impl ActiveThread {
messages: vec![request_message], messages: vec![request_message],
tools: vec![], tools: vec![],
stop: vec![], stop: vec![],
temperature: None, temperature: AssistantSettings::temperature_for_model(
&configured_model.model,
cx,
),
}; };
Some(configured_model.model.count_tokens(request, cx)) Some(configured_model.model.count_tokens(request, cx))

View File

@ -2,6 +2,7 @@ use crate::context::ContextLoadResult;
use crate::inline_prompt_editor::CodegenStatus; use crate::inline_prompt_editor::CodegenStatus;
use crate::{context::load_context, context_store::ContextStore}; use crate::{context::load_context, context_store::ContextStore};
use anyhow::Result; use anyhow::Result;
use assistant_settings::AssistantSettings;
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
use collections::HashSet; use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
@ -383,7 +384,7 @@ impl CodegenAlternative {
if user_prompt.trim().to_lowercase() == "delete" { if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(LanguageModelTextStream::default()) }.boxed_local() async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else { } else {
let request = self.build_request(user_prompt, cx)?; let request = self.build_request(&model, user_prompt, cx)?;
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await) cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local() .boxed_local()
}; };
@ -393,6 +394,7 @@ impl CodegenAlternative {
fn build_request( fn build_request(
&self, &self,
model: &Arc<dyn LanguageModel>,
user_prompt: String, user_prompt: String,
cx: &mut App, cx: &mut App,
) -> Result<Task<LanguageModelRequest>> { ) -> Result<Task<LanguageModelRequest>> {
@ -441,6 +443,8 @@ impl CodegenAlternative {
} }
}); });
let temperature = AssistantSettings::temperature_for_model(&model, cx);
Ok(cx.spawn(async move |_cx| { Ok(cx.spawn(async move |_cx| {
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
role: Role::User, role: Role::User,
@ -463,7 +467,7 @@ impl CodegenAlternative {
mode: None, mode: None,
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature,
messages: vec![request_message], messages: vec![request_message],
} }
})) }))

View File

@ -8,7 +8,7 @@ use crate::ui::{
AnimatedLabel, MaxModeTooltip, AnimatedLabel, MaxModeTooltip,
preview::{AgentPreview, UsageCallout}, preview::{AgentPreview, UsageCallout},
}; };
use assistant_settings::CompletionMode; use assistant_settings::{AssistantSettings, CompletionMode};
use buffer_diff::BufferDiff; use buffer_diff::BufferDiff;
use client::UserStore; use client::UserStore;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
@ -1273,7 +1273,7 @@ impl MessageEditor {
messages: vec![request_message], messages: vec![request_message],
tools: vec![], tools: vec![],
stop: vec![], stop: vec![],
temperature: None, temperature: AssistantSettings::temperature_for_model(&model.model, cx),
}; };
Some(model.model.count_tokens(request, cx)) Some(model.model.count_tokens(request, cx))

View File

@ -6,6 +6,7 @@ use crate::inline_prompt_editor::{
use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen}; use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
use crate::thread_store::{TextThreadStore, ThreadStore}; use crate::thread_store::{TextThreadStore, ThreadStore};
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use assistant_settings::AssistantSettings;
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque}; use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll}; use editor::{MultiBuffer, actions::SelectAll};
@ -266,6 +267,12 @@ impl TerminalInlineAssistant {
load_context(contexts, project, &assist.prompt_store, cx) load_context(contexts, project, &assist.prompt_store, cx)
})?; })?;
let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
.inline_assistant_model()
.context("No inline assistant model")?;
let temperature = AssistantSettings::temperature_for_model(&model, cx);
Ok(cx.background_spawn(async move { Ok(cx.background_spawn(async move {
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
role: Role::User, role: Role::User,
@ -287,7 +294,7 @@ impl TerminalInlineAssistant {
messages: vec![request_message], messages: vec![request_message],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature,
} }
})) }))
} }

View File

@ -1145,7 +1145,7 @@ impl Thread {
messages: vec![], messages: vec![],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: AssistantSettings::temperature_for_model(&model, cx),
}; };
let available_tools = self.available_tools(cx, model.clone()); let available_tools = self.available_tools(cx, model.clone());
@ -1251,7 +1251,12 @@ impl Thread {
request request
} }
fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest { fn to_summarize_request(
&self,
model: &Arc<dyn LanguageModel>,
added_user_message: String,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest { let mut request = LanguageModelRequest {
thread_id: None, thread_id: None,
prompt_id: None, prompt_id: None,
@ -1259,7 +1264,7 @@ impl Thread {
messages: vec![], messages: vec![],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: AssistantSettings::temperature_for_model(model, cx),
}; };
for message in &self.messages { for message in &self.messages {
@ -1696,7 +1701,7 @@ impl Thread {
If the conversation is about a specific subject, include it in the title. \ If the conversation is about a specific subject, include it in the title. \
Be descriptive. DO NOT speak in the first person."; Be descriptive. DO NOT speak in the first person.";
let request = self.to_summarize_request(added_user_message.into()); let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
self.pending_summary = cx.spawn(async move |this, cx| { self.pending_summary = cx.spawn(async move |this, cx| {
async move { async move {
@ -1782,7 +1787,7 @@ impl Thread {
4. Any action items or next steps if any\n\ 4. Any action items or next steps if any\n\
Format it in Markdown with headings and bullet points."; Format it in Markdown with headings and bullet points.";
let request = self.to_summarize_request(added_user_message.into()); let request = self.to_summarize_request(&model, added_user_message.into(), cx);
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
message_id: last_message_id, message_id: last_message_id,
@ -2655,7 +2660,7 @@ struct PendingCompletion {
mod tests { mod tests {
use super::*; use super::*;
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings; use assistant_settings::{AssistantSettings, LanguageModelParameters};
use assistant_tool::ToolRegistry; use assistant_tool::ToolRegistry;
use editor::EditorSettings; use editor::EditorSettings;
use gpui::TestAppContext; use gpui::TestAppContext;
@ -3066,6 +3071,100 @@ fn main() {{
); );
} }
#[gpui::test]
async fn test_temperature_setting(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Both model and provider
cx.update(|cx| {
AssistantSettings::override_global(
AssistantSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some(model.provider_id().0.to_string().into()),
model: Some(model.id().0.clone()),
temperature: Some(0.66),
}],
..AssistantSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
// Only model
cx.update(|cx| {
AssistantSettings::override_global(
AssistantSettings {
model_parameters: vec![LanguageModelParameters {
provider: None,
model: Some(model.id().0.clone()),
temperature: Some(0.66),
}],
..AssistantSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
// Only provider
cx.update(|cx| {
AssistantSettings::override_global(
AssistantSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some(model.provider_id().0.to_string().into()),
model: None,
temperature: Some(0.66),
}],
..AssistantSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
// Same model name, different provider
cx.update(|cx| {
AssistantSettings::override_global(
AssistantSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some("anthropic".into()),
model: Some(model.id().0.clone()),
temperature: Some(0.66),
}],
..AssistantSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, None);
}
fn init_test_settings(cx: &mut TestAppContext) { fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| { cx.update(|cx| {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);

View File

@ -163,7 +163,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel { fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel {
language_model::SelectedModel { language_model::SelectedModel {
provider: LanguageModelProviderId::from(selection.provider.clone()), provider: LanguageModelProviderId::from(selection.provider.0.clone()),
model: LanguageModelId::from(selection.model.clone()), model: LanguageModelId::from(selection.model.clone()),
} }
} }

View File

@ -2484,7 +2484,7 @@ impl InlineAssist {
.read(cx) .read(cx)
.active_context(cx)? .active_context(cx)?
.read(cx) .read(cx)
.to_completion_request(RequestType::Chat, cx), .to_completion_request(None, RequestType::Chat, cx),
) )
} else { } else {
None None
@ -2870,7 +2870,8 @@ impl CodegenAlternative {
if let Some(ConfiguredModel { model, .. }) = if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model() LanguageModelRegistry::read_global(cx).inline_assistant_model()
{ {
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); let request =
self.build_request(&model, user_prompt, assistant_panel_context.clone(), cx);
match request { match request {
Ok(request) => { Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx); let total_count = model.count_tokens(request.clone(), cx);
@ -2915,7 +2916,8 @@ impl CodegenAlternative {
if user_prompt.trim().to_lowercase() == "delete" { if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(LanguageModelTextStream::default()) }.boxed_local() async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else { } else {
let request = self.build_request(user_prompt, assistant_panel_context, cx)?; let request =
self.build_request(&model, user_prompt, assistant_panel_context, cx)?;
self.request = Some(request.clone()); self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await) cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
@ -2927,6 +2929,7 @@ impl CodegenAlternative {
fn build_request( fn build_request(
&self, &self,
model: &Arc<dyn LanguageModel>,
user_prompt: String, user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
cx: &App, cx: &App,
@ -2981,7 +2984,7 @@ impl CodegenAlternative {
messages, messages,
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: AssistantSettings::temperature_for_model(&model, cx),
}) })
} }

View File

@ -261,7 +261,7 @@ impl TerminalInlineAssistant {
.read(cx) .read(cx)
.active_context(cx)? .active_context(cx)?
.read(cx) .read(cx)
.to_completion_request(RequestType::Chat, cx), .to_completion_request(None, RequestType::Chat, cx),
) )
}) })
} else { } else {

View File

@ -3,6 +3,7 @@ mod context_tests;
use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus}; use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_slash_command::{ use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
SlashCommandResult, SlashCommandWorkingSet, SlashCommandResult, SlashCommandWorkingSet,
@ -1273,10 +1274,10 @@ impl AssistantContext {
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut Context<Self>) { pub(crate) fn count_remaining_tokens(&mut self, cx: &mut Context<Self>) {
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit), // Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
// because otherwise you see in the UI that your empty message has a bunch of tokens already used. // because otherwise you see in the UI that your empty message has a bunch of tokens already used.
let request = self.to_completion_request(RequestType::Chat, cx);
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return; return;
}; };
let request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
let debounce = self.token_count.is_some(); let debounce = self.token_count.is_some();
self.pending_token_count = cx.spawn(async move |this, cx| { self.pending_token_count = cx.spawn(async move |this, cx| {
async move { async move {
@ -1422,7 +1423,7 @@ impl AssistantContext {
} }
let request = { let request = {
let mut req = self.to_completion_request(RequestType::Chat, cx); let mut req = self.to_completion_request(Some(&model), RequestType::Chat, cx);
// Skip the last message because it's likely to change and // Skip the last message because it's likely to change and
// therefore would be a waste to cache. // therefore would be a waste to cache.
req.messages.pop(); req.messages.pop();
@ -2321,7 +2322,7 @@ impl AssistantContext {
// Compute which messages to cache, including the last one. // Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx); self.mark_cache_anchors(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(request_type, cx); let request = self.to_completion_request(Some(&model), request_type, cx);
let assistant_message = self let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
@ -2561,6 +2562,7 @@ impl AssistantContext {
pub fn to_completion_request( pub fn to_completion_request(
&self, &self,
model: Option<&Arc<dyn LanguageModel>>,
request_type: RequestType, request_type: RequestType,
cx: &App, cx: &App,
) -> LanguageModelRequest { ) -> LanguageModelRequest {
@ -2584,7 +2586,8 @@ impl AssistantContext {
messages: Vec::new(), messages: Vec::new(),
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: model
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)),
}; };
for message in self.messages(cx) { for message in self.messages(cx) {
if message.status != MessageStatus::Done { if message.status != MessageStatus::Done {
@ -2981,7 +2984,7 @@ impl AssistantContext {
return; return;
} }
let mut request = self.to_completion_request(RequestType::Chat, cx); let mut request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
request.messages.push(LanguageModelRequestMessage { request.messages.push(LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: vec![ content: vec![

View File

@ -43,9 +43,8 @@ use workspace::Workspace;
#[gpui::test] #[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut App) { fn test_inserting_and_removing_messages(cx: &mut App) {
let settings_store = SettingsStore::test(cx); init_test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| { let context = cx.new(|cx| {
@ -182,9 +181,8 @@ fn test_inserting_and_removing_messages(cx: &mut App) {
#[gpui::test] #[gpui::test]
fn test_message_splitting(cx: &mut App) { fn test_message_splitting(cx: &mut App) {
let settings_store = SettingsStore::test(cx); init_test(cx);
cx.set_global(settings_store);
LanguageModelRegistry::test(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
@ -285,9 +283,8 @@ fn test_message_splitting(cx: &mut App) {
#[gpui::test] #[gpui::test]
fn test_messages_for_offsets(cx: &mut App) { fn test_messages_for_offsets(cx: &mut App) {
let settings_store = SettingsStore::test(cx); init_test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| { let context = cx.new(|cx| {
@ -378,10 +375,8 @@ fn test_messages_for_offsets(cx: &mut App) {
#[gpui::test] #[gpui::test]
async fn test_slash_commands(cx: &mut TestAppContext) { async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test); cx.update(init_test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
cx.update(Project::init_settings);
let fs = FakeFs::new(cx.background_executor.clone()); let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree( fs.insert_tree(
@ -671,22 +666,19 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
#[gpui::test] #[gpui::test]
async fn test_workflow_step_parsing(cx: &mut TestAppContext) { async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(prompt_store::init);
let mut settings_store = cx.update(SettingsStore::test);
cx.update(|cx| { cx.update(|cx| {
settings_store init_test(cx);
.set_user_settings( cx.update_global(|settings_store: &mut SettingsStore, cx| {
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#, settings_store
cx, .set_user_settings(
) r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
.unwrap() cx,
)
.unwrap()
})
}); });
cx.set_global(settings_store);
cx.update(language::init);
cx.update(Project::init_settings);
let fs = FakeFs::new(cx.executor()); let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [Path::new("/root")], cx).await; let project = Project::test(fs, [Path::new("/root")], cx).await;
cx.update(LanguageModelRegistry::test);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
@ -1069,9 +1061,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
#[gpui::test] #[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) { async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test); cx.update(init_test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| { let context = cx.new(|cx| {
@ -1147,6 +1138,8 @@ async fn test_serialization(cx: &mut TestAppContext) {
#[gpui::test(iterations = 100)] #[gpui::test(iterations = 100)]
async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) { async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
cx.update(init_test);
let min_peers = env::var("MIN_PEERS") let min_peers = env::var("MIN_PEERS")
.map(|i| i.parse().expect("invalid `MIN_PEERS` variable")) .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
.unwrap_or(2); .unwrap_or(2);
@ -1157,10 +1150,6 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
.map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(50); .unwrap_or(50);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
let slash_commands = cx.update(SlashCommandRegistry::default_global); let slash_commands = cx.update(SlashCommandRegistry::default_global);
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false); slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
@ -1429,9 +1418,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
#[gpui::test] #[gpui::test]
fn test_mark_cache_anchors(cx: &mut App) { fn test_mark_cache_anchors(cx: &mut App) {
let settings_store = SettingsStore::test(cx); init_test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| { let context = cx.new(|cx| {
@ -1606,6 +1594,16 @@ fn messages_cache(
.collect() .collect()
} }
fn init_test(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
prompt_store::init(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
language::init(cx);
assistant_settings::init(cx);
Project::init_settings(cx);
}
#[derive(Clone)] #[derive(Clone)]
struct FakeSlashCommand(String); struct FakeSlashCommand(String);

View File

@ -14,6 +14,7 @@ path = "src/assistant_settings.rs"
[dependencies] [dependencies]
anthropic = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
gpui.workspace = true gpui.workspace = true
indexmap.workspace = true indexmap.workspace = true

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use collections::IndexMap;
use gpui::SharedString; use gpui::SharedString;
use indexmap::IndexMap;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -5,10 +5,10 @@ use std::sync::Arc;
use ::open_ai::Model as OpenAiModel; use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel; use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use collections::IndexMap;
use deepseek::Model as DeepseekModel; use deepseek::Model as DeepseekModel;
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt}; use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
use gpui::{App, Pixels}; use gpui::{App, Pixels, SharedString};
use indexmap::IndexMap;
use language_model::{CloudModel, LanguageModel}; use language_model::{CloudModel, LanguageModel};
use lmstudio::Model as LmStudioModel; use lmstudio::Model as LmStudioModel;
use ollama::Model as OllamaModel; use ollama::Model as OllamaModel;
@ -18,6 +18,10 @@ use settings::{Settings, SettingsSources};
pub use crate::agent_profile::*; pub use crate::agent_profile::*;
pub fn init(cx: &mut App) {
AssistantSettings::register(cx);
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition { pub enum AssistantDockPosition {
@ -89,10 +93,20 @@ pub struct AssistantSettings {
pub notify_when_agent_waiting: NotifyWhenAgentWaiting, pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool, pub stream_edits: bool,
pub single_file_review: bool, pub single_file_review: bool,
pub model_parameters: Vec<LanguageModelParameters>,
pub preferred_completion_mode: CompletionMode, pub preferred_completion_mode: CompletionMode,
} }
impl AssistantSettings { impl AssistantSettings {
pub fn temperature_for_model(model: &Arc<dyn LanguageModel>, cx: &App) -> Option<f32> {
let settings = Self::get_global(cx);
settings
.model_parameters
.iter()
.rfind(|setting| setting.matches(model))
.and_then(|m| m.temperature)
}
pub fn stream_edits(&self, cx: &App) -> bool { pub fn stream_edits(&self, cx: &App) -> bool {
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
} }
@ -106,15 +120,47 @@ impl AssistantSettings {
} }
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.inline_assistant_model = Some(LanguageModelSelection { provider, model }); self.inline_assistant_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
} }
pub fn set_commit_message_model(&mut self, provider: String, model: String) { pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.commit_message_model = Some(LanguageModelSelection { provider, model }); self.commit_message_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
} }
pub fn set_thread_summary_model(&mut self, provider: String, model: String) { pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.thread_summary_model = Some(LanguageModelSelection { provider, model }); self.thread_summary_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LanguageModelParameters {
pub provider: Option<LanguageModelProviderSetting>,
pub model: Option<SharedString>,
pub temperature: Option<f32>,
}
impl LanguageModelParameters {
pub fn matches(&self, model: &Arc<dyn LanguageModel>) -> bool {
if let Some(provider) = &self.provider {
if provider.0 != model.provider_id().0 {
return false;
}
}
if let Some(setting_model) = &self.model {
if *setting_model != model.id().0 {
return false;
}
}
true
} }
} }
@ -181,37 +227,37 @@ impl AssistantSettingsContent {
.and_then(|provider| match provider { .and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => { AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "zed.dev".to_string(), provider: "zed.dev".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::OpenAi { default_model, .. } => { AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "openai".to_string(), provider: "openai".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::Anthropic { default_model, .. } => { AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "anthropic".to_string(), provider: "anthropic".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::Ollama { default_model, .. } => { AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "ollama".to_string(), provider: "ollama".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::LmStudio { default_model, .. } => { AssistantProviderContentV1::LmStudio { default_model, .. } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "lmstudio".to_string(), provider: "lmstudio".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::DeepSeek { default_model, .. } => { AssistantProviderContentV1::DeepSeek { default_model, .. } => {
default_model.map(|model| LanguageModelSelection { default_model.map(|model| LanguageModelSelection {
provider: "deepseek".to_string(), provider: "deepseek".into(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
@ -227,6 +273,7 @@ impl AssistantSettingsContent {
notify_when_agent_waiting: None, notify_when_agent_waiting: None,
stream_edits: None, stream_edits: None,
single_file_review: None, single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None, preferred_completion_mode: None,
}, },
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(), VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
@ -238,7 +285,7 @@ impl AssistantSettingsContent {
default_width: settings.default_width, default_width: settings.default_width,
default_height: settings.default_height, default_height: settings.default_height,
default_model: Some(LanguageModelSelection { default_model: Some(LanguageModelSelection {
provider: "openai".to_string(), provider: "openai".into(),
model: settings model: settings
.default_open_ai_model .default_open_ai_model
.clone() .clone()
@ -257,6 +304,7 @@ impl AssistantSettingsContent {
notify_when_agent_waiting: None, notify_when_agent_waiting: None,
stream_edits: None, stream_edits: None,
single_file_review: None, single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None, preferred_completion_mode: None,
}, },
None => AssistantSettingsContentV2::default(), None => AssistantSettingsContentV2::default(),
@ -370,7 +418,10 @@ impl AssistantSettingsContent {
} }
} }
VersionedAssistantSettingsContent::V2(ref mut settings) => { VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.default_model = Some(LanguageModelSelection { provider, model }); settings.default_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
} }
}, },
Some(AssistantSettingsContentInner::Legacy(settings)) => { Some(AssistantSettingsContentInner::Legacy(settings)) => {
@ -381,7 +432,10 @@ impl AssistantSettingsContent {
None => { None => {
self.inner = Some(AssistantSettingsContentInner::for_v2( self.inner = Some(AssistantSettingsContentInner::for_v2(
AssistantSettingsContentV2 { AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection { provider, model }), default_model: Some(LanguageModelSelection {
provider: provider.into(),
model,
}),
..Default::default() ..Default::default()
}, },
)); ));
@ -391,7 +445,10 @@ impl AssistantSettingsContent {
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| { self.v2_setting(|setting| {
setting.inline_assistant_model = Some(LanguageModelSelection { provider, model }); setting.inline_assistant_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(()) Ok(())
}) })
.ok(); .ok();
@ -399,7 +456,10 @@ impl AssistantSettingsContent {
pub fn set_commit_message_model(&mut self, provider: String, model: String) { pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| { self.v2_setting(|setting| {
setting.commit_message_model = Some(LanguageModelSelection { provider, model }); setting.commit_message_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(()) Ok(())
}) })
.ok(); .ok();
@ -427,7 +487,10 @@ impl AssistantSettingsContent {
pub fn set_thread_summary_model(&mut self, provider: String, model: String) { pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| { self.v2_setting(|setting| {
setting.thread_summary_model = Some(LanguageModelSelection { provider, model }); setting.thread_summary_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(()) Ok(())
}) })
.ok(); .ok();
@ -523,6 +586,7 @@ impl Default for VersionedAssistantSettingsContent {
notify_when_agent_waiting: None, notify_when_agent_waiting: None,
stream_edits: None, stream_edits: None,
single_file_review: None, single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None, preferred_completion_mode: None,
}) })
} }
@ -587,6 +651,15 @@ pub struct AssistantSettingsContentV2 {
/// ///
/// Default: true /// Default: true
single_file_review: Option<bool>, single_file_review: Option<bool>,
/// Additional parameters for language model requests. When making a request
/// to a model, parameters will be taken from the last entry in this list
/// that matches the model's provider and name. In each entry, both provider
/// and model are optional, so that you can specify parameters for either
/// one.
///
/// Default: []
#[serde(default)]
model_parameters: Vec<LanguageModelParameters>,
/// What completion mode to enable for new threads /// What completion mode to enable for new threads
/// ///
@ -613,33 +686,53 @@ impl From<CompletionMode> for zed_llm_client::CompletionMode {
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LanguageModelSelection { pub struct LanguageModelSelection {
#[schemars(schema_with = "providers_schema")] pub provider: LanguageModelProviderSetting,
pub provider: String,
pub model: String, pub model: String,
} }
fn providers_schema(_: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
schemars::schema::SchemaObject { pub struct LanguageModelProviderSetting(pub String);
enum_values: Some(vec![
"anthropic".into(), impl JsonSchema for LanguageModelProviderSetting {
"bedrock".into(), fn schema_name() -> String {
"google".into(), "LanguageModelProviderSetting".into()
"lmstudio".into(), }
"ollama".into(),
"openai".into(), fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> Schema {
"zed.dev".into(), schemars::schema::SchemaObject {
"copilot_chat".into(), enum_values: Some(vec![
"deepseek".into(), "anthropic".into(),
]), "bedrock".into(),
..Default::default() "google".into(),
"lmstudio".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
"copilot_chat".into(),
"deepseek".into(),
]),
..Default::default()
}
.into()
}
}
impl From<String> for LanguageModelProviderSetting {
fn from(provider: String) -> Self {
Self(provider)
}
}
impl From<&str> for LanguageModelProviderSetting {
fn from(provider: &str) -> Self {
Self(provider.to_string())
} }
.into()
} }
impl Default for LanguageModelSelection { impl Default for LanguageModelSelection {
fn default() -> Self { fn default() -> Self {
Self { Self {
provider: "openai".to_string(), provider: LanguageModelProviderSetting("openai".to_string()),
model: "gpt-4".to_string(), model: "gpt-4".to_string(),
} }
} }
@ -781,6 +874,10 @@ impl Settings for AssistantSettings {
value.preferred_completion_mode, value.preferred_completion_mode,
); );
settings
.model_parameters
.extend_from_slice(&value.model_parameters);
if let Some(profiles) = value.profiles { if let Some(profiles) = value.profiles {
settings settings
.profiles .profiles
@ -913,6 +1010,7 @@ mod tests {
notify_when_agent_waiting: None, notify_when_agent_waiting: None,
stream_edits: None, stream_edits: None,
single_file_review: None, single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None, preferred_completion_mode: None,
}, },
)), )),
@ -976,7 +1074,7 @@ mod tests {
AssistantSettingsContentV2 { AssistantSettingsContentV2 {
enabled: Some(false), enabled: Some(false),
default_model: Some(LanguageModelSelection { default_model: Some(LanguageModelSelection {
provider: "xai".to_owned(), provider: "xai".to_owned().into(),
model: "grok".to_owned(), model: "grok".to_owned(),
}), }),
..Default::default() ..Default::default()

View File

@ -78,6 +78,7 @@ zed_llm_client.workspace = true
[dev-dependencies] [dev-dependencies]
assistant = { workspace = true, features = ["test-support"] } assistant = { workspace = true, features = ["test-support"] }
assistant_context_editor.workspace = true assistant_context_editor.workspace = true
assistant_settings.workspace = true
assistant_slash_command.workspace = true assistant_slash_command.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
async-trait.workspace = true async-trait.workspace = true

View File

@ -307,6 +307,7 @@ impl TestServer {
); );
language_model::LanguageModelRegistry::test(cx); language_model::LanguageModelRegistry::test(cx);
assistant_context_editor::init(client.clone(), cx); assistant_context_editor::init(client.clone(), cx);
assistant_settings::init(cx);
}); });
client client

View File

@ -1735,6 +1735,8 @@ impl GitPanel {
} }
}); });
let temperature = AssistantSettings::temperature_for_model(&model, cx);
self.generate_commit_message_task = Some(cx.spawn(async move |this, cx| { self.generate_commit_message_task = Some(cx.spawn(async move |this, cx| {
async move { async move {
let _defer = cx.on_drop(&this, |this, _cx| { let _defer = cx.on_drop(&this, |this, _cx| {
@ -1773,7 +1775,7 @@ impl GitPanel {
}], }],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature,
}; };
let stream = model.stream_completion_text(request, &cx); let stream = model.stream_completion_text(request, &cx);

View File

@ -87,8 +87,8 @@ pub struct AllLanguageModelSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)] #[serde(untagged)]
pub enum AnthropicSettingsContent { pub enum AnthropicSettingsContent {
Legacy(LegacyAnthropicSettingsContent),
Versioned(VersionedAnthropicSettingsContent), Versioned(VersionedAnthropicSettingsContent),
Legacy(LegacyAnthropicSettingsContent),
} }
impl AnthropicSettingsContent { impl AnthropicSettingsContent {
@ -197,8 +197,8 @@ pub struct MistralSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)] #[serde(untagged)]
pub enum OpenAiSettingsContent { pub enum OpenAiSettingsContent {
Legacy(LegacyOpenAiSettingsContent),
Versioned(VersionedOpenAiSettingsContent), Versioned(VersionedOpenAiSettingsContent),
Legacy(LegacyOpenAiSettingsContent),
} }
impl OpenAiSettingsContent { impl OpenAiSettingsContent {

View File

@ -3522,7 +3522,7 @@ impl LspStore {
) )
.detach(); .detach();
} else { } else {
log::info!("No extension events global found. Skipping JSON schema auto-reload setup"); log::debug!("No extension events global found. Skipping JSON schema auto-reload setup");
} }
cx.observe_global::<SettingsStore>(Self::on_settings_changed) cx.observe_global::<SettingsStore>(Self::on_settings_changed)
.detach(); .detach();

View File

@ -3871,7 +3871,7 @@ impl BackgroundScanner {
Some(ancestor_dot_git) Some(ancestor_dot_git)
}); });
log::info!("containing git repository: {containing_git_repository:?}"); log::trace!("containing git repository: {containing_git_repository:?}");
let (scan_job_tx, scan_job_rx) = channel::unbounded(); let (scan_job_tx, scan_job_rx) = channel::unbounded();
{ {

View File

@ -3058,14 +3058,14 @@ Run the `theme selector: toggle` action in the command palette to see a current
} }
``` ```
## Assistant Panel ## Agent
- Description: Customize assistant panel - Description: Customize agent behavior
- Setting: `assistant` - Setting: `agent`
- Default: - Default:
```json ```json
"assistant": { "agent": {
"version": "2", "version": "2",
"enabled": true, "enabled": true,
"button": true, "button": true,