From 108005f1b87c187683ada43ccf335d34fd8386c2 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 5 May 2025 21:36:12 +0200 Subject: [PATCH] context_store: Refactor state management (#29910) Because we instantiated `ContextServerManager` both in `agent` and `assistant-context-editor`, and these two entities track the running MCP servers separately, we were effectively running every MCP server twice. This PR moves the `ContextServerManager` into the project crate (now called `ContextServerStore`). The store can be accessed via a project instance. This ensures that we only instantiate one `ContextServerStore` per project. Also, this PR adds a bunch of tests to ensure that the `ContextServerStore` behaves correctly (Previously there were none). Closes #28714 Closes #29530 Release Notes: - N/A --- Cargo.lock | 31 +- Cargo.toml | 2 - crates/agent/src/active_thread.rs | 2 - crates/agent/src/agent_diff.rs | 3 - crates/agent/src/assistant.rs | 1 + crates/agent/src/assistant_configuration.rs | 78 +- .../add_context_server_modal.rs | 9 +- .../configure_context_server_modal.rs | 100 +- crates/agent/src/assistant_panel.rs | 4 +- crates/agent/src/buffer_codegen.rs | 26 +- .../agent/src/context_server_configuration.rs | 27 +- .../src/context_server_tool.rs | 24 +- crates/agent/src/thread.rs | 2 - crates/agent/src/thread_store.rs | 45 +- crates/assistant/Cargo.toml | 1 - crates/assistant/src/assistant.rs | 1 - crates/assistant/src/assistant_panel.rs | 22 +- .../src/context_store.rs | 48 +- .../src/context_server_command.rs | 26 +- crates/collab/src/tests/integration_tests.rs | 2 - crates/collab/src/tests/test_server.rs | 1 + crates/context_server/Cargo.toml | 13 +- crates/context_server/src/client.rs | 13 +- crates/context_server/src/context_server.rs | 131 +- crates/context_server/src/manager.rs | 584 --------- crates/context_server/src/protocol.rs | 2 +- crates/context_server/src/types.rs | 2 +- crates/context_server_settings/Cargo.toml | 22 - crates/context_server_settings/LICENSE-GPL | 1 - .../src/context_server_settings.rs | 99 -- crates/eval/Cargo.toml | 1 - crates/eval/src/eval.rs | 1 - crates/extension/src/extension_host_proxy.rs | 10 + crates/extension_host/Cargo.toml | 1 - crates/extension_host/src/extension_host.rs | 4 + .../src/wasm_host/wit/since_v0_5_0.rs | 17 +- crates/project/Cargo.toml | 1 + crates/project/src/context_server_store.rs | 1129 +++++++++++++++++ .../src/context_server_store/extension.rs} | 38 +- .../src/context_server_store}/registry.rs | 19 +- crates/project/src/git_store/git_traversal.rs | 5 +- crates/project/src/project.rs | 19 + crates/project/src/project_settings.rs | 52 + 43 files changed, 1570 insertions(+), 1049 deletions(-) rename crates/{context_server => agent}/src/context_server_tool.rs (88%) delete mode 100644 crates/context_server/src/manager.rs delete mode 100644 crates/context_server_settings/Cargo.toml delete mode 120000 crates/context_server_settings/LICENSE-GPL delete mode 100644 crates/context_server_settings/src/context_server_settings.rs create mode 100644 crates/project/src/context_server_store.rs rename crates/{context_server/src/extension_context_server.rs => project/src/context_server_store/extension.rs} (74%) rename crates/{context_server/src => project/src/context_server_store}/registry.rs (82%) diff --git a/Cargo.lock b/Cargo.lock index a92ff72c9c..aa0ceb043c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,7 +484,6 @@ dependencies = [ "client", "collections", "command_palette_hooks", - "context_server", "ctor", "db", "editor", @@ -3328,40 +3327,19 @@ name = "context_server" version = "0.1.0" dependencies = [ "anyhow", - "assistant_tool", "async-trait", "collections", - "command_palette_hooks", - "context_server_settings", - "extension", "futures 0.3.31", "gpui", - "icons", - "language_model", "log", "parking_lot", "postage", - "project", - "serde", - "serde_json", - "settings", - "smol", - "url", - "util", - "workspace-hack", -] - -[[package]] -name = "context_server_settings" -version = "0.1.0" -dependencies = [ - "anyhow", - "collections", - "gpui", "schemars", "serde", "serde_json", - "settings", + "smol", + "url", + "util", "workspace-hack", ] @@ -5028,7 +5006,6 @@ dependencies = [ "clap", "client", "collections", - "context_server", "dirs 4.0.0", "dotenv", "env_logger 0.11.8", @@ -5181,7 +5158,6 @@ dependencies = [ "async-trait", "client", "collections", - "context_server_settings", "ctor", "env_logger 0.11.8", "extension", @@ -11110,6 +11086,7 @@ dependencies = [ "client", "clock", "collections", + "context_server", "dap", "dap_adapters", "env_logger 0.11.8", diff --git a/Cargo.toml b/Cargo.toml index b971db3163..85d4babcfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,6 @@ members = [ "crates/component", "crates/component_preview", "crates/context_server", - "crates/context_server_settings", "crates/copilot", "crates/credentials_provider", "crates/dap", @@ -243,7 +242,6 @@ command_palette_hooks = { path = "crates/command_palette_hooks" } component = { path = "crates/component" } component_preview = { path = "crates/component_preview" } context_server = { path = "crates/context_server" } -context_server_settings = { path = "crates/context_server_settings" } copilot = { path = "crates/copilot" } credentials_provider = { path = "crates/credentials_provider" } dap = { path = "crates/dap" } diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index cd61b8e1b1..76d649d1ca 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -3487,7 +3487,6 @@ fn open_editor_at_position( #[cfg(test)] mod tests { use assistant_tool::{ToolRegistry, ToolWorkingSet}; - use context_server::ContextServerSettings; use editor::EditorSettings; use fs::FakeFs; use gpui::{TestAppContext, VisualTestContext}; @@ -3559,7 +3558,6 @@ mod tests { workspace::init_settings(cx); language_model::init_settings(cx); ThemeSettings::register(cx); - ContextServerSettings::register(cx); EditorSettings::register(cx); ToolRegistry::default_global(cx); }); diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 894c0e5b93..437e180e02 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1748,7 +1748,6 @@ mod tests { use crate::{Keep, ThreadStore, thread_store}; use assistant_settings::AssistantSettings; use assistant_tool::ToolWorkingSet; - use context_server::ContextServerSettings; use editor::EditorSettings; use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; use project::{FakeFs, Project}; @@ -1771,7 +1770,6 @@ mod tests { thread_store::init(cx); workspace::init_settings(cx); ThemeSettings::register(cx); - ContextServerSettings::register(cx); EditorSettings::register(cx); language_model::init_settings(cx); }); @@ -1928,7 +1926,6 @@ mod tests { thread_store::init(cx); workspace::init_settings(cx); ThemeSettings::register(cx); - ContextServerSettings::register(cx); EditorSettings::register(cx); language_model::init_settings(cx); workspace::register_project_item::(cx); diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 241d5922e5..3e34797db8 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -7,6 +7,7 @@ mod buffer_codegen; mod context; mod context_picker; mod context_server_configuration; +mod context_server_tool; mod context_store; mod context_strip; mod debug; diff --git a/crates/agent/src/assistant_configuration.rs b/crates/agent/src/assistant_configuration.rs index 06331187bf..f089506223 100644 --- a/crates/agent/src/assistant_configuration.rs +++ b/crates/agent/src/assistant_configuration.rs @@ -8,13 +8,14 @@ use std::{sync::Arc, time::Duration}; use assistant_settings::AssistantSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; use collections::HashMap; -use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus}; +use context_server::ContextServerId; use fs::Fs; use gpui::{ Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription, pulsating_between, }; use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; +use project::context_server_store::{ContextServerStatus, ContextServerStore}; use settings::{Settings, update_settings_file}; use ui::{ Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState, @@ -33,8 +34,8 @@ pub struct AssistantConfiguration { fs: Arc, focus_handle: FocusHandle, configuration_views_by_provider: HashMap, - context_server_manager: Entity, - expanded_context_server_tools: HashMap, bool>, + context_server_store: Entity, + expanded_context_server_tools: HashMap, tools: Entity, _registry_subscription: Subscription, scroll_handle: ScrollHandle, @@ -44,7 +45,7 @@ pub struct AssistantConfiguration { impl AssistantConfiguration { pub fn new( fs: Arc, - context_server_manager: Entity, + context_server_store: Entity, tools: Entity, window: &mut Window, cx: &mut Context, @@ -75,7 +76,7 @@ impl AssistantConfiguration { fs, focus_handle, configuration_views_by_provider: HashMap::default(), - context_server_manager, + context_server_store, expanded_context_server_tools: HashMap::default(), tools, _registry_subscription: registry_subscription, @@ -306,7 +307,7 @@ impl AssistantConfiguration { window: &mut Window, cx: &mut Context, ) -> impl IntoElement { - let context_servers = self.context_server_manager.read(cx).all_servers().clone(); + let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone(); const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly."; @@ -322,9 +323,9 @@ impl AssistantConfiguration { .child(Label::new(SUBHEADING).color(Color::Muted)), ) .children( - context_servers - .into_iter() - .map(|context_server| self.render_context_server(context_server, window, cx)), + context_server_ids.into_iter().map(|context_server_id| { + self.render_context_server(context_server_id, window, cx) + }), ) .child( h_flex() @@ -374,19 +375,20 @@ impl AssistantConfiguration { fn render_context_server( &self, - context_server: Arc, + context_server_id: ContextServerId, window: &mut Window, cx: &mut Context, ) -> impl use<> + IntoElement { let tools_by_source = self.tools.read(cx).tools_by_source(cx); let server_status = self - .context_server_manager + .context_server_store .read(cx) - .status_for_server(&context_server.id()); + .status_for_server(&context_server_id) + .unwrap_or(ContextServerStatus::Stopped); - let is_running = matches!(server_status, Some(ContextServerStatus::Running)); + let is_running = matches!(server_status, ContextServerStatus::Running); - let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() { + let error = if let ContextServerStatus::Error(error) = server_status.clone() { Some(error) } else { None @@ -394,13 +396,13 @@ impl AssistantConfiguration { let are_tools_expanded = self .expanded_context_server_tools - .get(&context_server.id()) + .get(&context_server_id) .copied() .unwrap_or_default(); let tools = tools_by_source .get(&ToolSource::ContextServer { - id: context_server.id().into(), + id: context_server_id.0.clone().into(), }) .map_or([].as_slice(), |tools| tools.as_slice()); let tool_count = tools.len(); @@ -408,7 +410,7 @@ impl AssistantConfiguration { let border_color = cx.theme().colors().border.opacity(0.6); v_flex() - .id(SharedString::from(context_server.id())) + .id(SharedString::from(context_server_id.0.clone())) .border_1() .rounded_md() .border_color(border_color) @@ -432,7 +434,7 @@ impl AssistantConfiguration { ) .disabled(tool_count == 0) .on_click(cx.listener({ - let context_server_id = context_server.id(); + let context_server_id = context_server_id.clone(); move |this, _event, _window, _cx| { let is_open = this .expanded_context_server_tools @@ -444,14 +446,14 @@ impl AssistantConfiguration { })), ) .child(match server_status { - Some(ContextServerStatus::Starting) => { + ContextServerStatus::Starting => { let color = Color::Success.color(cx); Indicator::dot() .color(Color::Success) .with_animation( SharedString::from(format!( "{}-starting", - context_server.id(), + context_server_id.0.clone(), )), Animation::new(Duration::from_secs(2)) .repeat() @@ -462,15 +464,17 @@ impl AssistantConfiguration { ) .into_any_element() } - Some(ContextServerStatus::Running) => { + ContextServerStatus::Running => { Indicator::dot().color(Color::Success).into_any_element() } - Some(ContextServerStatus::Error(_)) => { + ContextServerStatus::Error(_) => { Indicator::dot().color(Color::Error).into_any_element() } - None => Indicator::dot().color(Color::Muted).into_any_element(), + ContextServerStatus::Stopped => { + Indicator::dot().color(Color::Muted).into_any_element() + } }) - .child(Label::new(context_server.id()).ml_0p5()) + .child(Label::new(context_server_id.0.clone()).ml_0p5()) .when(is_running, |this| { this.child( Label::new(if tool_count == 1 { @@ -487,32 +491,22 @@ impl AssistantConfiguration { Switch::new("context-server-switch", is_running.into()) .color(SwitchColor::Accent) .on_click({ - let context_server_manager = self.context_server_manager.clone(); - let context_server = context_server.clone(); + let context_server_manager = self.context_server_store.clone(); + let context_server_id = context_server_id.clone(); move |state, _window, cx| match state { ToggleState::Unselected | ToggleState::Indeterminate => { context_server_manager.update(cx, |this, cx| { - this.stop_server(context_server.clone(), cx).log_err(); + this.stop_server(&context_server_id, cx).log_err(); }); } ToggleState::Selected => { - cx.spawn({ - let context_server_manager = - context_server_manager.clone(); - let context_server = context_server.clone(); - async move |cx| { - if let Some(start_server_task) = - context_server_manager - .update(cx, |this, cx| { - this.start_server(context_server, cx) - }) - .log_err() - { - start_server_task.await.log_err(); - } + context_server_manager.update(cx, |this, cx| { + if let Some(server) = + this.get_server(&context_server_id) + { + this.start_server(server, cx).log_err(); } }) - .detach(); } } }), diff --git a/crates/agent/src/assistant_configuration/add_context_server_modal.rs b/crates/agent/src/assistant_configuration/add_context_server_modal.rs index a1f686e029..6109b2d513 100644 --- a/crates/agent/src/assistant_configuration/add_context_server_modal.rs +++ b/crates/agent/src/assistant_configuration/add_context_server_modal.rs @@ -1,5 +1,6 @@ -use context_server::{ContextServerSettings, ServerCommand, ServerConfig}; +use context_server::ContextServerCommand; use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, prelude::*}; +use project::project_settings::{ContextServerConfiguration, ProjectSettings}; use serde_json::json; use settings::update_settings_file; use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*}; @@ -77,11 +78,11 @@ impl AddContextServerModal { if let Some(workspace) = self.workspace.upgrade() { workspace.update(cx, |workspace, cx| { let fs = workspace.app_state().fs.clone(); - update_settings_file::(fs.clone(), cx, |settings, _| { + update_settings_file::(fs.clone(), cx, |settings, _| { settings.context_servers.insert( name.into(), - ServerConfig { - command: Some(ServerCommand { + ContextServerConfiguration { + command: Some(ContextServerCommand { path, args, env: None, diff --git a/crates/agent/src/assistant_configuration/configure_context_server_modal.rs b/crates/agent/src/assistant_configuration/configure_context_server_modal.rs index 5f375782b4..8365fb577d 100644 --- a/crates/agent/src/assistant_configuration/configure_context_server_modal.rs +++ b/crates/agent/src/assistant_configuration/configure_context_server_modal.rs @@ -4,9 +4,8 @@ use std::{ }; use anyhow::Context as _; -use context_server::manager::{ContextServerManager, ContextServerStatus}; +use context_server::ContextServerId; use editor::{Editor, EditorElement, EditorStyle}; -use extension::ContextServerConfiguration; use gpui::{ Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage, @@ -14,6 +13,10 @@ use gpui::{ use language::{Language, LanguageRegistry}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; use notifications::status_toast::{StatusToast, ToastIcon}; +use project::{ + context_server_store::{ContextServerStatus, ContextServerStore}, + project_settings::{ContextServerConfiguration, ProjectSettings}, +}; use settings::{Settings as _, update_settings_file}; use theme::ThemeSettings; use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*}; @@ -23,11 +26,11 @@ use workspace::{ModalView, Workspace}; pub(crate) struct ConfigureContextServerModal { workspace: WeakEntity, context_servers_to_setup: Vec, - context_server_manager: Entity, + context_server_store: Entity, } struct ConfigureContextServer { - id: Arc, + id: ContextServerId, installation_instructions: Entity, settings_validator: Option, settings_editor: Entity, @@ -37,9 +40,9 @@ struct ConfigureContextServer { impl ConfigureContextServerModal { pub fn new( - configurations: impl Iterator, ContextServerConfiguration)>, + configurations: impl Iterator, + context_server_store: Entity, jsonc_language: Option>, - context_server_manager: Entity, language_registry: Arc, workspace: WeakEntity, window: &mut Window, @@ -85,7 +88,7 @@ impl ConfigureContextServerModal { Some(Self { workspace, context_servers_to_setup, - context_server_manager, + context_server_store, }) } } @@ -126,14 +129,14 @@ impl ConfigureContextServerModal { } let id = configuration.id.clone(); - let settings_changed = context_server::ContextServerSettings::get_global(cx) + let settings_changed = ProjectSettings::get_global(cx) .context_servers - .get(&id) + .get(&id.0) .map_or(true, |config| { config.settings.as_ref() != Some(&settings_value) }); - let is_running = self.context_server_manager.read(cx).status_for_server(&id) + let is_running = self.context_server_store.read(cx).status_for_server(&id) == Some(ContextServerStatus::Running); if !settings_changed && is_running { @@ -143,7 +146,7 @@ impl ConfigureContextServerModal { configuration.waiting_for_context_server = true; - let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx); + let task = wait_for_context_server(&self.context_server_store, id.clone(), cx); cx.spawn({ let id = id.clone(); async move |this, cx| { @@ -167,29 +170,25 @@ impl ConfigureContextServerModal { .detach(); // When we write the settings to the file, the context server will be restarted. - update_settings_file::( - workspace.read(cx).app_state().fs.clone(), - cx, - { - let id = id.clone(); - |settings, _| { - if let Some(server_config) = settings.context_servers.get_mut(&id) { - server_config.settings = Some(settings_value); - } else { - settings.context_servers.insert( - id, - context_server::ServerConfig { - settings: Some(settings_value), - ..Default::default() - }, - ); - } + update_settings_file::(workspace.read(cx).app_state().fs.clone(), cx, { + let id = id.clone(); + |settings, _| { + if let Some(server_config) = settings.context_servers.get_mut(&id.0) { + server_config.settings = Some(settings_value); + } else { + settings.context_servers.insert( + id.0, + ContextServerConfiguration { + settings: Some(settings_value), + ..Default::default() + }, + ); } - }, - ); + } + }); } - fn complete_setup(&mut self, id: Arc, cx: &mut Context) { + fn complete_setup(&mut self, id: ContextServerId, cx: &mut Context) { self.context_servers_to_setup.remove(0); cx.notify(); @@ -223,31 +222,40 @@ impl ConfigureContextServerModal { } fn wait_for_context_server( - context_server_manager: &Entity, - context_server_id: Arc, + context_server_store: &Entity, + context_server_id: ContextServerId, cx: &mut App, ) -> Task>> { let (tx, rx) = futures::channel::oneshot::channel(); let tx = Arc::new(Mutex::new(Some(tx))); - let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event { - context_server::manager::Event::ServerStatusChanged { server_id, status } => match status { - Some(ContextServerStatus::Running) => { - if server_id == &context_server_id { - if let Some(tx) = tx.lock().unwrap().take() { - let _ = tx.send(Ok(())); + let subscription = cx.subscribe(context_server_store, move |_, event, _cx| match event { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { + match status { + ContextServerStatus::Running => { + if server_id == &context_server_id { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(Ok(())); + } } } - } - Some(ContextServerStatus::Error(error)) => { - if server_id == &context_server_id { - if let Some(tx) = tx.lock().unwrap().take() { - let _ = tx.send(Err(error.clone())); + ContextServerStatus::Stopped => { + if server_id == &context_server_id { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(Err("Context server stopped running".into())); + } } } + ContextServerStatus::Error(error) => { + if server_id == &context_server_id { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(Err(error.clone())); + } + } + } + _ => {} } - _ => {} - }, + } }); cx.spawn(async move |_cx| { diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 74d7ea3197..4641d1efe0 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -1026,14 +1026,14 @@ impl AssistantPanel { } pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context) { - let context_server_manager = self.thread_store.read(cx).context_server_manager(); + let context_server_store = self.project.read(cx).context_server_store(); let tools = self.thread_store.read(cx).tools(); let fs = self.fs.clone(); self.set_active_view(ActiveView::Configuration, window, cx); self.configuration = Some(cx.new(|cx| { - AssistantConfiguration::new(fs, context_server_manager, tools, window, cx) + AssistantConfiguration::new(fs, context_server_store, tools, window, cx) })); if let Some(configuration) = self.configuration.as_ref() { diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index 4e0ee6a9c3..62c796f44b 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -1095,9 +1095,7 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_model::LanguageModelRegistry::test); - cx.update(language_settings::init); + init_test(cx); let text = indoc! {" fn main() { @@ -1167,8 +1165,7 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); + init_test(cx); let text = indoc! {" fn main() { @@ -1237,9 +1234,7 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); + init_test(cx); let text = concat!( "fn main() {\n", @@ -1305,9 +1300,7 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); + init_test(cx); let text = indoc! {" func main() { @@ -1367,9 +1360,7 @@ mod tests { #[gpui::test] async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); + init_test(cx); let text = indoc! {" fn main() { @@ -1473,6 +1464,13 @@ mod tests { } } + fn init_test(cx: &mut TestAppContext) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(Project::init_settings); + cx.update(language_settings::init); + } + fn simulate_response_stream( codegen: Entity, cx: &mut TestAppContext, diff --git a/crates/agent/src/context_server_configuration.rs b/crates/agent/src/context_server_configuration.rs index ec92b70149..ccf0204058 100644 --- a/crates/agent/src/context_server_configuration.rs +++ b/crates/agent/src/context_server_configuration.rs @@ -1,14 +1,15 @@ use std::sync::Arc; use anyhow::Context as _; -use context_server::ContextServerDescriptorRegistry; +use context_server::ContextServerId; use extension::ExtensionManifest; use language::LanguageRegistry; +use project::context_server_store::registry::ContextServerDescriptorRegistry; use ui::prelude::*; use util::ResultExt; use workspace::Workspace; -use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal}; +use crate::assistant_configuration::ConfigureContextServerModal; pub(crate) fn init(language_registry: Arc, cx: &mut App) { cx.observe_new(move |_: &mut Workspace, window, cx| { @@ -60,18 +61,10 @@ fn show_configure_mcp_modal( window: &mut Window, cx: &mut Context<'_, Workspace>, ) { - let Some(context_server_manager) = workspace.panel::(cx).map(|panel| { - panel - .read(cx) - .thread_store() - .read(cx) - .context_server_manager() - }) else { - return; - }; + let context_server_store = workspace.project().read(cx).context_server_store(); - let registry = ContextServerDescriptorRegistry::global(cx).read(cx); - let project = workspace.project().clone(); + let registry = ContextServerDescriptorRegistry::default_global(cx).read(cx); + let worktree_store = workspace.project().read(cx).worktree_store(); let configuration_tasks = manifest .context_servers .keys() @@ -80,15 +73,15 @@ fn show_configure_mcp_modal( |key| { let descriptor = registry.context_server_descriptor(&key)?; Some(cx.spawn({ - let project = project.clone(); + let worktree_store = worktree_store.clone(); async move |_, cx| { descriptor - .configuration(project, &cx) + .configuration(worktree_store.clone(), &cx) .await .context("Failed to resolve context server configuration") .log_err() .flatten() - .map(|config| (key, config)) + .map(|config| (ContextServerId(key), config)) } })) } @@ -104,8 +97,8 @@ fn show_configure_mcp_modal( this.update_in(cx, |this, window, cx| { let modal = ConfigureContextServerModal::new( descriptors.into_iter().flatten(), + context_server_store, jsonc_language, - context_server_manager, language_registry, cx.entity().downgrade(), window, diff --git a/crates/context_server/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs similarity index 88% rename from crates/context_server/src/context_server_tool.rs rename to crates/agent/src/context_server_tool.rs index 93422e87d8..69283b9b63 100644 --- a/crates/context_server/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -2,29 +2,27 @@ use std::sync::Arc; use anyhow::{Result, anyhow, bail}; use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource}; +use context_server::{ContextServerId, types}; use gpui::{AnyWindowHandle, App, Entity, Task}; -use icons::IconName; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; -use project::Project; - -use crate::manager::ContextServerManager; -use crate::types; +use project::{Project, context_server_store::ContextServerStore}; +use ui::IconName; pub struct ContextServerTool { - server_manager: Entity, - server_id: Arc, + store: Entity, + server_id: ContextServerId, tool: types::Tool, } impl ContextServerTool { pub fn new( - server_manager: Entity, - server_id: impl Into>, + store: Entity, + server_id: ContextServerId, tool: types::Tool, ) -> Self { Self { - server_manager, - server_id: server_id.into(), + store, + server_id, tool, } } @@ -45,7 +43,7 @@ impl Tool for ContextServerTool { fn source(&self) -> ToolSource { ToolSource::ContextServer { - id: self.server_id.clone().into(), + id: self.server_id.clone().0.into(), } } @@ -80,7 +78,7 @@ impl Tool for ContextServerTool { _window: Option, cx: &mut App, ) -> ToolResult { - if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { + if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) { let tool_name = self.tool.name.clone(); let server_clone = server.clone(); let input_clone = input.clone(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 7e633389a8..6dc9d87dc2 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2660,7 +2660,6 @@ mod tests { use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use assistant_settings::AssistantSettings; use assistant_tool::ToolRegistry; - use context_server::ContextServerSettings; use editor::EditorSettings; use gpui::TestAppContext; use language_model::fake_provider::FakeLanguageModel; @@ -3082,7 +3081,6 @@ fn main() {{ workspace::init_settings(cx); language_model::init_settings(cx); ThemeSettings::register(cx); - ContextServerSettings::register(cx); EditorSettings::register(cx); ToolRegistry::default_global(cx); }); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 9ecd989139..09f8498af6 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -9,8 +9,7 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; -use context_server::manager::{ContextServerManager, ContextServerStatus}; -use context_server::{ContextServerDescriptorRegistry, ContextServerTool}; +use context_server::ContextServerId; use futures::channel::{mpsc, oneshot}; use futures::future::{self, BoxFuture, Shared}; use futures::{FutureExt as _, StreamExt as _}; @@ -21,6 +20,7 @@ use gpui::{ use heed::Database; use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; +use project::context_server_store::{ContextServerStatus, ContextServerStore}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext, @@ -30,6 +30,7 @@ use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; +use crate::context_server_tool::ContextServerTool; use crate::thread::{ DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, }; @@ -62,8 +63,7 @@ pub struct ThreadStore { tools: Entity, prompt_builder: Arc, prompt_store: Option>, - context_server_manager: Entity, - context_server_tool_ids: HashMap, Vec>, + context_server_tool_ids: HashMap>, threads: Vec, project_context: SharedProjectContext, reload_system_prompt_tx: mpsc::Sender<()>, @@ -108,11 +108,6 @@ impl ThreadStore { prompt_store: Option>, cx: &mut Context, ) -> (Self, oneshot::Receiver<()>) { - let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx); - let context_server_manager = cx.new(|cx| { - ContextServerManager::new(context_server_factory_registry, project.clone(), cx) - }); - let mut subscriptions = vec![ cx.observe_global::(move |this: &mut Self, cx| { this.load_default_profile(cx); @@ -159,7 +154,6 @@ impl ThreadStore { tools, prompt_builder, prompt_store, - context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), project_context: SharedProjectContext::default(), @@ -354,10 +348,6 @@ impl ThreadStore { }) } - pub fn context_server_manager(&self) -> Entity { - self.context_server_manager.clone() - } - pub fn prompt_store(&self) -> &Option> { &self.prompt_store } @@ -494,11 +484,17 @@ impl ThreadStore { }); if profile.enable_all_context_servers { - for context_server in self.context_server_manager.read(cx).all_servers() { + for context_server_id in self + .project + .read(cx) + .context_server_store() + .read(cx) + .all_server_ids() + { self.tools.update(cx, |tools, cx| { tools.enable_source( ToolSource::ContextServer { - id: context_server.id().into(), + id: context_server_id.0.into(), }, cx, ); @@ -541,7 +537,7 @@ impl ThreadStore { fn register_context_server_handlers(&self, cx: &mut Context) { cx.subscribe( - &self.context_server_manager.clone(), + &self.project.read(cx).context_server_store(), Self::handle_context_server_event, ) .detach(); @@ -549,18 +545,19 @@ impl ThreadStore { fn handle_context_server_event( &mut self, - context_server_manager: Entity, - event: &context_server::manager::Event, + context_server_store: Entity, + event: &project::context_server_store::Event, cx: &mut Context, ) { let tool_working_set = self.tools.clone(); match event { - context_server::manager::Event::ServerStatusChanged { server_id, status } => { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { - Some(ContextServerStatus::Running) => { - if let Some(server) = context_server_manager.read(cx).get_server(server_id) + ContextServerStatus::Running => { + if let Some(server) = + context_server_store.read(cx).get_running_server(server_id) { - let context_server_manager = context_server_manager.clone(); + let context_server_manager = context_server_store.clone(); cx.spawn({ let server = server.clone(); let server_id = server_id.clone(); @@ -608,7 +605,7 @@ impl ThreadStore { .detach(); } } - None => { + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { tool_working_set.update(cx, |tool_working_set, _| { tool_working_set.remove(&tool_ids); diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index ba2026561a..ea014f6c37 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -31,7 +31,6 @@ async-watch.workspace = true client.workspace = true collections.workspace = true command_palette_hooks.workspace = true -context_server.workspace = true db.workspace = true editor.workspace = true feature_flags.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 47dd6feab6..f54a761037 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -106,7 +106,6 @@ pub fn init( assistant_slash_command::init(cx); assistant_tool::init(cx); assistant_panel::init(cx); - context_server::init(cx); register_slash_commands(cx); inline_assistant::init( diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index d19bb9f396..6ed5edc902 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1192,21 +1192,19 @@ impl AssistantPanel { fn restart_context_servers( workspace: &mut Workspace, - _action: &context_server::Restart, + _action: &project::context_server_store::Restart, _: &mut Window, cx: &mut Context, ) { - let Some(assistant_panel) = workspace.panel::(cx) else { - return; - }; - - assistant_panel.update(cx, |assistant_panel, cx| { - assistant_panel - .context_store - .update(cx, |context_store, cx| { - context_store.restart_context_servers(cx); - }); - }); + workspace + .project() + .read(cx) + .context_server_store() + .update(cx, |store, cx| { + for server in store.running_servers() { + store.restart_server(&server.id(), cx).log_err(); + } + }); } } diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index f0f67506a6..2cf5a3b654 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -7,15 +7,17 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet}; use client::{Client, TypedEnvelope, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::HashMap; -use context_server::ContextServerDescriptorRegistry; -use context_server::manager::{ContextServerManager, ContextServerStatus}; +use context_server::ContextServerId; use fs::{Fs, RemoveOptions}; use futures::StreamExt; use fuzzy::StringMatchCandidate; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; use language::LanguageRegistry; use paths::contexts_dir; -use project::Project; +use project::{ + Project, + context_server_store::{ContextServerStatus, ContextServerStore}, +}; use prompt_store::PromptBuilder; use regex::Regex; use rpc::AnyProtoClient; @@ -40,8 +42,7 @@ pub struct RemoteContextMetadata { pub struct ContextStore { contexts: Vec, contexts_metadata: Vec, - context_server_manager: Entity, - context_server_slash_command_ids: HashMap, Vec>, + context_server_slash_command_ids: HashMap>, host_contexts: Vec, fs: Arc, languages: Arc, @@ -98,15 +99,9 @@ impl ContextStore { let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new(|cx: &mut Context| { - let context_server_factory_registry = - ContextServerDescriptorRegistry::default_global(cx); - let context_server_manager = cx.new(|cx| { - ContextServerManager::new(context_server_factory_registry, project.clone(), cx) - }); let mut this = Self { contexts: Vec::new(), contexts_metadata: Vec::new(), - context_server_manager, context_server_slash_command_ids: HashMap::default(), host_contexts: Vec::new(), fs, @@ -802,22 +797,9 @@ impl ContextStore { }) } - pub fn restart_context_servers(&mut self, cx: &mut Context) { - cx.update_entity( - &self.context_server_manager, - |context_server_manager, cx| { - for server in context_server_manager.running_servers() { - context_server_manager - .restart_server(&server.id(), cx) - .detach_and_log_err(cx); - } - }, - ); - } - fn register_context_server_handlers(&self, cx: &mut Context) { cx.subscribe( - &self.context_server_manager.clone(), + &self.project.read(cx).context_server_store(), Self::handle_context_server_event, ) .detach(); @@ -825,16 +807,18 @@ impl ContextStore { fn handle_context_server_event( &mut self, - context_server_manager: Entity, - event: &context_server::manager::Event, + context_server_manager: Entity, + event: &project::context_server_store::Event, cx: &mut Context, ) { let slash_command_working_set = self.slash_commands.clone(); match event { - context_server::manager::Event::ServerStatusChanged { server_id, status } => { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { - Some(ContextServerStatus::Running) => { - if let Some(server) = context_server_manager.read(cx).get_server(server_id) + ContextServerStatus::Running => { + if let Some(server) = context_server_manager + .read(cx) + .get_running_server(server_id) { let context_server_manager = context_server_manager.clone(); cx.spawn({ @@ -858,7 +842,7 @@ impl ContextStore { slash_command_working_set.insert(Arc::new( assistant_slash_commands::ContextServerSlashCommand::new( context_server_manager.clone(), - &server, + server.id(), prompt, ), )) @@ -877,7 +861,7 @@ impl ContextStore { .detach(); } } - None => { + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { if let Some(slash_command_ids) = self.context_server_slash_command_ids.remove(server_id) { diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index 9ad25cadf9..5f9500a43f 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -4,12 +4,10 @@ use assistant_slash_command::{ SlashCommandOutputSection, SlashCommandResult, }; use collections::HashMap; -use context_server::{ - manager::{ContextServer, ContextServerManager}, - types::Prompt, -}; +use context_server::{ContextServerId, types::Prompt}; use gpui::{App, Entity, Task, WeakEntity, Window}; use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate}; +use project::context_server_store::ContextServerStore; use std::sync::Arc; use std::sync::atomic::AtomicBool; use text::LineEnding; @@ -19,21 +17,17 @@ use workspace::Workspace; use crate::create_label_for_command; pub struct ContextServerSlashCommand { - server_manager: Entity, - server_id: Arc, + store: Entity, + server_id: ContextServerId, prompt: Prompt, } impl ContextServerSlashCommand { - pub fn new( - server_manager: Entity, - server: &Arc, - prompt: Prompt, - ) -> Self { + pub fn new(store: Entity, id: ContextServerId, prompt: Prompt) -> Self { Self { - server_id: server.id(), + server_id: id, prompt, - server_manager, + store, } } } @@ -88,7 +82,7 @@ impl SlashCommand for ContextServerSlashCommand { let server_id = self.server_id.clone(); let prompt_name = self.prompt.name.clone(); - if let Some(server) = self.server_manager.read(cx).get_server(&server_id) { + if let Some(server) = self.store.read(cx).get_running_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client() else { return Err(anyhow!("Context server not initialized")); @@ -142,8 +136,8 @@ impl SlashCommand for ContextServerSlashCommand { Err(e) => return Task::ready(Err(e)), }; - let manager = self.server_manager.read(cx); - if let Some(server) = manager.get_server(&server_id) { + let store = self.store.read(cx); + if let Some(server) = store.get_running_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client() else { return Err(anyhow!("Context server not initialized")); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 9655fcbcb8..008e9eeaf8 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6709,8 +6709,6 @@ async fn test_context_collaboration_with_reconnect( assert_eq!(project.collaborators().len(), 1); }); - cx_a.update(context_server::init); - cx_b.update(context_server::init); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a .update(|cx| { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index ca94312e0f..a3afcacbe9 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -712,6 +712,7 @@ impl TestClient { worktree .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) .await; + cx.run_until_parked(); (project, worktree.read_with(cx, |tree, _| tree.id())) } diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index f229ed3e65..62a5354b39 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -13,28 +13,17 @@ path = "src/context_server.rs" [dependencies] anyhow.workspace = true -assistant_tool.workspace = true async-trait.workspace = true collections.workspace = true -command_palette_hooks.workspace = true -context_server_settings.workspace = true -extension.workspace = true futures.workspace = true gpui.workspace = true -icons.workspace = true -language_model.workspace = true log.workspace = true parking_lot.workspace = true postage.workspace = true -project.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true -settings.workspace = true smol.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true workspace-hack.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 7a2ae71c8b..79ee2f6c72 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -40,7 +40,7 @@ pub enum RequestId { Str(String), } -pub struct Client { +pub(crate) struct Client { server_id: ContextServerId, next_id: AtomicI32, outbound_tx: channel::Sender, @@ -59,7 +59,7 @@ pub struct Client { #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct ContextServerId(pub Arc); +pub(crate) struct ContextServerId(pub Arc); fn is_null_value(value: &T) -> bool { if let Ok(Value::Null) = serde_json::to_value(value) { @@ -367,6 +367,7 @@ impl Client { Ok(()) } + #[allow(unused)] pub fn on_notification(&self, method: &'static str, f: F) where F: 'static + Send + FnMut(Value, AsyncApp), @@ -375,14 +376,6 @@ impl Client { .lock() .insert(method, Box::new(f)); } - - pub fn name(&self) -> &str { - &self.name - } - - pub fn server_id(&self) -> ContextServerId { - self.server_id.clone() - } } impl fmt::Display for ContextServerId { diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 72b5a7cfc9..19f2f75541 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,30 +1,117 @@ pub mod client; -mod context_server_tool; -mod extension_context_server; -pub mod manager; pub mod protocol; -mod registry; -mod transport; +pub mod transport; pub mod types; -use command_palette_hooks::CommandPaletteFilter; -pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig}; -use gpui::{App, actions}; +use std::fmt::Display; +use std::path::Path; +use std::sync::Arc; -pub use crate::context_server_tool::ContextServerTool; -pub use crate::registry::ContextServerDescriptorRegistry; +use anyhow::Result; +use client::Client; +use collections::HashMap; +use gpui::AsyncApp; +use parking_lot::RwLock; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; -actions!(context_servers, [Restart]); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ContextServerId(pub Arc); -/// The namespace for the context servers actions. -pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; - -pub fn init(cx: &mut App) { - context_server_settings::init(cx); - ContextServerDescriptorRegistry::default_global(cx); - extension_context_server::init(cx); - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - }); +impl Display for ContextServerId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)] +pub struct ContextServerCommand { + pub path: String, + pub args: Vec, + pub env: Option>, +} + +enum ContextServerTransport { + Stdio(ContextServerCommand), + Custom(Arc), +} + +pub struct ContextServer { + id: ContextServerId, + client: RwLock>>, + configuration: ContextServerTransport, +} + +impl ContextServer { + pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self { + Self { + id, + client: RwLock::new(None), + configuration: ContextServerTransport::Stdio(command), + } + } + + pub fn new(id: ContextServerId, transport: Arc) -> Self { + Self { + id, + client: RwLock::new(None), + configuration: ContextServerTransport::Custom(transport), + } + } + + pub fn id(&self) -> ContextServerId { + self.id.clone() + } + + pub fn client(&self) -> Option> { + self.client.read().clone() + } + + pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { + let client = match &self.configuration { + ContextServerTransport::Stdio(command) => Client::stdio( + client::ContextServerId(self.id.0.clone()), + client::ModelContextServerBinary { + executable: Path::new(&command.path).to_path_buf(), + args: command.args.clone(), + env: command.env.clone(), + }, + cx.clone(), + )?, + ContextServerTransport::Custom(transport) => Client::new( + client::ContextServerId(self.id.0.clone()), + self.id().0, + transport.clone(), + cx.clone(), + )?, + }; + self.initialize(client).await + } + + async fn initialize(&self, client: Client) -> Result<()> { + log::info!("starting context server {}", self.id); + let protocol = crate::protocol::ModelContextProtocol::new(client); + let client_info = types::Implementation { + name: "Zed".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + let initialized_protocol = protocol.initialize(client_info).await?; + + log::debug!( + "context server {} initialized: {:?}", + self.id, + initialized_protocol.initialize, + ); + + *self.client.write() = Some(Arc::new(initialized_protocol)); + Ok(()) + } + + pub fn stop(&self) -> Result<()> { + let mut client = self.client.write(); + if let Some(protocol) = client.take() { + drop(protocol); + } + Ok(()) + } } diff --git a/crates/context_server/src/manager.rs b/crates/context_server/src/manager.rs deleted file mode 100644 index 7a74e4879e..0000000000 --- a/crates/context_server/src/manager.rs +++ /dev/null @@ -1,584 +0,0 @@ -//! This module implements a context server management system for Zed. -//! -//! It provides functionality to: -//! - Define and load context server settings -//! - Manage individual context servers (start, stop, restart) -//! - Maintain a global manager for all context servers -//! -//! Key components: -//! - `ContextServerSettings`: Defines the structure for server configurations -//! - `ContextServer`: Represents an individual context server -//! - `ContextServerManager`: Manages multiple context servers -//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager -//! -//! The module also includes initialization logic to set up the context server system -//! and react to changes in settings. - -use std::path::Path; -use std::sync::Arc; - -use anyhow::{Result, bail}; -use collections::HashMap; -use command_palette_hooks::CommandPaletteFilter; -use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity}; -use log; -use parking_lot::RwLock; -use project::Project; -use settings::{Settings, SettingsStore}; -use util::ResultExt as _; - -use crate::transport::Transport; -use crate::{ContextServerSettings, ServerConfig}; - -use crate::{ - CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry, - client::{self, Client}, - types, -}; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ContextServerStatus { - Starting, - Running, - Error(Arc), -} - -pub struct ContextServer { - pub id: Arc, - pub config: Arc, - pub client: RwLock>>, - transport: Option>, -} - -impl ContextServer { - pub fn new(id: Arc, config: Arc) -> Self { - Self { - id, - config, - client: RwLock::new(None), - transport: None, - } - } - - #[cfg(any(test, feature = "test-support"))] - pub fn test(id: Arc, transport: Arc) -> Arc { - Arc::new(Self { - id, - client: RwLock::new(None), - config: Arc::new(ServerConfig::default()), - transport: Some(transport), - }) - } - - pub fn id(&self) -> Arc { - self.id.clone() - } - - pub fn config(&self) -> Arc { - self.config.clone() - } - - pub fn client(&self) -> Option> { - self.client.read().clone() - } - - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = if let Some(transport) = self.transport.clone() { - Client::new( - client::ContextServerId(self.id.clone()), - self.id(), - transport, - cx.clone(), - )? - } else { - let Some(command) = &self.config.command else { - bail!("no command specified for server {}", self.id); - }; - Client::stdio( - client::ContextServerId(self.id.clone()), - client::ModelContextServerBinary { - executable: Path::new(&command.path).to_path_buf(), - args: command.args.clone(), - env: command.env.clone(), - }, - cx.clone(), - )? - }; - self.initialize(client).await - } - - async fn initialize(&self, client: Client) -> Result<()> { - log::info!("starting context server {}", self.id); - let protocol = crate::protocol::ModelContextProtocol::new(client); - let client_info = types::Implementation { - name: "Zed".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }; - let initialized_protocol = protocol.initialize(client_info).await?; - - log::debug!( - "context server {} initialized: {:?}", - self.id, - initialized_protocol.initialize, - ); - - *self.client.write() = Some(Arc::new(initialized_protocol)); - Ok(()) - } - - pub fn stop(&self) -> Result<()> { - let mut client = self.client.write(); - if let Some(protocol) = client.take() { - drop(protocol); - } - Ok(()) - } -} - -pub struct ContextServerManager { - servers: HashMap, Arc>, - server_status: HashMap, ContextServerStatus>, - project: Entity, - registry: Entity, - update_servers_task: Option>>, - needs_server_update: bool, - _subscriptions: Vec, -} - -pub enum Event { - ServerStatusChanged { - server_id: Arc, - status: Option, - }, -} - -impl EventEmitter for ContextServerManager {} - -impl ContextServerManager { - pub fn new( - registry: Entity, - project: Entity, - cx: &mut Context, - ) -> Self { - let mut this = Self { - _subscriptions: vec![ - cx.observe(®istry, |this, _registry, cx| { - this.available_context_servers_changed(cx); - }), - cx.observe_global::(|this, cx| { - this.available_context_servers_changed(cx); - }), - ], - project, - registry, - needs_server_update: false, - servers: HashMap::default(), - server_status: HashMap::default(), - update_servers_task: None, - }; - this.available_context_servers_changed(cx); - this - } - - fn available_context_servers_changed(&mut self, cx: &mut Context) { - if self.update_servers_task.is_some() { - self.needs_server_update = true; - } else { - self.update_servers_task = Some(cx.spawn(async move |this, cx| { - this.update(cx, |this, _| { - this.needs_server_update = false; - })?; - - if let Err(err) = Self::maintain_servers(this.clone(), cx).await { - log::error!("Error maintaining context servers: {}", err); - } - - this.update(cx, |this, cx| { - let has_any_context_servers = !this.running_servers().is_empty(); - if has_any_context_servers { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); - }); - } - - this.update_servers_task.take(); - if this.needs_server_update { - this.available_context_servers_changed(cx); - } - })?; - - Ok(()) - })); - } - } - - pub fn get_server(&self, id: &str) -> Option> { - self.servers - .get(id) - .filter(|server| server.client().is_some()) - .cloned() - } - - pub fn status_for_server(&self, id: &str) -> Option { - self.server_status.get(id).cloned() - } - - pub fn start_server( - &self, - server: Arc, - cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await) - } - - pub fn stop_server( - &mut self, - server: Arc, - cx: &mut Context, - ) -> Result<()> { - server.stop().log_err(); - self.update_server_status(server.id().clone(), None, cx); - Ok(()) - } - - pub fn restart_server(&mut self, id: &Arc, cx: &mut Context) -> Task> { - let id = id.clone(); - cx.spawn(async move |this, cx| { - if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? { - let config = server.config(); - - this.update(cx, |this, cx| this.stop_server(server, cx))??; - let new_server = Arc::new(ContextServer::new(id.clone(), config)); - Self::run_server(this, new_server, cx).await?; - } - Ok(()) - }) - } - - pub fn all_servers(&self) -> Vec> { - self.servers.values().cloned().collect() - } - - pub fn running_servers(&self) -> Vec> { - self.servers - .values() - .filter(|server| server.client().is_some()) - .cloned() - .collect() - } - - async fn maintain_servers(this: WeakEntity, cx: &mut AsyncApp) -> Result<()> { - let mut desired_servers = HashMap::default(); - - let (registry, project) = this.update(cx, |this, cx| { - let location = this - .project - .read(cx) - .visible_worktrees(cx) - .next() - .map(|worktree| settings::SettingsLocation { - worktree_id: worktree.read(cx).id(), - path: Path::new(""), - }); - let settings = ContextServerSettings::get(location, cx); - desired_servers = settings.context_servers.clone(); - - (this.registry.clone(), this.project.clone()) - })?; - - for (id, descriptor) in - registry.read_with(cx, |registry, _| registry.context_server_descriptors())? - { - let config = desired_servers.entry(id).or_default(); - if config.command.is_none() { - if let Some(extension_command) = - descriptor.command(project.clone(), &cx).await.log_err() - { - config.command = Some(extension_command); - } - } - } - - let mut servers_to_start = HashMap::default(); - let mut servers_to_stop = HashMap::default(); - - this.update(cx, |this, _cx| { - this.servers.retain(|id, server| { - if desired_servers.contains_key(id) { - true - } else { - servers_to_stop.insert(id.clone(), server.clone()); - false - } - }); - - for (id, config) in desired_servers { - let existing_config = this.servers.get(&id).map(|server| server.config()); - if existing_config.as_deref() != Some(&config) { - let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config))); - servers_to_start.insert(id.clone(), server.clone()); - if let Some(old_server) = this.servers.remove(&id) { - servers_to_stop.insert(id, old_server); - } - } - } - })?; - - for (_, server) in servers_to_stop { - this.update(cx, |this, cx| this.stop_server(server, cx).ok())?; - } - - for (_, server) in servers_to_start { - Self::run_server(this.clone(), server, cx).await.ok(); - } - - Ok(()) - } - - async fn run_server( - this: WeakEntity, - server: Arc, - cx: &mut AsyncApp, - ) -> Result<()> { - let id = server.id(); - - this.update(cx, |this, cx| { - this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx); - this.servers.insert(id.clone(), server.clone()); - })?; - - match server.start(&cx).await { - Ok(_) => { - log::debug!("`{}` context server started", id); - this.update(cx, |this, cx| { - this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx) - })?; - Ok(()) - } - Err(err) => { - log::error!("`{}` context server failed to start\n{}", id, err); - this.update(cx, |this, cx| { - this.update_server_status( - id.clone(), - Some(ContextServerStatus::Error(err.to_string().into())), - cx, - ) - })?; - Err(err) - } - } - } - - fn update_server_status( - &mut self, - id: Arc, - status: Option, - cx: &mut Context, - ) { - if let Some(status) = status.clone() { - self.server_status.insert(id.clone(), status); - } else { - self.server_status.remove(&id); - } - - cx.emit(Event::ServerStatusChanged { - server_id: id, - status, - }); - } -} - -#[cfg(test)] -mod tests { - use std::pin::Pin; - - use crate::types::{ - Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities, - }; - - use super::*; - use futures::{Stream, StreamExt as _, lock::Mutex}; - use gpui::{AppContext as _, TestAppContext}; - use project::FakeFs; - use serde_json::json; - use util::path; - - #[gpui::test] - async fn test_context_server_status(cx: &mut TestAppContext) { - init_test_settings(cx); - let project = create_test_project(cx, json!({"code.rs": ""})).await; - - let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); - let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx)); - - let server_1_id: Arc = "mcp-1".into(); - let server_2_id: Arc = "mcp-2".into(); - - let transport_1 = Arc::new(FakeTransport::new( - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = Arc::new(FakeTransport::new( - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); - - let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone()); - let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone()); - - manager - .update(cx, |manager, cx| manager.start_server(server_1, cx)) - .await - .unwrap(); - - cx.update(|cx| { - assert_eq!( - manager.read(cx).status_for_server(&server_1_id), - Some(ContextServerStatus::Running) - ); - assert_eq!(manager.read(cx).status_for_server(&server_2_id), None); - }); - - manager - .update(cx, |manager, cx| manager.start_server(server_2.clone(), cx)) - .await - .unwrap(); - - cx.update(|cx| { - assert_eq!( - manager.read(cx).status_for_server(&server_1_id), - Some(ContextServerStatus::Running) - ); - assert_eq!( - manager.read(cx).status_for_server(&server_2_id), - Some(ContextServerStatus::Running) - ); - }); - - manager - .update(cx, |manager, cx| manager.stop_server(server_2, cx)) - .unwrap(); - - cx.update(|cx| { - assert_eq!( - manager.read(cx).status_for_server(&server_1_id), - Some(ContextServerStatus::Running) - ); - assert_eq!(manager.read(cx).status_for_server(&server_2_id), None); - }); - } - - async fn create_test_project( - cx: &mut TestAppContext, - files: serde_json::Value, - ) -> Entity { - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), files).await; - Project::test(fs, [path!("/test").as_ref()], cx).await - } - - fn init_test_settings(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - ContextServerSettings::register(cx); - }); - } - - fn create_initialize_response(server_name: String) -> serde_json::Value { - serde_json::to_value(&InitializeResponse { - protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), - server_info: Implementation { - name: server_name, - version: "1.0.0".to_string(), - }, - capabilities: ServerCapabilities::default(), - meta: None, - }) - .unwrap() - } - - struct FakeTransport { - on_request: Arc< - dyn Fn(u64, Option, serde_json::Value) -> Option - + Send - + Sync, - >, - tx: futures::channel::mpsc::UnboundedSender, - rx: Arc>>, - } - - impl FakeTransport { - fn new( - on_request: impl Fn( - u64, - Option, - serde_json::Value, - ) -> Option - + 'static - + Send - + Sync, - ) -> Self { - let (tx, rx) = futures::channel::mpsc::unbounded(); - Self { - on_request: Arc::new(on_request), - tx, - rx: Arc::new(Mutex::new(rx)), - } - } - } - - #[async_trait::async_trait] - impl Transport for FakeTransport { - async fn send(&self, message: String) -> Result<()> { - if let Ok(msg) = serde_json::from_str::(&message) { - let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); - - if let Some(method) = msg.get("method") { - let request_type = method - .as_str() - .and_then(|method| types::RequestType::try_from(method).ok()); - if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { - let response = serde_json::json!({ - "jsonrpc": "2.0", - "id": id, - "result": payload - }); - - self.tx - .unbounded_send(response.to_string()) - .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?; - } - } - } - Ok(()) - } - - fn receive(&self) -> Pin + Send>> { - let rx = self.rx.clone(); - Box::pin(futures::stream::unfold(rx, |rx| async move { - let mut rx_guard = rx.lock().await; - if let Some(message) = rx_guard.next().await { - drop(rx_guard); - Some((message, rx)) - } else { - None - } - })) - } - - fn receive_err(&self) -> Pin + Send>> { - Box::pin(futures::stream::empty()) - } - } -} diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 91fa9289cc..0700a36feb 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -16,7 +16,7 @@ pub struct ModelContextProtocol { } impl ModelContextProtocol { - pub fn new(inner: Client) -> Self { + pub(crate) fn new(inner: Client) -> Self { Self { inner } } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 7478ae44af..a287759125 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -610,7 +610,7 @@ pub enum ToolResponseContent { Resource { resource: ResourceContents }, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListToolsResponse { pub tools: Vec, diff --git a/crates/context_server_settings/Cargo.toml b/crates/context_server_settings/Cargo.toml deleted file mode 100644 index c1be563963..0000000000 --- a/crates/context_server_settings/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "context_server_settings" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/context_server_settings.rs" - -[dependencies] -anyhow.workspace = true -collections.workspace = true -gpui.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -workspace-hack.workspace = true diff --git a/crates/context_server_settings/LICENSE-GPL b/crates/context_server_settings/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/context_server_settings/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/context_server_settings/src/context_server_settings.rs b/crates/context_server_settings/src/context_server_settings.rs deleted file mode 100644 index 8047eab297..0000000000 --- a/crates/context_server_settings/src/context_server_settings.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use gpui::App; -use schemars::JsonSchema; -use schemars::r#gen::SchemaGenerator; -use schemars::schema::{InstanceType, Schema, SchemaObject}; -use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; - -pub fn init(cx: &mut App) { - ContextServerSettings::register(cx); -} - -#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)] -pub struct ServerConfig { - /// The command to run this context server. - /// - /// This will override the command set by an extension. - pub command: Option, - /// The settings for this context server. - /// - /// Consult the documentation for the context server to see what settings - /// are supported. - #[schemars(schema_with = "server_config_settings_json_schema")] - pub settings: Option, -} - -fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema { - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::Object.into()), - ..Default::default() - }) -} - -#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)] -pub struct ServerCommand { - pub path: String, - pub args: Vec, - pub env: Option>, -} - -#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)] -pub struct ContextServerSettings { - /// Settings for context servers used in the Assistant. - #[serde(default)] - pub context_servers: HashMap, ServerConfig>, -} - -impl Settings for ContextServerSettings { - const KEY: Option<&'static str> = None; - - type FileContent = Self; - - fn load( - sources: SettingsSources, - _: &mut gpui::App, - ) -> anyhow::Result { - sources.json_merge() - } - - fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) { - // we don't handle "inputs" replacement strings, see perplexity-key in this example: - // https://code.visualstudio.com/docs/copilot/chat/mcp-servers#_configuration-example - #[derive(Deserialize)] - struct VsCodeServerCommand { - command: String, - args: Option>, - env: Option>, - // note: we don't support envFile and type - } - impl From for ServerCommand { - fn from(cmd: VsCodeServerCommand) -> Self { - Self { - path: cmd.command, - args: cmd.args.unwrap_or_default(), - env: cmd.env, - } - } - } - if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) { - current - .context_servers - .extend(mcp.iter().filter_map(|(k, v)| { - Some(( - k.clone().into(), - ServerConfig { - command: Some( - serde_json::from_value::(v.clone()) - .ok()? - .into(), - ), - settings: None, - }, - )) - })); - } - } -} diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index af7930ba51..2cd7290788 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -30,7 +30,6 @@ chrono.workspace = true clap.workspace = true client.workspace = true collections.workspace = true -context_server.workspace = true dirs.workspace = true dotenv.workspace = true env_logger.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index f0c8cee30c..6623d56d1d 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -423,7 +423,6 @@ pub fn init(cx: &mut App) -> Arc { language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx); - context_server::init(cx); prompt_store::init(cx); let stdout_is_a_pty = false; let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index 513011d083..7858a1eddf 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -362,6 +362,8 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static { server_id: Arc, cx: &mut App, ); + + fn unregister_context_server(&self, server_id: Arc, cx: &mut App); } impl ExtensionContextServerProxy for ExtensionHostProxy { @@ -377,6 +379,14 @@ impl ExtensionContextServerProxy for ExtensionHostProxy { proxy.register_context_server(extension, server_id, cx) } + + fn unregister_context_server(&self, server_id: Arc, cx: &mut App) { + let Some(proxy) = self.context_server_proxy.read().clone() else { + return; + }; + + proxy.unregister_context_server(server_id, cx) + } } pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static { diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index 5d0fdd1fcf..1e1f99168f 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -22,7 +22,6 @@ async-tar.workspace = true async-trait.workspace = true client.workspace = true collections.workspace = true -context_server_settings.workspace = true extension.workspace = true fs.workspace = true futures.workspace = true diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 87f315f658..578b0526d7 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1130,6 +1130,10 @@ impl ExtensionStore { .remove_language_server(&language, language_server_name); } } + + for (server_id, _) in extension.manifest.context_servers.iter() { + self.proxy.unregister_context_server(server_id.clone(), cx); + } } self.wasm_extensions diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs index 8dc6674286..5a234da63a 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs @@ -7,7 +7,6 @@ use anyhow::{Context, Result, anyhow, bail}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use async_trait::async_trait; -use context_server_settings::ContextServerSettings; use extension::{ ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate, }; @@ -676,21 +675,23 @@ impl ExtensionImports for WasmState { })?) } "context_servers" => { - let settings = key + let configuration = key .and_then(|key| { - ContextServerSettings::get(location, cx) + ProjectSettings::get(location, cx) .context_servers .get(key.as_str()) }) .cloned() .unwrap_or_default(); Ok(serde_json::to_string(&settings::ContextServerSettings { - command: settings.command.map(|command| settings::CommandSettings { - path: Some(command.path), - arguments: Some(command.args), - env: command.env.map(|env| env.into_iter().collect()), + command: configuration.command.map(|command| { + settings::CommandSettings { + path: Some(command.path), + arguments: Some(command.args), + env: command.env.map(|env| env.into_iter().collect()), + } }), - settings: settings.settings, + settings: configuration.settings, })?) } _ => { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 346d8910df..0314a91721 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -36,6 +36,7 @@ circular-buffer.workspace = true client.workspace = true clock.workspace = true collections.workspace = true +context_server.workspace = true dap.workspace = true extension.workspace = true fancy-regex.workspace = true diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs new file mode 100644 index 0000000000..2015bf66bd --- /dev/null +++ b/crates/project/src/context_server_store.rs @@ -0,0 +1,1129 @@ +pub mod extension; +pub mod registry; + +use std::{path::Path, sync::Arc}; + +use anyhow::{Context as _, Result}; +use collections::{HashMap, HashSet}; +use context_server::{ContextServer, ContextServerId}; +use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions}; +use registry::ContextServerDescriptorRegistry; +use settings::{Settings as _, SettingsStore}; +use util::ResultExt as _; + +use crate::{ + project_settings::{ContextServerConfiguration, ProjectSettings}, + worktree_store::WorktreeStore, +}; + +pub fn init(cx: &mut App) { + extension::init(cx); +} + +actions!(context_server, [Restart]); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ContextServerStatus { + Starting, + Running, + Stopped, + Error(Arc), +} + +impl ContextServerStatus { + fn from_state(state: &ContextServerState) -> Self { + match state { + ContextServerState::Starting { .. } => ContextServerStatus::Starting, + ContextServerState::Running { .. } => ContextServerStatus::Running, + ContextServerState::Stopped { error, .. } => { + if let Some(error) = error { + ContextServerStatus::Error(error.clone()) + } else { + ContextServerStatus::Stopped + } + } + } + } +} + +enum ContextServerState { + Starting { + server: Arc, + configuration: Arc, + _task: Task<()>, + }, + Running { + server: Arc, + configuration: Arc, + }, + Stopped { + server: Arc, + configuration: Arc, + error: Option>, + }, +} + +impl ContextServerState { + pub fn server(&self) -> Arc { + match self { + ContextServerState::Starting { server, .. } => server.clone(), + ContextServerState::Running { server, .. } => server.clone(), + ContextServerState::Stopped { server, .. } => server.clone(), + } + } + + pub fn configuration(&self) -> Arc { + match self { + ContextServerState::Starting { configuration, .. } => configuration.clone(), + ContextServerState::Running { configuration, .. } => configuration.clone(), + ContextServerState::Stopped { configuration, .. } => configuration.clone(), + } + } +} + +pub type ContextServerFactory = + Box) -> Arc>; + +pub struct ContextServerStore { + servers: HashMap, + worktree_store: Entity, + registry: Entity, + update_servers_task: Option>>, + context_server_factory: Option, + needs_server_update: bool, + _subscriptions: Vec, +} + +pub enum Event { + ServerStatusChanged { + server_id: ContextServerId, + status: ContextServerStatus, + }, +} + +impl EventEmitter for ContextServerStore {} + +impl ContextServerStore { + pub fn new(worktree_store: Entity, cx: &mut Context) -> Self { + Self::new_internal( + true, + None, + ContextServerDescriptorRegistry::default_global(cx), + worktree_store, + cx, + ) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn test( + registry: Entity, + worktree_store: Entity, + cx: &mut Context, + ) -> Self { + Self::new_internal(false, None, registry, worktree_store, cx) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn test_maintain_server_loop( + context_server_factory: ContextServerFactory, + registry: Entity, + worktree_store: Entity, + cx: &mut Context, + ) -> Self { + Self::new_internal( + true, + Some(context_server_factory), + registry, + worktree_store, + cx, + ) + } + + fn new_internal( + maintain_server_loop: bool, + context_server_factory: Option, + registry: Entity, + worktree_store: Entity, + cx: &mut Context, + ) -> Self { + let subscriptions = if maintain_server_loop { + vec![ + cx.observe(®istry, |this, _registry, cx| { + this.available_context_servers_changed(cx); + }), + cx.observe_global::(|this, cx| { + this.available_context_servers_changed(cx); + }), + ] + } else { + Vec::new() + }; + + let mut this = Self { + _subscriptions: subscriptions, + worktree_store, + registry, + needs_server_update: false, + servers: HashMap::default(), + update_servers_task: None, + context_server_factory, + }; + if maintain_server_loop { + this.available_context_servers_changed(cx); + } + this + } + + pub fn get_server(&self, id: &ContextServerId) -> Option> { + self.servers.get(id).map(|state| state.server()) + } + + pub fn get_running_server(&self, id: &ContextServerId) -> Option> { + if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) { + Some(server.clone()) + } else { + None + } + } + + pub fn status_for_server(&self, id: &ContextServerId) -> Option { + self.servers.get(id).map(ContextServerStatus::from_state) + } + + pub fn all_server_ids(&self) -> Vec { + self.servers.keys().cloned().collect() + } + + pub fn running_servers(&self) -> Vec> { + self.servers + .values() + .filter_map(|state| { + if let ContextServerState::Running { server, .. } = state { + Some(server.clone()) + } else { + None + } + }) + .collect() + } + + pub fn start_server( + &mut self, + server: Arc, + cx: &mut Context, + ) -> Result<()> { + let location = self + .worktree_store + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| settings::SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: Path::new(""), + }); + let settings = ProjectSettings::get(location, cx); + let configuration = settings + .context_servers + .get(&server.id().0) + .context("Failed to load context server configuration from settings")? + .clone(); + + self.run_server(server, Arc::new(configuration), cx); + Ok(()) + } + + pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context) -> Result<()> { + let Some(state) = self.servers.remove(id) else { + return Err(anyhow::anyhow!("Context server not found")); + }; + + let server = state.server(); + let configuration = state.configuration(); + let mut result = Ok(()); + if let ContextServerState::Running { server, .. } = &state { + result = server.stop(); + } + drop(state); + + self.update_server_state( + id.clone(), + ContextServerState::Stopped { + configuration, + server, + error: None, + }, + cx, + ); + + result + } + + pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context) -> Result<()> { + if let Some(state) = self.servers.get(&id) { + let configuration = state.configuration(); + + self.stop_server(&state.server().id(), cx)?; + let new_server = self.create_context_server(id.clone(), configuration.clone())?; + self.run_server(new_server, configuration, cx); + } + Ok(()) + } + + fn run_server( + &mut self, + server: Arc, + configuration: Arc, + cx: &mut Context, + ) { + let id = server.id(); + if matches!( + self.servers.get(&id), + Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. }) + ) { + self.stop_server(&id, cx).log_err(); + } + + let task = cx.spawn({ + let id = server.id(); + let server = server.clone(); + let configuration = configuration.clone(); + async move |this, cx| { + match server.clone().start(&cx).await { + Ok(_) => { + log::info!("Started {} context server", id); + debug_assert!(server.client().is_some()); + + this.update(cx, |this, cx| { + this.update_server_state( + id.clone(), + ContextServerState::Running { + server, + configuration, + }, + cx, + ) + }) + .log_err() + } + Err(err) => { + log::error!("{} context server failed to start: {}", id, err); + this.update(cx, |this, cx| { + this.update_server_state( + id.clone(), + ContextServerState::Stopped { + configuration, + server, + error: Some(err.to_string().into()), + }, + cx, + ) + }) + .log_err() + } + }; + } + }); + + self.update_server_state( + id.clone(), + ContextServerState::Starting { + configuration, + _task: task, + server, + }, + cx, + ); + } + + fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context) -> Result<()> { + let Some(state) = self.servers.remove(id) else { + return Err(anyhow::anyhow!("Context server not found")); + }; + drop(state); + cx.emit(Event::ServerStatusChanged { + server_id: id.clone(), + status: ContextServerStatus::Stopped, + }); + Ok(()) + } + + fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool { + // Command must be some when we are running in stdio mode. + self.context_server_factory.as_ref().is_some() || configuration.command.is_some() + } + + fn create_context_server( + &self, + id: ContextServerId, + configuration: Arc, + ) -> Result> { + if let Some(factory) = self.context_server_factory.as_ref() { + Ok(factory(id, configuration)) + } else { + let command = configuration + .command + .clone() + .context("Missing command to run context server")?; + Ok(Arc::new(ContextServer::stdio(id, command))) + } + } + + fn update_server_state( + &mut self, + id: ContextServerId, + state: ContextServerState, + cx: &mut Context, + ) { + let status = ContextServerStatus::from_state(&state); + self.servers.insert(id.clone(), state); + cx.emit(Event::ServerStatusChanged { + server_id: id, + status, + }); + } + + fn available_context_servers_changed(&mut self, cx: &mut Context) { + if self.update_servers_task.is_some() { + self.needs_server_update = true; + } else { + self.needs_server_update = false; + self.update_servers_task = Some(cx.spawn(async move |this, cx| { + if let Err(err) = Self::maintain_servers(this.clone(), cx).await { + log::error!("Error maintaining context servers: {}", err); + } + + this.update(cx, |this, cx| { + this.update_servers_task.take(); + if this.needs_server_update { + this.available_context_servers_changed(cx); + } + })?; + + Ok(()) + })); + } + } + + async fn maintain_servers(this: WeakEntity, cx: &mut AsyncApp) -> Result<()> { + let mut desired_servers = HashMap::default(); + + let (registry, worktree_store) = this.update(cx, |this, cx| { + let location = this + .worktree_store + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| settings::SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: Path::new(""), + }); + let settings = ProjectSettings::get(location, cx); + desired_servers = settings.context_servers.clone(); + + (this.registry.clone(), this.worktree_store.clone()) + })?; + + for (id, descriptor) in + registry.read_with(cx, |registry, _| registry.context_server_descriptors())? + { + let config = desired_servers.entry(id.clone()).or_default(); + if config.command.is_none() { + if let Some(extension_command) = descriptor + .command(worktree_store.clone(), &cx) + .await + .log_err() + { + config.command = Some(extension_command); + } + } + } + + this.update(cx, |this, _| { + // Filter out configurations without commands, the user uninstalled an extension. + desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration)); + })?; + + let mut servers_to_start = Vec::new(); + let mut servers_to_remove = HashSet::default(); + let mut servers_to_stop = HashSet::default(); + + this.update(cx, |this, _cx| { + for server_id in this.servers.keys() { + // All servers that are not in desired_servers should be removed from the store. + // E.g. this can happen if the user removed a server from the configuration, + // or the user uninstalled an extension. + if !desired_servers.contains_key(&server_id.0) { + servers_to_remove.insert(server_id.clone()); + } + } + + for (id, config) in desired_servers { + let id = ContextServerId(id.clone()); + + let existing_config = this.servers.get(&id).map(|state| state.configuration()); + if existing_config.as_deref() != Some(&config) { + let config = Arc::new(config); + if let Some(server) = this + .create_context_server(id.clone(), config.clone()) + .log_err() + { + servers_to_start.push((server, config)); + if this.servers.contains_key(&id) { + servers_to_stop.insert(id); + } + } + } + } + })?; + + for id in servers_to_stop { + this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?; + } + + for id in servers_to_remove { + this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?; + } + + for (server, config) in servers_to_start { + this.update(cx, |this, cx| this.run_server(server, config, cx)) + .log_err(); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FakeFs, Project, project_settings::ProjectSettings}; + use context_server::{ + transport::Transport, + types::{ + self, Implementation, InitializeResponse, ProtocolVersion, RequestType, + ServerCapabilities, + }, + }; + use futures::{Stream, StreamExt as _, lock::Mutex}; + use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _}; + use serde_json::json; + use std::{cell::RefCell, pin::Pin, rc::Rc}; + use util::path; + + #[gpui::test] + async fn test_context_server_status(cx: &mut TestAppContext) { + const SERVER_1_ID: &'static str = "mcp-1"; + const SERVER_2_ID: &'static str = "mcp-2"; + + let (_fs, project) = setup_context_server_test( + cx, + json!({"code.rs": ""}), + vec![ + (SERVER_1_ID.into(), ContextServerConfiguration::default()), + (SERVER_2_ID.into(), ContextServerConfiguration::default()), + ], + ) + .await; + + let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); + let store = cx.new(|cx| { + ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + }); + + let server_1_id = ContextServerId("mcp-1".into()); + let server_2_id = ContextServerId("mcp-2".into()); + + let transport_1 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-1".to_string())) + } + _ => None, + }, + )); + + let transport_2 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-2".to_string())) + } + _ => None, + }, + )); + + let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); + let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + + store + .update(cx, |store, cx| store.start_server(server_1, cx)) + .unwrap(); + + cx.run_until_parked(); + + cx.update(|cx| { + assert_eq!( + store.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!(store.read(cx).status_for_server(&server_2_id), None); + }); + + store + .update(cx, |store, cx| store.start_server(server_2.clone(), cx)) + .unwrap(); + + cx.run_until_parked(); + + cx.update(|cx| { + assert_eq!( + store.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!( + store.read(cx).status_for_server(&server_2_id), + Some(ContextServerStatus::Running) + ); + }); + + store + .update(cx, |store, cx| store.stop_server(&server_2_id, cx)) + .unwrap(); + + cx.update(|cx| { + assert_eq!( + store.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!( + store.read(cx).status_for_server(&server_2_id), + Some(ContextServerStatus::Stopped) + ); + }); + } + + #[gpui::test] + async fn test_context_server_status_events(cx: &mut TestAppContext) { + const SERVER_1_ID: &'static str = "mcp-1"; + const SERVER_2_ID: &'static str = "mcp-2"; + + let (_fs, project) = setup_context_server_test( + cx, + json!({"code.rs": ""}), + vec![ + (SERVER_1_ID.into(), ContextServerConfiguration::default()), + (SERVER_2_ID.into(), ContextServerConfiguration::default()), + ], + ) + .await; + + let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); + let store = cx.new(|cx| { + ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + }); + + let server_1_id = ContextServerId("mcp-1".into()); + let server_2_id = ContextServerId("mcp-2".into()); + + let transport_1 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-1".to_string())) + } + _ => None, + }, + )); + + let transport_2 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-2".to_string())) + } + _ => None, + }, + )); + + let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); + let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + + let _server_events = assert_server_events( + &store, + vec![ + (server_1_id.clone(), ContextServerStatus::Starting), + (server_1_id.clone(), ContextServerStatus::Running), + (server_2_id.clone(), ContextServerStatus::Starting), + (server_2_id.clone(), ContextServerStatus::Running), + (server_2_id.clone(), ContextServerStatus::Stopped), + ], + cx, + ); + + store + .update(cx, |store, cx| store.start_server(server_1, cx)) + .unwrap(); + + cx.run_until_parked(); + + store + .update(cx, |store, cx| store.start_server(server_2.clone(), cx)) + .unwrap(); + + cx.run_until_parked(); + + store + .update(cx, |store, cx| store.stop_server(&server_2_id, cx)) + .unwrap(); + } + + #[gpui::test(iterations = 25)] + async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) { + const SERVER_1_ID: &'static str = "mcp-1"; + + let (_fs, project) = setup_context_server_test( + cx, + json!({"code.rs": ""}), + vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())], + ) + .await; + + let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); + let store = cx.new(|cx| { + ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + }); + + let server_id = ContextServerId(SERVER_1_ID.into()); + + let transport_1 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response(SERVER_1_ID.to_string())) + } + _ => None, + }, + )); + + let transport_2 = + Arc::new(FakeTransport::new( + cx.executor(), + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response(SERVER_1_ID.to_string())) + } + _ => None, + }, + )); + + let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1)); + let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2)); + + // If we start another server with the same id, we should report that we stopped the previous one + let _server_events = assert_server_events( + &store, + vec![ + (server_id.clone(), ContextServerStatus::Starting), + (server_id.clone(), ContextServerStatus::Stopped), + (server_id.clone(), ContextServerStatus::Starting), + (server_id.clone(), ContextServerStatus::Running), + ], + cx, + ); + + store + .update(cx, |store, cx| { + store.start_server(server_with_same_id_1.clone(), cx) + }) + .unwrap(); + store + .update(cx, |store, cx| { + store.start_server(server_with_same_id_2.clone(), cx) + }) + .unwrap(); + cx.update(|cx| { + assert_eq!( + store.read(cx).status_for_server(&server_id), + Some(ContextServerStatus::Starting) + ); + }); + + cx.run_until_parked(); + + cx.update(|cx| { + assert_eq!( + store.read(cx).status_for_server(&server_id), + Some(ContextServerStatus::Running) + ); + }); + } + + #[gpui::test] + async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) { + const SERVER_1_ID: &'static str = "mcp-1"; + const SERVER_2_ID: &'static str = "mcp-2"; + + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); + + let (_fs, project) = setup_context_server_test( + cx, + json!({"code.rs": ""}), + vec![( + SERVER_1_ID.into(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": true + })), + }, + )], + ) + .await; + + let executor = cx.executor(); + let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); + let store = cx.new(|cx| { + ContextServerStore::test_maintain_server_loop( + Box::new(move |id, _| { + let transport = FakeTransport::new(executor.clone(), { + let id = id.0.clone(); + move |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response(id.clone().to_string())) + } + _ => None, + } + }); + Arc::new(ContextServer::new(id.clone(), Arc::new(transport))) + }), + registry.clone(), + project.read(cx).worktree_store(), + cx, + ) + }); + + // Ensure that mcp-1 starts up + { + let _server_events = assert_server_events( + &store, + vec![ + (server_1_id.clone(), ContextServerStatus::Starting), + (server_1_id.clone(), ContextServerStatus::Running), + ], + cx, + ); + cx.run_until_parked(); + } + + // Ensure that mcp-1 is restarted when the configuration was changed + { + let _server_events = assert_server_events( + &store, + vec![ + (server_1_id.clone(), ContextServerStatus::Stopped), + (server_1_id.clone(), ContextServerStatus::Starting), + (server_1_id.clone(), ContextServerStatus::Running), + ], + cx, + ); + set_context_server_configuration( + vec![( + server_1_id.0.clone(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": false + })), + }, + )], + cx, + ); + + cx.run_until_parked(); + } + + // Ensure that mcp-1 is not restarted when the configuration was not changed + { + let _server_events = assert_server_events(&store, vec![], cx); + set_context_server_configuration( + vec![( + server_1_id.0.clone(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": false + })), + }, + )], + cx, + ); + + cx.run_until_parked(); + } + + // Ensure that mcp-2 is started once it is added to the settings + { + let _server_events = assert_server_events( + &store, + vec![ + (server_2_id.clone(), ContextServerStatus::Starting), + (server_2_id.clone(), ContextServerStatus::Running), + ], + cx, + ); + set_context_server_configuration( + vec![ + ( + server_1_id.0.clone(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": false + })), + }, + ), + ( + server_2_id.0.clone(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": true + })), + }, + ), + ], + cx, + ); + + cx.run_until_parked(); + } + + // Ensure that mcp-2 is removed once it is removed from the settings + { + let _server_events = assert_server_events( + &store, + vec![(server_2_id.clone(), ContextServerStatus::Stopped)], + cx, + ); + set_context_server_configuration( + vec![( + server_1_id.0.clone(), + ContextServerConfiguration { + command: None, + settings: Some(json!({ + "somevalue": false + })), + }, + )], + cx, + ); + + cx.run_until_parked(); + + cx.update(|cx| { + assert_eq!(store.read(cx).status_for_server(&server_2_id), None); + }); + } + } + + fn set_context_server_configuration( + context_servers: Vec<(Arc, ContextServerConfiguration)>, + cx: &mut TestAppContext, + ) { + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + let mut settings = ProjectSettings::default(); + for (id, config) in context_servers { + settings.context_servers.insert(id, config); + } + store + .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx) + .unwrap(); + }) + }); + } + + struct ServerEvents { + received_event_count: Rc>, + expected_event_count: usize, + _subscription: Subscription, + } + + impl Drop for ServerEvents { + fn drop(&mut self) { + let actual_event_count = *self.received_event_count.borrow(); + assert_eq!( + actual_event_count, self.expected_event_count, + " + Expected to receive {} context server store events, but received {} events", + self.expected_event_count, actual_event_count + ); + } + } + + fn assert_server_events( + store: &Entity, + expected_events: Vec<(ContextServerId, ContextServerStatus)>, + cx: &mut TestAppContext, + ) -> ServerEvents { + cx.update(|cx| { + let mut ix = 0; + let received_event_count = Rc::new(RefCell::new(0)); + let expected_event_count = expected_events.len(); + let subscription = cx.subscribe(store, { + let received_event_count = received_event_count.clone(); + move |_, event, _| match event { + Event::ServerStatusChanged { + server_id: actual_server_id, + status: actual_status, + } => { + let (expected_server_id, expected_status) = &expected_events[ix]; + + assert_eq!( + actual_server_id, expected_server_id, + "Expected different server id at index {}", + ix + ); + assert_eq!( + actual_status, expected_status, + "Expected different status at index {}", + ix + ); + ix += 1; + *received_event_count.borrow_mut() += 1; + } + } + }); + ServerEvents { + expected_event_count, + received_event_count, + _subscription: subscription, + } + }) + } + + async fn setup_context_server_test( + cx: &mut TestAppContext, + files: serde_json::Value, + context_server_configurations: Vec<(Arc, ContextServerConfiguration)>, + ) -> (Arc, Entity) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + let mut settings = ProjectSettings::get_global(cx).clone(); + for (id, config) in context_server_configurations { + settings.context_servers.insert(id, config); + } + ProjectSettings::override_global(settings, cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), files).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + + (fs, project) + } + + fn create_initialize_response(server_name: String) -> serde_json::Value { + serde_json::to_value(&InitializeResponse { + protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), + server_info: Implementation { + name: server_name, + version: "1.0.0".to_string(), + }, + capabilities: ServerCapabilities::default(), + meta: None, + }) + .unwrap() + } + + struct FakeTransport { + on_request: Arc< + dyn Fn(u64, Option, serde_json::Value) -> Option + + Send + + Sync, + >, + tx: futures::channel::mpsc::UnboundedSender, + rx: Arc>>, + executor: BackgroundExecutor, + } + + impl FakeTransport { + fn new( + executor: BackgroundExecutor, + on_request: impl Fn( + u64, + Option, + serde_json::Value, + ) -> Option + + 'static + + Send + + Sync, + ) -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { + on_request: Arc::new(on_request), + tx, + rx: Arc::new(Mutex::new(rx)), + executor, + } + } + } + + #[async_trait::async_trait] + impl Transport for FakeTransport { + async fn send(&self, message: String) -> Result<()> { + if let Ok(msg) = serde_json::from_str::(&message) { + let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); + + if let Some(method) = msg.get("method") { + let request_type = method + .as_str() + .and_then(|method| types::RequestType::try_from(method).ok()); + if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": payload + }); + + self.tx + .unbounded_send(response.to_string()) + .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?; + } + } + } + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + let rx = self.rx.clone(); + let executor = self.executor.clone(); + Box::pin(futures::stream::unfold(rx, move |rx| { + let executor = executor.clone(); + async move { + let mut rx_guard = rx.lock().await; + executor.simulate_random_delay().await; + if let Some(message) = rx_guard.next().await { + drop(rx_guard); + Some((message, rx)) + } else { + None + } + } + })) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(futures::stream::empty()) + } + } +} diff --git a/crates/context_server/src/extension_context_server.rs b/crates/project/src/context_server_store/extension.rs similarity index 74% rename from crates/context_server/src/extension_context_server.rs rename to crates/project/src/context_server_store/extension.rs index 1fb138d56f..825ee0b678 100644 --- a/crates/context_server/src/extension_context_server.rs +++ b/crates/project/src/context_server_store/extension.rs @@ -1,19 +1,21 @@ use std::sync::Arc; use anyhow::Result; +use context_server::ContextServerCommand; use extension::{ ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate, }; use gpui::{App, AsyncApp, Entity, Task}; -use project::Project; -use crate::{ContextServerDescriptorRegistry, ServerCommand, registry}; +use crate::worktree_store::WorktreeStore; + +use super::registry::{self, ContextServerDescriptorRegistry}; pub fn init(cx: &mut App) { let proxy = ExtensionHostProxy::default_global(cx); proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy { - context_server_factory_registry: ContextServerDescriptorRegistry::global(cx), + context_server_factory_registry: ContextServerDescriptorRegistry::default_global(cx), }); } @@ -32,10 +34,13 @@ struct ContextServerDescriptor { extension: Arc, } -fn extension_project(project: Entity, cx: &mut AsyncApp) -> Result> { - project.update(cx, |project, cx| { +fn extension_project( + worktree_store: Entity, + cx: &mut AsyncApp, +) -> Result> { + worktree_store.update(cx, |worktree_store, cx| { Arc::new(ExtensionProject { - worktree_ids: project + worktree_ids: worktree_store .visible_worktrees(cx) .map(|worktree| worktree.read(cx).id().to_proto()) .collect(), @@ -44,11 +49,15 @@ fn extension_project(project: Entity, cx: &mut AsyncApp) -> Result, cx: &AsyncApp) -> Task> { + fn command( + &self, + worktree_store: Entity, + cx: &AsyncApp, + ) -> Task> { let id = self.id.clone(); let extension = self.extension.clone(); cx.spawn(async move |cx| { - let extension_project = extension_project(project, cx)?; + let extension_project = extension_project(worktree_store, cx)?; let mut command = extension .context_server_command(id.clone(), extension_project.clone()) .await?; @@ -59,7 +68,7 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor { log::info!("loaded command for context server {id}: {command:?}"); - Ok(ServerCommand { + Ok(ContextServerCommand { path: command.command, args: command.args, env: Some(command.env.into_iter().collect()), @@ -69,13 +78,13 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor { fn configuration( &self, - project: Entity, + worktree_store: Entity, cx: &AsyncApp, ) -> Task>> { let id = self.id.clone(); let extension = self.extension.clone(); cx.spawn(async move |cx| { - let extension_project = extension_project(project, cx)?; + let extension_project = extension_project(worktree_store, cx)?; let configuration = extension .context_server_configuration(id.clone(), extension_project) .await?; @@ -102,4 +111,11 @@ impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy { ) }); } + + fn unregister_context_server(&self, server_id: Arc, cx: &mut App) { + self.context_server_factory_registry + .update(cx, |registry, _| { + registry.unregister_context_server_descriptor_by_id(&server_id) + }); + } } diff --git a/crates/context_server/src/registry.rs b/crates/project/src/context_server_store/registry.rs similarity index 82% rename from crates/context_server/src/registry.rs rename to crates/project/src/context_server_store/registry.rs index 96fd7459ef..972ec6642d 100644 --- a/crates/context_server/src/registry.rs +++ b/crates/project/src/context_server_store/registry.rs @@ -2,17 +2,21 @@ use std::sync::Arc; use anyhow::Result; use collections::HashMap; +use context_server::ContextServerCommand; use extension::ContextServerConfiguration; -use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task}; -use project::Project; +use gpui::{App, AppContext as _, AsyncApp, Entity, Global, Task}; -use crate::ServerCommand; +use crate::worktree_store::WorktreeStore; pub trait ContextServerDescriptor { - fn command(&self, project: Entity, cx: &AsyncApp) -> Task>; + fn command( + &self, + worktree_store: Entity, + cx: &AsyncApp, + ) -> Task>; fn configuration( &self, - project: Entity, + worktree_store: Entity, cx: &AsyncApp, ) -> Task>>; } @@ -27,11 +31,6 @@ pub struct ContextServerDescriptorRegistry { } impl ContextServerDescriptorRegistry { - /// Returns the global [`ContextServerDescriptorRegistry`]. - pub fn global(cx: &App) -> Entity { - GlobalContextServerDescriptorRegistry::global(cx).0.clone() - } - /// Returns the global [`ContextServerDescriptorRegistry`]. /// /// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist. diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index 531cd35b6c..363a7b7d7d 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -244,9 +244,8 @@ mod tests { use git::status::{FileStatus, StatusCode, TrackedSummary, UnmergedStatus, UnmergedStatusCode}; use gpui::TestAppContext; use serde_json::json; - use settings::{Settings as _, SettingsStore}; + use settings::SettingsStore; use util::path; - use worktree::WorktreeSettings; const CONFLICT: FileStatus = FileStatus::Unmerged(UnmergedStatus { first_head: UnmergedStatusCode::Updated, @@ -682,7 +681,7 @@ mod tests { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); - WorktreeSettings::register(cx); + Project::init_settings(cx); }); } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 7d5adc5a69..e2bfc8f004 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1,6 +1,7 @@ pub mod buffer_store; mod color_extractor; pub mod connection_manager; +pub mod context_server_store; pub mod debounced_delay; pub mod debugger; pub mod git_store; @@ -23,6 +24,7 @@ mod project_tests; mod direnv; mod environment; use buffer_diff::BufferDiff; +use context_server_store::ContextServerStore; pub use environment::{EnvironmentErrorMessage, ProjectEnvironmentEvent}; use git_store::{Repository, RepositoryId}; pub mod search_history; @@ -182,6 +184,7 @@ pub struct Project { client_subscriptions: Vec, worktree_store: Entity, buffer_store: Entity, + context_server_store: Entity, image_store: Entity, lsp_store: Entity, _subscriptions: Vec, @@ -845,6 +848,7 @@ impl Project { ToolchainStore::init(&client); DapStore::init(&client, cx); BreakpointStore::init(&client); + context_server_store::init(cx); } pub fn local( @@ -865,6 +869,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + let environment = cx.new(|_| ProjectEnvironment::new(env)); let toolchain_store = cx.new(|cx| { ToolchainStore::local( @@ -965,6 +972,7 @@ impl Project { buffer_store, image_store, lsp_store, + context_server_store, join_project_response_message_id: 0, client_state: ProjectClientState::Local, git_store, @@ -1025,6 +1033,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + let buffer_store = cx.new(|cx| { BufferStore::remote( worktree_store.clone(), @@ -1109,6 +1120,7 @@ impl Project { buffer_store, image_store, lsp_store, + context_server_store, breakpoint_store, dap_store, join_project_response_message_id: 0, @@ -1267,6 +1279,8 @@ impl Project { let image_store = cx.new(|cx| { ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx) })?; + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?; let environment = cx.new(|_| ProjectEnvironment::new(None))?; @@ -1360,6 +1374,7 @@ impl Project { image_store, worktree_store: worktree_store.clone(), lsp_store: lsp_store.clone(), + context_server_store, active_entry: None, collaborators: Default::default(), join_project_response_message_id: response.message_id, @@ -1590,6 +1605,10 @@ impl Project { self.worktree_store.clone() } + pub fn context_server_store(&self) -> Entity { + self.context_server_store.clone() + } + pub fn buffer_for_id(&self, remote_id: BufferId, cx: &App) -> Option> { self.buffer_store.read(cx).get(remote_id) } diff --git a/crates/project/src/project_settings.rs b/crates/project/src/project_settings.rs index 857293262f..4eb25699d4 100644 --- a/crates/project/src/project_settings.rs +++ b/crates/project/src/project_settings.rs @@ -1,5 +1,6 @@ use anyhow::Context as _; use collections::HashMap; +use context_server::ContextServerCommand; use dap::adapters::DebugAdapterName; use fs::Fs; use futures::StreamExt as _; @@ -51,6 +52,10 @@ pub struct ProjectSettings { #[serde(default)] pub dap: HashMap, + /// Settings for context servers used for AI-related features. + #[serde(default)] + pub context_servers: HashMap, ContextServerConfiguration>, + /// Configuration for Diagnostics-related features. #[serde(default)] pub diagnostics: DiagnosticsSettings, @@ -78,6 +83,19 @@ pub struct DapSettings { pub binary: Option, } +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)] +pub struct ContextServerConfiguration { + /// The command to run this context server. + /// + /// This will override the command set by an extension. + pub command: Option, + /// The settings for this context server. + /// + /// Consult the documentation for the context server to see what settings + /// are supported. + pub settings: Option, +} + #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct NodeBinarySettings { /// The path to the Node binary. @@ -354,6 +372,40 @@ impl Settings for ProjectSettings { } } + #[derive(Deserialize)] + struct VsCodeContextServerCommand { + command: String, + args: Option>, + env: Option>, + // note: we don't support envFile and type + } + impl From for ContextServerCommand { + fn from(cmd: VsCodeContextServerCommand) -> Self { + Self { + path: cmd.command, + args: cmd.args.unwrap_or_default(), + env: cmd.env, + } + } + } + if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) { + current + .context_servers + .extend(mcp.iter().filter_map(|(k, v)| { + Some(( + k.clone().into(), + ContextServerConfiguration { + command: Some( + serde_json::from_value::(v.clone()) + .ok()? + .into(), + ), + settings: None, + }, + )) + })); + } + // TODO: translate lsp settings for rust-analyzer and other popular ones to old.lsp } }