From 24eb039752d4ffc1aaf222e80180b3b11a3cab9b Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Thu, 1 May 2025 20:02:14 +0200 Subject: [PATCH] context servers: Show configuration modal when extension is installed (#29309) WIP Release Notes: - N/A --------- Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Co-authored-by: Danilo Leal Co-authored-by: Marshall Bowers Co-authored-by: Cole Miller Co-authored-by: Antonio Scandurra Co-authored-by: Oleksiy Syvokon --- Cargo.lock | 150 +++++- Cargo.toml | 1 + assets/keymaps/default-linux.json | 8 + assets/keymaps/default-macos.json | 9 + crates/agent/Cargo.toml | 4 + crates/agent/src/assistant.rs | 4 + crates/agent/src/assistant_configuration.rs | 330 +++++++------ .../configure_context_server_modal.rs | 443 ++++++++++++++++++ .../agent/src/context_server_configuration.rs | 120 +++++ crates/agent/src/thread_store.rs | 112 ++--- .../src/context_store.rs | 100 ++-- crates/context_server/Cargo.toml | 4 + crates/context_server/src/client.rs | 25 +- crates/context_server/src/context_server.rs | 4 +- .../src/extension_context_server.rs | 128 +++-- crates/context_server/src/manager.rs | 389 ++++++++++++--- crates/context_server/src/registry.rs | 65 ++- crates/context_server/src/types.rs | 28 +- crates/eval/src/eval.rs | 8 +- crates/extension/src/extension.rs | 6 + crates/extension/src/extension_events.rs | 6 + crates/extension/src/types.rs | 2 + crates/extension/src/types/context_server.rs | 10 + crates/extension_api/src/extension_api.rs | 18 + .../wit/since_v0.5.0/context-server.wit | 11 + .../wit/since_v0.5.0/extension.wit | 5 + crates/extension_host/src/extension_host.rs | 32 +- crates/extension_host/src/wasm_host.rs | 32 +- crates/extension_host/src/wasm_host/wit.rs | 24 + .../src/wasm_host/wit/since_v0_5_0.rs | 18 + crates/extensions_ui/Cargo.toml | 1 + crates/extensions_ui/src/extensions_ui.rs | 167 +++++-- crates/project/src/lsp_store.rs | 12 +- crates/zed/src/main.rs | 1 + tooling/workspace-hack/Cargo.toml | 26 +- 35 files changed, 1866 insertions(+), 437 deletions(-) create mode 100644 crates/agent/src/assistant_configuration/configure_context_server_modal.rs create mode 100644 crates/agent/src/context_server_configuration.rs create mode 100644 crates/extension/src/types/context_server.rs create mode 100644 crates/extension_api/wit/since_v0.5.0/context-server.wit diff --git a/Cargo.lock b/Cargo.lock index c13e488081..6fed6289e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,7 @@ dependencies = [ "convert_case 0.8.0", "db", "editor", + "extension", "feature_flags", "file_icons", "fs", @@ -81,6 +82,7 @@ dependencies = [ "indexmap", "indoc", "itertools 0.14.0", + "jsonschema", "language", "language_model", "language_model_selector", @@ -90,6 +92,7 @@ dependencies = [ "markdown", "menu", "multi_buffer", + "notifications", "ordered-float 2.10.1", "parking_lot", "paths", @@ -106,6 +109,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "serde_json_lenient", "settings", "smallvec", "smol", @@ -148,7 +152,9 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", + "getrandom 0.2.15", "once_cell", + "serde", "version_check", "zerocopy 0.7.35", ] @@ -2186,6 +2192,12 @@ dependencies = [ "piper", ] +[[package]] +name = "borrow-or-share" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32" + [[package]] name = "borsh" version = "1.5.7" @@ -2301,6 +2313,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + [[package]] name = "bytemuck" version = "1.22.0" @@ -4783,6 +4801,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "embed-resource" version = "3.0.2" @@ -5198,6 +5225,7 @@ dependencies = [ "collections", "db", "editor", + "extension", "extension_host", "fs", "fuzzy", @@ -5430,6 +5458,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8" +[[package]] +name = "fluent-uri" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" +dependencies = [ + "borrow-or-share", + "ref-cast", + "serde", +] + [[package]] name = "flume" version = "0.11.1" @@ -5584,6 +5623,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "freetype-sys" version = "0.20.1" @@ -7587,6 +7636,33 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonschema" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d" +dependencies = [ + "ahash 0.8.11", + "base64 0.22.1", + "bytecount", + "email_address", + "fancy-regex 0.14.0", + "fraction", + "idna", + "itoa", + "num-cmp", + "num-traits", + "once_cell", + "percent-encoding", + "referencing", + "regex", + "regex-syntax 0.8.5", + "reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)", + "serde", + "serde_json", + "uuid-simd", +] + [[package]] name = "jsonwebtoken" version = "9.3.1" @@ -8271,7 +8347,7 @@ dependencies = [ "prost 0.9.0", "prost-build 0.9.0", "prost-types 0.9.0", - "reqwest 0.12.15", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", "workspace-hack", ] @@ -9181,6 +9257,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + [[package]] name = "num-complex" version = "0.4.6" @@ -11774,6 +11856,20 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "referencing" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e" +dependencies = [ + "ahash 0.8.11", + "fluent-uri", + "once_cell", + "parking_lot", + "percent-encoding", + "serde_json", +] + [[package]] name = "refineable" version = "0.1.0" @@ -12043,6 +12139,43 @@ dependencies = [ "winreg 0.50.0", ] +[[package]] +name = "reqwest" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +dependencies = [ + "base64 0.22.1", + "bytes 1.10.1", + "futures-channel", + "futures-core", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower 0.5.2", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows-registry 0.4.0", +] + [[package]] name = "reqwest" version = "0.12.15" @@ -12103,7 +12236,7 @@ dependencies = [ "http_client_tls", "log", "regex", - "reqwest 0.12.15", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", "smol", "tokio", @@ -15954,6 +16087,17 @@ dependencies = [ "sha1_smol", ] +[[package]] +name = "uuid-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8" +dependencies = [ + "outref", + "uuid", + "vsimd", +] + [[package]] name = "v_frame" version = "0.3.8" @@ -18054,6 +18198,7 @@ dependencies = [ "hmac", "hyper 0.14.32", "hyper-rustls 0.27.5", + "idna", "indexmap", "inout", "itertools 0.12.1", @@ -18077,6 +18222,7 @@ dependencies = [ "num-bigint-dig", "num-integer", "num-iter", + "num-rational", "num-traits", "object", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 743beb6ff0..1fc04f838e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -462,6 +462,7 @@ indexmap = { version = "2.7.0", features = ["serde"] } indoc = "2" inventory = "0.3.19" itertools = "0.14.0" +jsonschema = "0.30.0" jsonwebtoken = "9.3" jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" } diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 786a2a346a..d3e02e5952 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -963,6 +963,14 @@ "escape": "menu::Cancel" } }, + { + "context": "ConfigureContextServerModal > Editor", + "bindings": { + "escape": "menu::Cancel", + "enter": "editor::Newline", + "ctrl-enter": "menu::Confirm" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index b6214f461f..b252617b47 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -1069,6 +1069,15 @@ "escape": "menu::Cancel" } }, + { + "context": "ConfigureContextServerModal > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "editor::Newline", + "cmd-enter": "menu::Confirm" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 83562a321b..8594551e80 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -35,6 +35,7 @@ context_server.workspace = true convert_case.workspace = true db.workspace = true editor.workspace = true +extension.workspace = true feature_flags.workspace = true file_icons.workspace = true fs.workspace = true @@ -47,6 +48,7 @@ html_to_markdown.workspace = true http_client.workspace = true indexmap.workspace = true itertools.workspace = true +jsonschema.workspace = true language.workspace = true language_model.workspace = true language_model_selector.workspace = true @@ -56,6 +58,7 @@ lsp.workspace = true markdown.workspace = true menu.workspace = true multi_buffer.workspace = true +notifications.workspace = true ordered-float.workspace = true parking_lot.workspace = true paths.workspace = true @@ -71,6 +74,7 @@ rope.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +serde_json_lenient.workspace = true settings.workspace = true smallvec.workspace = true smol.workspace = true diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 3c9b915b87..7c20501595 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -6,6 +6,7 @@ mod assistant_panel; mod buffer_codegen; mod context; mod context_picker; +mod context_server_configuration; mod context_store; mod context_strip; mod history_store; @@ -30,6 +31,7 @@ use command_palette_hooks::CommandPaletteFilter; use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt}; use fs::Fs; use gpui::{App, actions, impl_actions}; +use language::LanguageRegistry; use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::Deserialize; @@ -107,11 +109,13 @@ pub fn init( fs: Arc, client: Arc, prompt_builder: Arc, + language_registry: Arc, cx: &mut App, ) { AssistantSettings::register(cx); thread_store::init(cx); assistant_panel::init(cx); + context_server_configuration::init(language_registry, cx); inline_assistant::init( fs.clone(), diff --git a/crates/agent/src/assistant_configuration.rs b/crates/agent/src/assistant_configuration.rs index 5c52dd2b36..a399641091 100644 --- a/crates/agent/src/assistant_configuration.rs +++ b/crates/agent/src/assistant_configuration.rs @@ -1,16 +1,18 @@ mod add_context_server_modal; +mod configure_context_server_modal; mod manage_profiles_modal; mod tool_picker; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use assistant_settings::AssistantSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; use collections::HashMap; -use context_server::manager::ContextServerManager; +use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus}; use fs::Fs; use gpui::{ - Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription, + Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle, + Focusable, ScrollHandle, Subscription, pulsating_between, }; use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; use settings::{Settings, update_settings_file}; @@ -22,6 +24,7 @@ use util::ResultExt as _; use zed_actions::ExtensionCategoryFilter; pub(crate) use add_context_server_modal::AddContextServerModal; +pub(crate) use configure_context_server_modal::ConfigureContextServerModal; pub(crate) use manage_profiles_modal::ManageProfilesModal; use crate::AddContextServer; @@ -256,8 +259,6 @@ impl AssistantConfiguration { fn render_context_servers_section(&mut self, cx: &mut Context) -> impl IntoElement { let context_servers = self.context_server_manager.read(cx).all_servers().clone(); - let tools_by_source = self.tools.read(cx).tools_by_source(cx); - let empty = Vec::new(); const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly."; @@ -272,136 +273,11 @@ impl AssistantConfiguration { .child(Headline::new("Model Context Protocol (MCP) Servers")) .child(Label::new(SUBHEADING).color(Color::Muted)), ) - .children(context_servers.into_iter().map(|context_server| { - let is_running = context_server.client().is_some(); - let are_tools_expanded = self - .expanded_context_server_tools - .get(&context_server.id()) - .copied() - .unwrap_or_default(); - - let tools = tools_by_source - .get(&ToolSource::ContextServer { - id: context_server.id().into(), - }) - .unwrap_or_else(|| &empty); - let tool_count = tools.len(); - - v_flex() - .id(SharedString::from(context_server.id())) - .border_1() - .rounded_md() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().background.opacity(0.25)) - .child( - h_flex() - .p_1() - .justify_between() - .when(are_tools_expanded && tool_count > 1, |element| { - element - .border_b_1() - .border_color(cx.theme().colors().border) - }) - .child( - h_flex() - .gap_2() - .child( - Disclosure::new("tool-list-disclosure", are_tools_expanded) - .disabled(tool_count == 0) - .on_click(cx.listener({ - let context_server_id = context_server.id(); - move |this, _event, _window, _cx| { - let is_open = this - .expanded_context_server_tools - .entry(context_server_id.clone()) - .or_insert(false); - - *is_open = !*is_open; - } - })), - ) - .child(Indicator::dot().color(if is_running { - Color::Success - } else { - Color::Error - })) - .child(Label::new(context_server.id())) - .child( - Label::new(format!("{tool_count} tools")) - .color(Color::Muted) - .size(LabelSize::Small), - ), - ) - .child( - Switch::new("context-server-switch", is_running.into()) - .color(SwitchColor::Accent) - .on_click({ - let context_server_manager = - self.context_server_manager.clone(); - let context_server = context_server.clone(); - move |state, _window, cx| match state { - ToggleState::Unselected - | ToggleState::Indeterminate => { - context_server_manager.update(cx, |this, cx| { - this.stop_server(context_server.clone(), cx) - .log_err(); - }); - } - ToggleState::Selected => { - cx.spawn({ - let context_server_manager = - context_server_manager.clone(); - let context_server = context_server.clone(); - async move |cx| { - if let Some(start_server_task) = - context_server_manager - .update(cx, |this, cx| { - this.start_server( - context_server, - cx, - ) - }) - .log_err() - { - start_server_task.await.log_err(); - } - } - }) - .detach(); - } - } - }), - ), - ) - .map(|parent| { - if !are_tools_expanded { - return parent; - } - - parent.child(v_flex().py_1p5().px_1().gap_1().children( - tools.into_iter().enumerate().map(|(ix, tool)| { - h_flex() - .id(("tool-item", ix)) - .px_1() - .gap_2() - .justify_between() - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .rounded_sm() - .child( - Label::new(tool.name()) - .buffer_font(cx) - .size(LabelSize::Small), - ) - .child( - Icon::new(IconName::Info) - .size(IconSize::Small) - .color(Color::Ignored), - ) - .tooltip(Tooltip::text(tool.description())) - }), - )) - }) - })) + .children( + context_servers + .into_iter() + .map(|context_server| self.render_context_server(context_server, cx)), + ) .child( h_flex() .justify_between() @@ -447,6 +323,190 @@ impl AssistantConfiguration { ), ) } + + fn render_context_server( + &self, + context_server: Arc, + cx: &mut Context, + ) -> impl use<> + IntoElement { + let tools_by_source = self.tools.read(cx).tools_by_source(cx); + let server_status = self + .context_server_manager + .read(cx) + .status_for_server(&context_server.id()); + + let is_running = matches!(server_status, Some(ContextServerStatus::Running)); + + let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() { + Some(error) + } else { + None + }; + + let are_tools_expanded = self + .expanded_context_server_tools + .get(&context_server.id()) + .copied() + .unwrap_or_default(); + + let tools = tools_by_source + .get(&ToolSource::ContextServer { + id: context_server.id().into(), + }) + .map_or([].as_slice(), |tools| tools.as_slice()); + let tool_count = tools.len(); + + v_flex() + .id(SharedString::from(context_server.id())) + .border_1() + .rounded_md() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background.opacity(0.25)) + .child( + h_flex() + .p_1() + .justify_between() + .when(are_tools_expanded && tool_count > 1, |element| { + element + .border_b_1() + .border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .gap_2() + .child( + Disclosure::new( + "tool-list-disclosure", + are_tools_expanded || error.is_some(), + ) + .disabled(tool_count == 0) + .on_click(cx.listener({ + let context_server_id = context_server.id(); + move |this, _event, _window, _cx| { + let is_open = this + .expanded_context_server_tools + .entry(context_server_id.clone()) + .or_insert(false); + + *is_open = !*is_open; + } + })), + ) + .child(match server_status { + Some(ContextServerStatus::Starting) => { + let color = Color::Success.color(cx); + Indicator::dot() + .color(Color::Success) + .with_animation( + SharedString::from(format!( + "{}-starting", + context_server.id(), + )), + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 1.)), + move |this, delta| { + this.color(color.alpha(delta).into()) + }, + ) + .into_any_element() + } + Some(ContextServerStatus::Running) => { + Indicator::dot().color(Color::Success).into_any_element() + } + Some(ContextServerStatus::Error(_)) => { + Indicator::dot().color(Color::Error).into_any_element() + } + None => Indicator::dot().color(Color::Muted).into_any_element(), + }) + .child(Label::new(context_server.id())) + .when(is_running, |this| { + this.child( + Label::new(if tool_count == 1 { + SharedString::from("1 tool") + } else { + SharedString::from(format!("{} tools", tool_count)) + }) + .color(Color::Muted) + .size(LabelSize::Small), + ) + }), + ) + .child( + Switch::new("context-server-switch", is_running.into()) + .color(SwitchColor::Accent) + .on_click({ + let context_server_manager = self.context_server_manager.clone(); + let context_server = context_server.clone(); + move |state, _window, cx| match state { + ToggleState::Unselected | ToggleState::Indeterminate => { + context_server_manager.update(cx, |this, cx| { + this.stop_server(context_server.clone(), cx).log_err(); + }); + } + ToggleState::Selected => { + cx.spawn({ + let context_server_manager = + context_server_manager.clone(); + let context_server = context_server.clone(); + async move |cx| { + if let Some(start_server_task) = + context_server_manager + .update(cx, |this, cx| { + this.start_server(context_server, cx) + }) + .log_err() + { + start_server_task.await.log_err(); + } + } + }) + .detach(); + } + } + }), + ), + ) + .map(|parent| { + if let Some(error) = error { + return parent.child( + div().py_1p5().px_2().child( + Label::new(error) + .color(Color::Muted) + .buffer_font(cx) + .size(LabelSize::Small), + ), + ); + } + + if !are_tools_expanded || tools.is_empty() { + return parent; + } + + parent.child(v_flex().py_1p5().px_1().gap_1().children( + tools.into_iter().enumerate().map(|(ix, tool)| { + h_flex() + .id(("tool-item", ix)) + .px_1() + .gap_2() + .justify_between() + .hover(|style| style.bg(cx.theme().colors().element_hover)) + .rounded_sm() + .child( + Label::new(tool.name()) + .buffer_font(cx) + .size(LabelSize::Small), + ) + .child( + Icon::new(IconName::Info) + .size(IconSize::Small) + .color(Color::Ignored), + ) + .tooltip(Tooltip::text(tool.description())) + }), + )) + }) + } } impl Render for AssistantConfiguration { diff --git a/crates/agent/src/assistant_configuration/configure_context_server_modal.rs b/crates/agent/src/assistant_configuration/configure_context_server_modal.rs new file mode 100644 index 0000000000..b200e156d1 --- /dev/null +++ b/crates/agent/src/assistant_configuration/configure_context_server_modal.rs @@ -0,0 +1,443 @@ +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::Context as _; +use context_server::manager::{ContextServerManager, ContextServerStatus}; +use editor::{Editor, EditorElement, EditorStyle}; +use extension::ContextServerConfiguration; +use gpui::{ + Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task, + TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage, +}; +use language::{Language, LanguageRegistry}; +use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use settings::{Settings as _, update_settings_file}; +use theme::ThemeSettings; +use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*}; +use util::ResultExt; +use workspace::{ModalView, Workspace}; + +pub(crate) struct ConfigureContextServerModal { + workspace: WeakEntity, + context_servers_to_setup: Vec, + context_server_manager: Entity, +} + +struct ConfigureContextServer { + id: Arc, + installation_instructions: Entity, + settings_validator: Option, + settings_editor: Entity, + last_error: Option, + waiting_for_context_server: bool, +} + +impl ConfigureContextServerModal { + pub fn new( + configurations: impl Iterator, ContextServerConfiguration)>, + jsonc_language: Option>, + context_server_manager: Entity, + language_registry: Arc, + workspace: WeakEntity, + window: &mut Window, + cx: &mut App, + ) -> Option { + let context_servers_to_setup = configurations + .map(|(id, manifest)| { + let jsonc_language = jsonc_language.clone(); + let settings_validator = jsonschema::validator_for(&manifest.settings_schema) + .context("Failed to load JSON schema for context server settings") + .log_err(); + ConfigureContextServer { + id: id.clone(), + installation_instructions: cx.new(|cx| { + Markdown::new( + manifest.installation_instructions.clone().into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + settings_validator, + settings_editor: cx.new(|cx| { + let mut editor = Editor::auto_height(16, window, cx); + editor.set_text(manifest.default_settings.trim(), window, cx); + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + buffer.update(cx, |buffer, cx| buffer.set_language(jsonc_language, cx)) + } + editor + }), + waiting_for_context_server: false, + last_error: None, + } + }) + .collect::>(); + + if context_servers_to_setup.is_empty() { + return None; + } + + Some(Self { + workspace, + context_servers_to_setup, + context_server_manager, + }) + } +} + +impl ConfigureContextServerModal { + pub fn confirm(&mut self, cx: &mut Context) { + if self.context_servers_to_setup.is_empty() { + return; + } + + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + + let configuration = &mut self.context_servers_to_setup[0]; + if configuration.waiting_for_context_server { + return; + } + + let settings_value = match serde_json_lenient::from_str::( + &configuration.settings_editor.read(cx).text(cx), + ) { + Ok(value) => value, + Err(error) => { + configuration.last_error = Some(error.to_string().into()); + cx.notify(); + return; + } + }; + + if let Some(validator) = configuration.settings_validator.as_ref() { + if let Err(error) = validator.validate(&settings_value) { + configuration.last_error = Some(error.to_string().into()); + cx.notify(); + return; + } + } + let id = configuration.id.clone(); + + let settings_changed = context_server::ContextServerSettings::get_global(cx) + .context_servers + .get(&id) + .map_or(true, |config| { + config.settings.as_ref() != Some(&settings_value) + }); + + let is_running = self.context_server_manager.read(cx).status_for_server(&id) + == Some(ContextServerStatus::Running); + + if !settings_changed && is_running { + self.complete_setup(id, cx); + return; + } + + configuration.waiting_for_context_server = true; + + let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx); + cx.spawn({ + let id = id.clone(); + async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| match result { + Ok(_) => { + this.complete_setup(id, cx); + } + Err(err) => { + if let Some(configuration) = this.context_servers_to_setup.get_mut(0) { + configuration.last_error = Some(err.into()); + configuration.waiting_for_context_server = false; + } else { + this.dismiss(cx); + } + cx.notify(); + } + }) + } + }) + .detach(); + + // When we write the settings to the file, the context server will be restarted. + update_settings_file::( + workspace.read(cx).app_state().fs.clone(), + cx, + { + let id = id.clone(); + |settings, _| { + if let Some(server_config) = settings.context_servers.get_mut(&id) { + server_config.settings = Some(settings_value); + } else { + settings.context_servers.insert( + id, + context_server::ServerConfig { + settings: Some(settings_value), + ..Default::default() + }, + ); + } + } + }, + ); + } + + fn complete_setup(&mut self, id: Arc, cx: &mut Context) { + self.context_servers_to_setup.remove(0); + cx.notify(); + + if !self.context_servers_to_setup.is_empty() { + return; + } + + self.workspace + .update(cx, { + |workspace, cx| { + let status_toast = StatusToast::new( + format!("{} MCP configured successfully", id), + cx, + |this, _cx| { + this.icon(ToastIcon::new(IconName::DatabaseZap).color(Color::Muted)) + .action("Dismiss", |_, _| {}) + }, + ); + + workspace.toggle_status_toast(status_toast, cx); + } + }) + .log_err(); + + self.dismiss(cx); + } + + fn dismiss(&self, cx: &mut Context) { + cx.emit(DismissEvent); + } +} + +fn wait_for_context_server( + context_server_manager: &Entity, + context_server_id: Arc, + cx: &mut App, +) -> Task>> { + let (tx, rx) = futures::channel::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); + + let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event { + context_server::manager::Event::ServerStatusChanged { server_id, status } => match status { + Some(ContextServerStatus::Running) => { + if server_id == &context_server_id { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(Ok(())); + } + } + } + Some(ContextServerStatus::Error(error)) => { + if server_id == &context_server_id { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(Err(error.clone())); + } + } + } + _ => {} + }, + }); + + cx.spawn(async move |_cx| { + let result = rx.await.unwrap(); + drop(subscription); + result + }) +} + +impl Render for ConfigureContextServerModal { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let Some(configuration) = self.context_servers_to_setup.first() else { + return div().child("No context servers to setup"); + }; + + let focus_handle = self.focus_handle(cx); + + div() + .elevation_3(cx) + .w(rems(34.)) + .key_context("ConfigureContextServerModal") + .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| this.confirm(cx))) + .on_action(cx.listener(|this, _: &menu::Cancel, _window, cx| this.dismiss(cx))) + .capture_any_mouse_down(cx.listener(|this, _, window, cx| { + this.focus_handle(cx).focus(window); + })) + .child( + Modal::new("configure-context-server", None) + .header(ModalHeader::new().headline(format!("Configure {}", configuration.id))) + .section( + Section::new() + .child(div().py_2().child(MarkdownElement::new( + configuration.installation_instructions.clone(), + default_markdown_style(window, cx), + ))) + .child( + div() + .p_2() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border_variant) + .bg(cx.theme().colors().editor_background) + .gap_1() + .child({ + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_size: settings.buffer_font_size(cx).into(), + font_weight: settings.buffer_font.weight, + line_height: relative( + settings.buffer_line_height.value(), + ), + ..Default::default() + }; + EditorElement::new( + &configuration.settings_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ) + }) + .when_some(configuration.last_error.clone(), |this, error| { + this.child( + h_flex() + .gap_2() + .px_2() + .py_1() + .child( + Icon::new(IconName::Warning) + .size(IconSize::XSmall) + .color(Color::Warning), + ) + .child( + div().w_full().child( + Label::new(error) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ), + ) + }), + ) + .when(configuration.waiting_for_context_server, |this| { + this.child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::ArrowCircle) + .size(IconSize::XSmall) + .color(Color::Info) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate( + percentage(delta), + )) + }, + ) + .into_any_element(), + ) + .child( + Label::new("Waiting for Context Server") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }), + ) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("cancel", "Cancel") + .key_binding( + KeyBinding::for_action_in( + &menu::Cancel, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, _window, cx| { + this.dismiss(cx) + })), + ) + .child( + Button::new("configure-server", "Configure MCP") + .disabled(configuration.waiting_for_context_server) + .key_binding( + KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, _window, cx| { + this.confirm(cx) + })), + ), + ), + ), + ) + } +} + +pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let mut text_style = window.text_style(); + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + font_fallbacks: theme_settings.ui_font.fallbacks.clone(), + font_features: Some(theme_settings.ui_font.features.clone()), + font_size: Some(TextSize::XSmall.rems(cx).into()), + color: Some(colors.text_muted), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style.clone(), + selection_background_color: cx.theme().players().local().selection, + link: TextStyleRefinement { + background_color: Some(colors.editor_foreground.opacity(0.025)), + underline: Some(UnderlineStyle { + color: Some(colors.text_accent.opacity(0.5)), + thickness: px(1.), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + } +} + +impl ModalView for ConfigureContextServerModal {} +impl EventEmitter for ConfigureContextServerModal {} +impl Focusable for ConfigureContextServerModal { + fn focus_handle(&self, cx: &App) -> FocusHandle { + if let Some(current) = self.context_servers_to_setup.first() { + current.settings_editor.read(cx).focus_handle(cx) + } else { + cx.focus_handle() + } + } +} diff --git a/crates/agent/src/context_server_configuration.rs b/crates/agent/src/context_server_configuration.rs new file mode 100644 index 0000000000..ec92b70149 --- /dev/null +++ b/crates/agent/src/context_server_configuration.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use context_server::ContextServerDescriptorRegistry; +use extension::ExtensionManifest; +use language::LanguageRegistry; +use ui::prelude::*; +use util::ResultExt; +use workspace::Workspace; + +use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal}; + +pub(crate) fn init(language_registry: Arc, cx: &mut App) { + cx.observe_new(move |_: &mut Workspace, window, cx| { + let Some(window) = window else { + return; + }; + + if let Some(extension_events) = extension::ExtensionEvents::try_global(cx).as_ref() { + cx.subscribe_in(extension_events, window, { + let language_registry = language_registry.clone(); + move |workspace, _, event, window, cx| match event { + extension::Event::ExtensionInstalled(manifest) => { + show_configure_mcp_modal( + language_registry.clone(), + manifest, + workspace, + window, + cx, + ); + } + extension::Event::ConfigureExtensionRequested(manifest) => { + if !manifest.context_servers.is_empty() { + show_configure_mcp_modal( + language_registry.clone(), + manifest, + workspace, + window, + cx, + ); + } + } + _ => {} + } + }) + .detach(); + } else { + log::info!( + "No extension events global found. Skipping context server configuration wizard" + ); + } + }) + .detach(); +} + +fn show_configure_mcp_modal( + language_registry: Arc, + manifest: &Arc, + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context<'_, Workspace>, +) { + let Some(context_server_manager) = workspace.panel::(cx).map(|panel| { + panel + .read(cx) + .thread_store() + .read(cx) + .context_server_manager() + }) else { + return; + }; + + let registry = ContextServerDescriptorRegistry::global(cx).read(cx); + let project = workspace.project().clone(); + let configuration_tasks = manifest + .context_servers + .keys() + .cloned() + .filter_map({ + |key| { + let descriptor = registry.context_server_descriptor(&key)?; + Some(cx.spawn({ + let project = project.clone(); + async move |_, cx| { + descriptor + .configuration(project, &cx) + .await + .context("Failed to resolve context server configuration") + .log_err() + .flatten() + .map(|config| (key, config)) + } + })) + } + }) + .collect::>(); + + let jsonc_language = language_registry.language_for_name("jsonc"); + + cx.spawn_in(window, async move |this, cx| { + let descriptors = futures::future::join_all(configuration_tasks).await; + let jsonc_language = jsonc_language.await.ok(); + + this.update_in(cx, |this, window, cx| { + let modal = ConfigureContextServerModal::new( + descriptors.into_iter().flatten(), + jsonc_language, + context_server_manager, + language_registry, + cx.entity().downgrade(), + window, + cx, + ); + if let Some(modal) = modal { + this.toggle_modal(window, cx, |_, _| modal); + } + }) + }) + .detach(); +} diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 8465144b39..372ba33c25 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -9,8 +9,8 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; -use context_server::manager::ContextServerManager; -use context_server::{ContextServerFactoryRegistry, ContextServerTool}; +use context_server::manager::{ContextServerManager, ContextServerStatus}; +use context_server::{ContextServerDescriptorRegistry, ContextServerTool}; use futures::channel::{mpsc, oneshot}; use futures::future::{self, BoxFuture, Shared}; use futures::{FutureExt as _, StreamExt as _}; @@ -108,7 +108,7 @@ impl ThreadStore { prompt_store: Option>, cx: &mut Context, ) -> (Self, oneshot::Receiver<()>) { - let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx); + let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx); let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); @@ -555,62 +555,68 @@ impl ThreadStore { ) { let tool_working_set = self.tools.clone(); match event { - context_server::manager::Event::ServerStarted { server_id } => { - if let Some(server) = context_server_manager.read(cx).get_server(server_id) { - let context_server_manager = context_server_manager.clone(); - cx.spawn({ - let server = server.clone(); - let server_id = server_id.clone(); - async move |this, cx| { - let Some(protocol) = server.client() else { - return; - }; + context_server::manager::Event::ServerStatusChanged { server_id, status } => { + match status { + Some(ContextServerStatus::Running) => { + if let Some(server) = context_server_manager.read(cx).get_server(server_id) + { + let context_server_manager = context_server_manager.clone(); + cx.spawn({ + let server = server.clone(); + let server_id = server_id.clone(); + async move |this, cx| { + let Some(protocol) = server.client() else { + return; + }; - if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { - let tool_ids = tool_working_set - .update(cx, |tool_working_set, _| { - tools - .tools - .into_iter() - .map(|tool| { - log::info!( - "registering context server tool: {:?}", - tool.name - ); - tool_working_set.insert(Arc::new( - ContextServerTool::new( - context_server_manager.clone(), - server.id(), - tool, - ), - )) + if protocol.capable(context_server::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + let tool_ids = tool_working_set + .update(cx, |tool_working_set, _| { + tools + .tools + .into_iter() + .map(|tool| { + log::info!( + "registering context server tool: {:?}", + tool.name + ); + tool_working_set.insert(Arc::new( + ContextServerTool::new( + context_server_manager.clone(), + server.id(), + tool, + ), + )) + }) + .collect::>() }) - .collect::>() - }) - .log_err(); + .log_err(); - if let Some(tool_ids) = tool_ids { - this.update(cx, |this, cx| { - this.context_server_tool_ids - .insert(server_id, tool_ids); - this.load_default_profile(cx); - }) - .log_err(); + if let Some(tool_ids) = tool_ids { + this.update(cx, |this, cx| { + this.context_server_tool_ids + .insert(server_id, tool_ids); + this.load_default_profile(cx); + }) + .log_err(); + } + } } } - } + }) + .detach(); } - }) - .detach(); - } - } - context_server::manager::Event::ServerStopped { server_id } => { - if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { - tool_working_set.update(cx, |tool_working_set, _| { - tool_working_set.remove(&tool_ids); - }); - self.load_default_profile(cx); + } + None => { + if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { + tool_working_set.update(cx, |tool_working_set, _| { + tool_working_set.remove(&tool_ids); + }); + self.load_default_profile(cx); + } + } + _ => {} } } } diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index e1ea41e40a..f0f67506a6 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -7,8 +7,8 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet}; use client::{Client, TypedEnvelope, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::HashMap; -use context_server::ContextServerFactoryRegistry; -use context_server::manager::ContextServerManager; +use context_server::ContextServerDescriptorRegistry; +use context_server::manager::{ContextServerManager, ContextServerStatus}; use fs::{Fs, RemoveOptions}; use futures::StreamExt; use fuzzy::StringMatchCandidate; @@ -99,7 +99,7 @@ impl ContextStore { let this = cx.new(|cx: &mut Context| { let context_server_factory_registry = - ContextServerFactoryRegistry::default_global(cx); + ContextServerDescriptorRegistry::default_global(cx); let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); @@ -831,54 +831,60 @@ impl ContextStore { ) { let slash_command_working_set = self.slash_commands.clone(); match event { - context_server::manager::Event::ServerStarted { server_id } => { - if let Some(server) = context_server_manager.read(cx).get_server(server_id) { - let context_server_manager = context_server_manager.clone(); - cx.spawn({ - let server = server.clone(); - let server_id = server_id.clone(); - async move |this, cx| { - let Some(protocol) = server.client() else { - return; - }; + context_server::manager::Event::ServerStatusChanged { server_id, status } => { + match status { + Some(ContextServerStatus::Running) => { + if let Some(server) = context_server_manager.read(cx).get_server(server_id) + { + let context_server_manager = context_server_manager.clone(); + cx.spawn({ + let server = server.clone(); + let server_id = server_id.clone(); + async move |this, cx| { + let Some(protocol) = server.client() else { + return; + }; - if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - let slash_command_ids = prompts - .into_iter() - .filter(assistant_slash_commands::acceptable_prompt) - .map(|prompt| { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - slash_command_working_set.insert(Arc::new( - assistant_slash_commands::ContextServerSlashCommand::new( - context_server_manager.clone(), - &server, - prompt, - ), - )) - }) - .collect::>(); + if protocol.capable(context_server::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + let slash_command_ids = prompts + .into_iter() + .filter(assistant_slash_commands::acceptable_prompt) + .map(|prompt| { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + slash_command_working_set.insert(Arc::new( + assistant_slash_commands::ContextServerSlashCommand::new( + context_server_manager.clone(), + &server, + prompt, + ), + )) + }) + .collect::>(); - this.update( cx, |this, _cx| { - this.context_server_slash_command_ids - .insert(server_id.clone(), slash_command_ids); - }) - .log_err(); + this.update( cx, |this, _cx| { + this.context_server_slash_command_ids + .insert(server_id.clone(), slash_command_ids); + }) + .log_err(); + } + } } - } + }) + .detach(); } - }) - .detach(); - } - } - context_server::manager::Event::ServerStopped { server_id } => { - if let Some(slash_command_ids) = - self.context_server_slash_command_ids.remove(server_id) - { - slash_command_working_set.remove(&slash_command_ids); + } + None => { + if let Some(slash_command_ids) = + self.context_server_slash_command_ids.remove(server_id) + { + slash_command_working_set.remove(&slash_command_ids); + } + } + _ => {} } } } diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 3344c1f0ed..f229ed3e65 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -34,3 +34,7 @@ smol.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 119279b26f..7a2ae71c8b 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -140,7 +140,7 @@ impl Client { /// This function initializes a new Client by spawning a child process for the context server, /// setting up communication channels, and initializing handlers for input/output operations. /// It takes a server ID, binary information, and an async app context as input. - pub fn new( + pub fn stdio( server_id: ContextServerId, binary: ModelContextServerBinary, cx: AsyncApp, @@ -158,7 +158,16 @@ impl Client { .unwrap_or_else(String::new); let transport = Arc::new(StdioTransport::new(binary, &cx)?); + Self::new(server_id, server_name.into(), transport, cx) + } + /// Creates a new Client instance for a context server. + pub fn new( + server_id: ContextServerId, + server_name: Arc, + transport: Arc, + cx: AsyncApp, + ) -> Result { let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -167,7 +176,7 @@ impl Client { let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); - let stdout_input_task = cx.spawn({ + let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); let transport = transport.clone(); @@ -177,13 +186,13 @@ impl Client { .await } }); - let stderr_input_task = cx.spawn({ + let receive_err_task = cx.spawn({ let transport = transport.clone(); - async move |_| Self::handle_stderr(transport).log_err().await + async move |_| Self::handle_err(transport).log_err().await }); let input_task = cx.spawn(async move |_| { - let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); - stdout.or(stderr) + let (input, err) = futures::join!(receive_input_task, receive_err_task); + input.or(err) }); let output_task = cx.background_spawn({ @@ -201,7 +210,7 @@ impl Client { server_id, notification_handlers, response_handlers, - name: server_name.into(), + name: server_name, next_id: Default::default(), outbound_tx, executor: cx.background_executor().clone(), @@ -247,7 +256,7 @@ impl Client { /// Handles the stderr output from the context server. /// Continuously reads and logs any error messages from the server. - async fn handle_stderr(transport: Arc) -> anyhow::Result<()> { + async fn handle_err(transport: Arc) -> anyhow::Result<()> { while let Some(err) = transport.receive_err().next().await { log::warn!("context server stderr: {}", err.trim()); } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 27d285ba93..72b5a7cfc9 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo use gpui::{App, actions}; pub use crate::context_server_tool::ContextServerTool; -pub use crate::registry::ContextServerFactoryRegistry; +pub use crate::registry::ContextServerDescriptorRegistry; actions!(context_servers, [Restart]); @@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; pub fn init(cx: &mut App) { context_server_settings::init(cx); - ContextServerFactoryRegistry::default_global(cx); + ContextServerDescriptorRegistry::default_global(cx); extension_context_server::init(cx); CommandPaletteFilter::update_global(cx, |filter, _cx| { diff --git a/crates/context_server/src/extension_context_server.rs b/crates/context_server/src/extension_context_server.rs index 90ddc1609a..1fb138d56f 100644 --- a/crates/context_server/src/extension_context_server.rs +++ b/crates/context_server/src/extension_context_server.rs @@ -1,9 +1,21 @@ use std::sync::Arc; -use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate}; -use gpui::{App, Entity}; +use anyhow::Result; +use extension::{ + ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy, + ProjectDelegate, +}; +use gpui::{App, AsyncApp, Entity, Task}; +use project::Project; -use crate::{ContextServerFactoryRegistry, ServerCommand}; +use crate::{ContextServerDescriptorRegistry, ServerCommand, registry}; + +pub fn init(cx: &mut App) { + let proxy = ExtensionHostProxy::default_global(cx); + proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy { + context_server_factory_registry: ContextServerDescriptorRegistry::global(cx), + }); +} struct ExtensionProject { worktree_ids: Vec, @@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject { } } -pub fn init(cx: &mut App) { - let proxy = ExtensionHostProxy::default_global(cx); - proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy { - context_server_factory_registry: ContextServerFactoryRegistry::global(cx), - }); +struct ContextServerDescriptor { + id: Arc, + extension: Arc, } -struct ContextServerFactoryRegistryProxy { - context_server_factory_registry: Entity, +fn extension_project(project: Entity, cx: &mut AsyncApp) -> Result> { + project.update(cx, |project, cx| { + Arc::new(ExtensionProject { + worktree_ids: project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).id().to_proto()) + .collect(), + }) + }) } -impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy { +impl registry::ContextServerDescriptor for ContextServerDescriptor { + fn command(&self, project: Entity, cx: &AsyncApp) -> Task> { + let id = self.id.clone(); + let extension = self.extension.clone(); + cx.spawn(async move |cx| { + let extension_project = extension_project(project, cx)?; + let mut command = extension + .context_server_command(id.clone(), extension_project.clone()) + .await?; + command.command = extension + .path_from_extension(command.command.as_ref()) + .to_string_lossy() + .to_string(); + + log::info!("loaded command for context server {id}: {command:?}"); + + Ok(ServerCommand { + path: command.command, + args: command.args, + env: Some(command.env.into_iter().collect()), + }) + }) + } + + fn configuration( + &self, + project: Entity, + cx: &AsyncApp, + ) -> Task>> { + let id = self.id.clone(); + let extension = self.extension.clone(); + cx.spawn(async move |cx| { + let extension_project = extension_project(project, cx)?; + let configuration = extension + .context_server_configuration(id.clone(), extension_project) + .await?; + + log::debug!("loaded configuration for context server {id}: {configuration:?}"); + + Ok(configuration) + }) + } +} + +struct ContextServerDescriptorRegistryProxy { + context_server_factory_registry: Entity, +} + +impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy { fn register_context_server(&self, extension: Arc, id: Arc, cx: &mut App) { self.context_server_factory_registry .update(cx, |registry, _| { - registry.register_server_factory( + registry.register_context_server_descriptor( id.clone(), - Arc::new({ - move |project, cx| { - log::info!( - "loading command for context server {id} from extension {}", - extension.manifest().id - ); - - let id = id.clone(); - let extension = extension.clone(); - cx.spawn(async move |cx| { - let extension_project = project.update(cx, |project, cx| { - Arc::new(ExtensionProject { - worktree_ids: project - .visible_worktrees(cx) - .map(|worktree| worktree.read(cx).id().to_proto()) - .collect(), - }) - })?; - - let mut command = extension - .context_server_command(id.clone(), extension_project) - .await?; - command.command = extension - .path_from_extension(command.command.as_ref()) - .to_string_lossy() - .to_string(); - - log::info!("loaded command for context server {id}: {command:?}"); - - Ok(ServerCommand { - path: command.command, - args: command.args, - env: Some(command.env.into_iter().collect()), - }) - }) - } - }), + Arc::new(ContextServerDescriptor { id, extension }) + as Arc, ) }); } diff --git a/crates/context_server/src/manager.rs b/crates/context_server/src/manager.rs index cb163a655a..7a74e4879e 100644 --- a/crates/context_server/src/manager.rs +++ b/crates/context_server/src/manager.rs @@ -27,18 +27,27 @@ use project::Project; use settings::{Settings, SettingsStore}; use util::ResultExt as _; +use crate::transport::Transport; use crate::{ContextServerSettings, ServerConfig}; use crate::{ - CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry, + CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry, client::{self, Client}, types, }; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ContextServerStatus { + Starting, + Running, + Error(Arc), +} + pub struct ContextServer { pub id: Arc, pub config: Arc, pub client: RwLock>>, + transport: Option>, } impl ContextServer { @@ -47,9 +56,20 @@ impl ContextServer { id, config, client: RwLock::new(None), + transport: None, } } + #[cfg(any(test, feature = "test-support"))] + pub fn test(id: Arc, transport: Arc) -> Arc { + Arc::new(Self { + id, + client: RwLock::new(None), + config: Arc::new(ServerConfig::default()), + transport: Some(transport), + }) + } + pub fn id(&self) -> Arc { self.id.clone() } @@ -63,20 +83,32 @@ impl ContextServer { } pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - log::info!("starting context server {}", self.id); - let Some(command) = &self.config.command else { - bail!("no command specified for server {}", self.id); + let client = if let Some(transport) = self.transport.clone() { + Client::new( + client::ContextServerId(self.id.clone()), + self.id(), + transport, + cx.clone(), + )? + } else { + let Some(command) = &self.config.command else { + bail!("no command specified for server {}", self.id); + }; + Client::stdio( + client::ContextServerId(self.id.clone()), + client::ModelContextServerBinary { + executable: Path::new(&command.path).to_path_buf(), + args: command.args.clone(), + env: command.env.clone(), + }, + cx.clone(), + )? }; - let client = Client::new( - client::ContextServerId(self.id.clone()), - client::ModelContextServerBinary { - executable: Path::new(&command.path).to_path_buf(), - args: command.args.clone(), - env: command.env.clone(), - }, - cx.clone(), - )?; + self.initialize(client).await + } + async fn initialize(&self, client: Client) -> Result<()> { + log::info!("starting context server {}", self.id); let protocol = crate::protocol::ModelContextProtocol::new(client); let client_info = types::Implementation { name: "Zed".to_string(), @@ -105,23 +137,26 @@ impl ContextServer { pub struct ContextServerManager { servers: HashMap, Arc>, + server_status: HashMap, ContextServerStatus>, project: Entity, - registry: Entity, + registry: Entity, update_servers_task: Option>>, needs_server_update: bool, _subscriptions: Vec, } pub enum Event { - ServerStarted { server_id: Arc }, - ServerStopped { server_id: Arc }, + ServerStatusChanged { + server_id: Arc, + status: Option, + }, } impl EventEmitter for ContextServerManager {} impl ContextServerManager { pub fn new( - registry: Entity, + registry: Entity, project: Entity, cx: &mut Context, ) -> Self { @@ -138,6 +173,7 @@ impl ContextServerManager { registry, needs_server_update: false, servers: HashMap::default(), + server_status: HashMap::default(), update_servers_task: None, }; this.available_context_servers_changed(cx); @@ -153,7 +189,9 @@ impl ContextServerManager { this.needs_server_update = false; })?; - Self::maintain_servers(this.clone(), cx).await?; + if let Err(err) = Self::maintain_servers(this.clone(), cx).await { + log::error!("Error maintaining context servers: {}", err); + } this.update(cx, |this, cx| { let has_any_context_servers = !this.running_servers().is_empty(); @@ -181,52 +219,37 @@ impl ContextServerManager { .cloned() } + pub fn status_for_server(&self, id: &str) -> Option { + self.server_status.get(id).cloned() + } + pub fn start_server( &self, server: Arc, cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| { - let id = server.id.clone(); - server.start(&cx).await?; - this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?; - Ok(()) - }) + ) -> Task> { + cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await) } pub fn stop_server( - &self, + &mut self, server: Arc, cx: &mut Context, - ) -> anyhow::Result<()> { - server.stop()?; - cx.emit(Event::ServerStopped { - server_id: server.id(), - }); + ) -> Result<()> { + server.stop().log_err(); + self.update_server_status(server.id().clone(), None, cx); Ok(()) } - pub fn restart_server( - &mut self, - id: &Arc, - cx: &mut Context, - ) -> Task> { + pub fn restart_server(&mut self, id: &Arc, cx: &mut Context) -> Task> { let id = id.clone(); cx.spawn(async move |this, cx| { if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? { - server.stop()?; let config = server.config(); + + this.update(cx, |this, cx| this.stop_server(server, cx))??; let new_server = Arc::new(ContextServer::new(id.clone(), config)); - new_server.clone().start(&cx).await?; - this.update(cx, |this, cx| { - this.servers.insert(id.clone(), new_server); - cx.emit(Event::ServerStopped { - server_id: id.clone(), - }); - cx.emit(Event::ServerStarted { - server_id: id.clone(), - }); - })?; + Self::run_server(this, new_server, cx).await?; } Ok(()) }) @@ -263,12 +286,14 @@ impl ContextServerManager { (this.registry.clone(), this.project.clone()) })?; - for (id, factory) in - registry.read_with(cx, |registry, _| registry.context_server_factories())? + for (id, descriptor) in + registry.read_with(cx, |registry, _| registry.context_server_descriptors())? { let config = desired_servers.entry(id).or_default(); if config.command.is_none() { - if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() { + if let Some(extension_command) = + descriptor.command(project.clone(), &cx).await.log_err() + { config.command = Some(extension_command); } } @@ -290,28 +315,270 @@ impl ContextServerManager { for (id, config) in desired_servers { let existing_config = this.servers.get(&id).map(|server| server.config()); if existing_config.as_deref() != Some(&config) { - let config = Arc::new(config); - let server = Arc::new(ContextServer::new(id.clone(), config)); + let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config))); servers_to_start.insert(id.clone(), server.clone()); - let old_server = this.servers.insert(id.clone(), server); - if let Some(old_server) = old_server { + if let Some(old_server) = this.servers.remove(&id) { servers_to_stop.insert(id, old_server); } } } })?; - for (id, server) in servers_to_stop { - server.stop().log_err(); - this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?; + for (_, server) in servers_to_stop { + this.update(cx, |this, cx| this.stop_server(server, cx).ok())?; } - for (id, server) in servers_to_start { - if server.start(&cx).await.log_err().is_some() { - this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?; - } + for (_, server) in servers_to_start { + Self::run_server(this.clone(), server, cx).await.ok(); } Ok(()) } + + async fn run_server( + this: WeakEntity, + server: Arc, + cx: &mut AsyncApp, + ) -> Result<()> { + let id = server.id(); + + this.update(cx, |this, cx| { + this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx); + this.servers.insert(id.clone(), server.clone()); + })?; + + match server.start(&cx).await { + Ok(_) => { + log::debug!("`{}` context server started", id); + this.update(cx, |this, cx| { + this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx) + })?; + Ok(()) + } + Err(err) => { + log::error!("`{}` context server failed to start\n{}", id, err); + this.update(cx, |this, cx| { + this.update_server_status( + id.clone(), + Some(ContextServerStatus::Error(err.to_string().into())), + cx, + ) + })?; + Err(err) + } + } + } + + fn update_server_status( + &mut self, + id: Arc, + status: Option, + cx: &mut Context, + ) { + if let Some(status) = status.clone() { + self.server_status.insert(id.clone(), status); + } else { + self.server_status.remove(&id); + } + + cx.emit(Event::ServerStatusChanged { + server_id: id, + status, + }); + } +} + +#[cfg(test)] +mod tests { + use std::pin::Pin; + + use crate::types::{ + Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities, + }; + + use super::*; + use futures::{Stream, StreamExt as _, lock::Mutex}; + use gpui::{AppContext as _, TestAppContext}; + use project::FakeFs; + use serde_json::json; + use util::path; + + #[gpui::test] + async fn test_context_server_status(cx: &mut TestAppContext) { + init_test_settings(cx); + let project = create_test_project(cx, json!({"code.rs": ""})).await; + + let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); + let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx)); + + let server_1_id: Arc = "mcp-1".into(); + let server_2_id: Arc = "mcp-2".into(); + + let transport_1 = Arc::new(FakeTransport::new( + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-1".to_string())) + } + _ => None, + }, + )); + + let transport_2 = Arc::new(FakeTransport::new( + |_, request_type, _| match request_type { + Some(RequestType::Initialize) => { + Some(create_initialize_response("mcp-2".to_string())) + } + _ => None, + }, + )); + + let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone()); + let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone()); + + manager + .update(cx, |manager, cx| manager.start_server(server_1, cx)) + .await + .unwrap(); + + cx.update(|cx| { + assert_eq!( + manager.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!(manager.read(cx).status_for_server(&server_2_id), None); + }); + + manager + .update(cx, |manager, cx| manager.start_server(server_2.clone(), cx)) + .await + .unwrap(); + + cx.update(|cx| { + assert_eq!( + manager.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!( + manager.read(cx).status_for_server(&server_2_id), + Some(ContextServerStatus::Running) + ); + }); + + manager + .update(cx, |manager, cx| manager.stop_server(server_2, cx)) + .unwrap(); + + cx.update(|cx| { + assert_eq!( + manager.read(cx).status_for_server(&server_1_id), + Some(ContextServerStatus::Running) + ); + assert_eq!(manager.read(cx).status_for_server(&server_2_id), None); + }); + } + + async fn create_test_project( + cx: &mut TestAppContext, + files: serde_json::Value, + ) -> Entity { + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), files).await; + Project::test(fs, [path!("/test").as_ref()], cx).await + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + ContextServerSettings::register(cx); + }); + } + + fn create_initialize_response(server_name: String) -> serde_json::Value { + serde_json::to_value(&InitializeResponse { + protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), + server_info: Implementation { + name: server_name, + version: "1.0.0".to_string(), + }, + capabilities: ServerCapabilities::default(), + meta: None, + }) + .unwrap() + } + + struct FakeTransport { + on_request: Arc< + dyn Fn(u64, Option, serde_json::Value) -> Option + + Send + + Sync, + >, + tx: futures::channel::mpsc::UnboundedSender, + rx: Arc>>, + } + + impl FakeTransport { + fn new( + on_request: impl Fn( + u64, + Option, + serde_json::Value, + ) -> Option + + 'static + + Send + + Sync, + ) -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { + on_request: Arc::new(on_request), + tx, + rx: Arc::new(Mutex::new(rx)), + } + } + } + + #[async_trait::async_trait] + impl Transport for FakeTransport { + async fn send(&self, message: String) -> Result<()> { + if let Ok(msg) = serde_json::from_str::(&message) { + let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); + + if let Some(method) = msg.get("method") { + let request_type = method + .as_str() + .and_then(|method| types::RequestType::try_from(method).ok()); + if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": payload + }); + + self.tx + .unbounded_send(response.to_string()) + .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?; + } + } + } + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + let rx = self.rx.clone(); + Box::pin(futures::stream::unfold(rx, |rx| async move { + let mut rx_guard = rx.lock().await; + if let Some(message) = rx_guard.next().await { + drop(rx_guard); + Some((message, rx)) + } else { + None + } + })) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(futures::stream::empty()) + } + } } diff --git a/crates/context_server/src/registry.rs b/crates/context_server/src/registry.rs index e11a9ffbb9..96fd7459ef 100644 --- a/crates/context_server/src/registry.rs +++ b/crates/context_server/src/registry.rs @@ -2,38 +2,47 @@ use std::sync::Arc; use anyhow::Result; use collections::HashMap; +use extension::ContextServerConfiguration; use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task}; use project::Project; use crate::ServerCommand; -pub type ContextServerFactory = - Arc, &AsyncApp) -> Task> + Send + Sync + 'static>; - -struct GlobalContextServerFactoryRegistry(Entity); - -impl Global for GlobalContextServerFactoryRegistry {} - -#[derive(Default)] -pub struct ContextServerFactoryRegistry { - context_servers: HashMap, ContextServerFactory>, +pub trait ContextServerDescriptor { + fn command(&self, project: Entity, cx: &AsyncApp) -> Task>; + fn configuration( + &self, + project: Entity, + cx: &AsyncApp, + ) -> Task>>; } -impl ContextServerFactoryRegistry { - /// Returns the global [`ContextServerFactoryRegistry`]. +struct GlobalContextServerDescriptorRegistry(Entity); + +impl Global for GlobalContextServerDescriptorRegistry {} + +#[derive(Default)] +pub struct ContextServerDescriptorRegistry { + context_servers: HashMap, Arc>, +} + +impl ContextServerDescriptorRegistry { + /// Returns the global [`ContextServerDescriptorRegistry`]. pub fn global(cx: &App) -> Entity { - GlobalContextServerFactoryRegistry::global(cx).0.clone() + GlobalContextServerDescriptorRegistry::global(cx).0.clone() } - /// Returns the global [`ContextServerFactoryRegistry`]. + /// Returns the global [`ContextServerDescriptorRegistry`]. /// - /// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist. + /// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist. pub fn default_global(cx: &mut App) -> Entity { - if !cx.has_global::() { + if !cx.has_global::() { let registry = cx.new(|_| Self::new()); - cx.set_global(GlobalContextServerFactoryRegistry(registry)); + cx.set_global(GlobalContextServerDescriptorRegistry(registry)); } - cx.global::().0.clone() + cx.global::() + .0 + .clone() } pub fn new() -> Self { @@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry { } } - pub fn context_server_factories(&self) -> Vec<(Arc, ContextServerFactory)> { + pub fn context_server_descriptors(&self) -> Vec<(Arc, Arc)> { self.context_servers .iter() .map(|(id, factory)| (id.clone(), factory.clone())) .collect() } - /// Registers the provided [`ContextServerFactory`]. - pub fn register_server_factory(&mut self, id: Arc, factory: ContextServerFactory) { - self.context_servers.insert(id, factory); + pub fn context_server_descriptor(&self, id: &str) -> Option> { + self.context_servers.get(id).cloned() } - /// Unregisters the [`ContextServerFactory`] for the server with the given ID. - pub fn unregister_server_factory_by_id(&mut self, server_id: &str) { + /// Registers the provided [`ContextServerDescriptor`]. + pub fn register_context_server_descriptor( + &mut self, + id: Arc, + descriptor: Arc, + ) { + self.context_servers.insert(id, descriptor); + } + + /// Unregisters the [`ContextServerDescriptor`] for the server with the given ID. + pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) { self.context_servers.remove(server_id); } } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index f3c6e1c5e2..7478ae44af 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -42,6 +42,30 @@ impl RequestType { } } +impl TryFrom<&str> for RequestType { + type Error = (); + + fn try_from(s: &str) -> Result { + match s { + "initialize" => Ok(RequestType::Initialize), + "tools/call" => Ok(RequestType::CallTool), + "resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe), + "resources/subscribe" => Ok(RequestType::ResourcesSubscribe), + "resources/read" => Ok(RequestType::ResourcesRead), + "resources/list" => Ok(RequestType::ResourcesList), + "logging/setLevel" => Ok(RequestType::LoggingSetLevel), + "prompts/get" => Ok(RequestType::PromptsGet), + "prompts/list" => Ok(RequestType::PromptsList), + "completion/complete" => Ok(RequestType::CompletionComplete), + "ping" => Ok(RequestType::Ping), + "tools/list" => Ok(RequestType::ListTools), + "resources/templates/list" => Ok(RequestType::ListResourceTemplates), + "roots/list" => Ok(RequestType::ListRoots), + _ => Err(()), + } + } +} + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct ProtocolVersion(pub String); @@ -154,7 +178,7 @@ pub struct CompletionArgument { pub value: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeResponse { pub protocol_version: ProtocolVersion, @@ -343,7 +367,7 @@ pub struct ClientCapabilities { pub roots: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Default, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ServerCapabilities { #[serde(skip_serializing_if = "Option::is_none")] diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index f19f8de23a..646a8e23c0 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -424,7 +424,13 @@ pub fn init(cx: &mut App) -> Arc { prompt_store::init(cx); let stdout_is_a_pty = false; let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); - agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); + agent::init( + fs.clone(), + client.clone(), + prompt_builder.clone(), + languages.clone(), + cx, + ); assistant_tools::init(client.http_client(), cx); SettingsStore::update_global(cx, |store, cx| { diff --git a/crates/extension/src/extension.rs b/crates/extension/src/extension.rs index 1955c9f3d0..9f732a114d 100644 --- a/crates/extension/src/extension.rs +++ b/crates/extension/src/extension.rs @@ -121,6 +121,12 @@ pub trait Extension: Send + Sync + 'static { project: Arc, ) -> Result; + async fn context_server_configuration( + &self, + context_server_id: Arc, + project: Arc, + ) -> Result>; + async fn suggest_docs_packages(&self, provider: Arc) -> Result>; async fn index_docs( diff --git a/crates/extension/src/extension_events.rs b/crates/extension/src/extension_events.rs index 831010177d..73075067fd 100644 --- a/crates/extension/src/extension_events.rs +++ b/crates/extension/src/extension_events.rs @@ -1,5 +1,9 @@ +use std::sync::Arc; + use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global}; +use crate::ExtensionManifest; + pub fn init(cx: &mut App) { let extension_events = cx.new(ExtensionEvents::new); cx.set_global(GlobalExtensionEvents(extension_events)); @@ -31,7 +35,9 @@ impl ExtensionEvents { #[derive(Clone)] pub enum Event { + ExtensionInstalled(Arc), ExtensionsInstalledChanged, + ConfigureExtensionRequested(Arc), } impl EventEmitter for ExtensionEvents {} diff --git a/crates/extension/src/types.rs b/crates/extension/src/types.rs index f04d31300f..2e5b9c135c 100644 --- a/crates/extension/src/types.rs +++ b/crates/extension/src/types.rs @@ -1,8 +1,10 @@ +mod context_server; mod lsp; mod slash_command; use std::ops::Range; +pub use context_server::*; pub use lsp::*; pub use slash_command::*; diff --git a/crates/extension/src/types/context_server.rs b/crates/extension/src/types/context_server.rs new file mode 100644 index 0000000000..e0bac5b0d9 --- /dev/null +++ b/crates/extension/src/types/context_server.rs @@ -0,0 +1,10 @@ +/// Configuration for a context server. +#[derive(Debug, Clone)] +pub struct ContextServerConfiguration { + /// Installation instructions for the user. + pub installation_instructions: String, + /// Default settings for the context server. + pub default_settings: String, + /// JSON schema describing server settings. + pub settings_schema: serde_json::Value, +} diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 052a4181b8..1ecff4a5cc 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -18,6 +18,7 @@ pub use wit::{ CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars, KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file, make_file_executable, + zed::extension::context_server::ContextServerConfiguration, zed::extension::github::{ GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name, latest_github_release, @@ -159,6 +160,15 @@ pub trait Extension: Send + Sync { Err("`context_server_command` not implemented".to_string()) } + /// Returns the configuration options for the specified context server. + fn context_server_configuration( + &mut self, + _context_server_id: &ContextServerId, + _project: &Project, + ) -> Result> { + Ok(None) + } + /// Returns a list of package names as suggestions to be included in the /// search results of the `/docs` slash command. /// @@ -342,6 +352,14 @@ impl wit::Guest for Component { extension().context_server_command(&context_server_id, project) } + fn context_server_configuration( + context_server_id: String, + project: &Project, + ) -> Result, String> { + let context_server_id = ContextServerId(context_server_id); + extension().context_server_configuration(&context_server_id, project) + } + fn suggest_docs_packages(provider: String) -> Result, String> { extension().suggest_docs_packages(provider) } diff --git a/crates/extension_api/wit/since_v0.5.0/context-server.wit b/crates/extension_api/wit/since_v0.5.0/context-server.wit new file mode 100644 index 0000000000..89dc99c85b --- /dev/null +++ b/crates/extension_api/wit/since_v0.5.0/context-server.wit @@ -0,0 +1,11 @@ +interface context-server { + /// + record context-server-configuration { + /// + installation-instructions: string, + /// + settings-schema: string, + /// + default-settings: string, + } +} diff --git a/crates/extension_api/wit/since_v0.5.0/extension.wit b/crates/extension_api/wit/since_v0.5.0/extension.wit index 3caf8b60b7..f21cc1bf21 100644 --- a/crates/extension_api/wit/since_v0.5.0/extension.wit +++ b/crates/extension_api/wit/since_v0.5.0/extension.wit @@ -1,6 +1,7 @@ package zed:extension; world extension { + import context-server; import github; import http-client; import platform; @@ -8,6 +9,7 @@ world extension { import nodejs; use common.{env-vars, range}; + use context-server.{context-server-configuration}; use lsp.{completion, symbol}; use process.{command}; use slash-command.{slash-command, slash-command-argument-completion, slash-command-output}; @@ -139,6 +141,9 @@ world extension { /// Returns the command used to start up a context server. export context-server-command: func(context-server-id: string, project: borrow) -> result; + /// Returns the configuration for a context server. + export context-server-configuration: func(context-server-id: string, project: borrow) -> result, string>; + /// Returns a list of packages as suggestions to be included in the `/docs` /// search results. /// diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 411cb23bf1..87f315f658 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -431,6 +431,13 @@ impl ExtensionStore { .filter_map(|extension| extension.dev.then_some(&extension.manifest)) } + pub fn extension_manifest_for_id(&self, extension_id: &str) -> Option<&Arc> { + self.extension_index + .extensions + .get(extension_id) + .map(|extension| &extension.manifest) + } + /// Returns the names of themes provided by extensions. pub fn extension_themes<'a>( &'a self, @@ -744,8 +751,18 @@ impl ExtensionStore { .await; if let ExtensionOperation::Install = operation { - this.update( cx, |_, cx| { - cx.emit(Event::ExtensionInstalled(extension_id)); + this.update( cx, |this, cx| { + cx.emit(Event::ExtensionInstalled(extension_id.clone())); + if let Some(events) = ExtensionEvents::try_global(cx) { + if let Some(manifest) = this.extension_manifest_for_id(&extension_id) { + events.update(cx, |this, cx| { + this.emit( + extension::Event::ExtensionInstalled(manifest.clone()), + cx, + ) + }); + } + } }) .ok(); } @@ -935,6 +952,17 @@ impl ExtensionStore { .await?; this.update(cx, |this, cx| this.reload(None, cx))?.await; + this.update(cx, |this, cx| { + cx.emit(Event::ExtensionInstalled(extension_id.clone())); + if let Some(events) = ExtensionEvents::try_global(cx) { + if let Some(manifest) = this.extension_manifest_for_id(&extension_id) { + events.update(cx, |this, cx| { + this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx) + }); + } + } + })?; + Ok(()) }) } diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index 1308b8e421..7c61bd9ae8 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -4,8 +4,9 @@ use crate::ExtensionManifest; use anyhow::{Context as _, Result, anyhow, bail}; use async_trait::async_trait; use extension::{ - CodeLabel, Command, Completion, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, - SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, + CodeLabel, Command, Completion, ContextServerConfiguration, ExtensionHostProxy, + KeyValueStoreDelegate, ProjectDelegate, SlashCommand, SlashCommandArgumentCompletion, + SlashCommandOutput, Symbol, WorktreeDelegate, }; use fs::{Fs, normalize_path}; use futures::future::LocalBoxFuture; @@ -306,6 +307,33 @@ impl extension::Extension for WasmExtension { .await } + async fn context_server_configuration( + &self, + context_server_id: Arc, + project: Arc, + ) -> Result> { + self.call(|extension, store| { + async move { + let project_resource = store.data_mut().table().push(project)?; + let Some(configuration) = extension + .call_context_server_configuration( + store, + context_server_id.clone(), + project_resource, + ) + .await? + .map_err(|err| anyhow!("{err}"))? + else { + return Ok(None); + }; + + Ok(Some(configuration.try_into()?)) + } + .boxed() + }) + .await + } + async fn suggest_docs_packages(&self, provider: Arc) -> Result> { self.call(|extension, store| { async move { diff --git a/crates/extension_host/src/wasm_host/wit.rs b/crates/extension_host/src/wasm_host/wit.rs index d9c88606b7..732e55418d 100644 --- a/crates/extension_host/src/wasm_host/wit.rs +++ b/crates/extension_host/src/wasm_host/wit.rs @@ -25,6 +25,7 @@ use wasmtime::{ pub use latest::CodeLabelSpanLiteral; pub use latest::{ CodeLabel, CodeLabelSpan, Command, ExtensionProject, Range, SlashCommand, + zed::extension::context_server::ContextServerConfiguration, zed::extension::lsp::{ Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind, }, @@ -726,6 +727,29 @@ impl Extension { } } + pub async fn call_context_server_configuration( + &self, + store: &mut Store, + context_server_id: Arc, + project: Resource, + ) -> Result, String>> { + match self { + Extension::V0_5_0(ext) => { + ext.call_context_server_configuration(store, &context_server_id, project) + .await + } + Extension::V0_0_1(_) + | Extension::V0_0_4(_) + | Extension::V0_0_6(_) + | Extension::V0_1_0(_) + | Extension::V0_2_0(_) + | Extension::V0_3_0(_) + | Extension::V0_4_0(_) => Err(anyhow!( + "`context_server_configuration` not available prior to v0.5.0" + )), + } + } + pub async fn call_suggest_docs_packages( &self, store: &mut Store, diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs index 55a4d8efb8..8dc6674286 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_5_0.rs @@ -247,6 +247,21 @@ impl From for extension::SlashCommandArgumentCom } } +impl TryFrom for extension::ContextServerConfiguration { + type Error = anyhow::Error; + + fn try_from(value: ContextServerConfiguration) -> Result { + let settings_schema: serde_json::Value = serde_json::from_str(&value.settings_schema) + .context("Failed to parse settings_schema")?; + + Ok(Self { + installation_instructions: value.installation_instructions, + default_settings: value.default_settings, + settings_schema, + }) + } +} + impl HostKeyValueStore for WasmState { async fn insert( &mut self, @@ -610,6 +625,9 @@ impl process::Host for WasmState { #[async_trait] impl slash_command::Host for WasmState {} +#[async_trait] +impl context_server::Host for WasmState {} + impl ExtensionImports for WasmState { async fn get_settings( &mut self, diff --git a/crates/extensions_ui/Cargo.toml b/crates/extensions_ui/Cargo.toml index 8415041dc1..bc68c98ebc 100644 --- a/crates/extensions_ui/Cargo.toml +++ b/crates/extensions_ui/Cargo.toml @@ -17,6 +17,7 @@ client.workspace = true collections.workspace = true db.workspace = true editor.workspace = true +extension.workspace = true extension_host.workspace = true fs.workspace = true fuzzy.workspace = true diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index b2ad3e2699..15c73effee 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -246,6 +246,12 @@ fn keywords_by_feature() -> &'static BTreeMap> { }) } +struct ExtensionCardButtons { + install_or_uninstall: Button, + upgrade: Option