diff --git a/Cargo.lock b/Cargo.lock index f3d62686f9..f7df83814a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,7 @@ dependencies = [ "clock", "collections", "command_palette_hooks", + "context_servers", "ctor", "db", "editor", @@ -2668,6 +2669,27 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "context_servers" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "futures 0.3.30", + "gpui", + "log", + "parking_lot", + "postage", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "url", + "util", + "workspace", +] + [[package]] name = "convert_case" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 514b3708e2..3bd85b5793 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/collections", "crates/command_palette", "crates/command_palette_hooks", + "crates/context_servers", "crates/copilot", "crates/db", "crates/dev_server_projects", @@ -189,6 +190,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +context_servers = { path = "crates/context_servers" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } dev_server_projects = { path = "crates/dev_server_projects" } diff --git a/assets/settings/default.json b/assets/settings/default.json index 1648237d2d..86c73dfefb 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1010,5 +1010,16 @@ // ] // } // ] - "ssh_connections": null + "ssh_connections": null, + // Configures the Context Server Protocol binaries + // + // Examples: + // { + // "id": "server-1", + // "executable": "/path", + // "args": ['arg1", "args2"] + // } + "experimental.context_servers": { + "servers": [] + } } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index ba39e741e9..9915b32d0f 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -33,6 +33,7 @@ clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true db.workspace = true +context_servers.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index e7adff1286..db109e029d 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -21,9 +21,11 @@ use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub use context::*; +use context_servers::ContextServerRegistry; pub use context_store::*; use feature_flags::FeatureFlagAppExt; use fs::Fs; +use gpui::Context as _; use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; @@ -37,9 +39,9 @@ use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::{ - default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command, - project_command, prompt_command, search_command, symbols_command, tab_command, - terminal_command, workflow_command, + context_server_command, default_command, diagnostics_command, docs_command, fetch_command, + file_command, now_command, project_command, prompt_command, search_command, symbols_command, + tab_command, terminal_command, workflow_command, }; use std::sync::Arc; pub(crate) use streaming_diff::*; @@ -221,6 +223,7 @@ pub fn init( init_language_model_settings(cx); assistant_slash_command::init(cx); assistant_panel::init(cx); + context_servers::init(cx); let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext { dev_mode, @@ -261,9 +264,69 @@ pub fn init( }) .detach(); + register_context_server_handlers(cx); + prompt_builder } +fn register_context_server_handlers(cx: &mut AppContext) { + cx.subscribe( + &context_servers::manager::ContextServerManager::global(cx), + |manager, event, cx| match event { + context_servers::manager::Event::ServerStarted { server_id } => { + cx.update_model( + &manager, + |manager: &mut context_servers::manager::ContextServerManager, cx| { + let slash_command_registry = SlashCommandRegistry::global(cx); + let context_server_registry = ContextServerRegistry::global(cx); + if let Some(server) = manager.get_server(server_id) { + cx.spawn(|_, _| async move { + let Some(protocol) = server.client.read().clone() else { + return; + }; + + if let Some(prompts) = protocol.list_prompts().await.log_err() { + for prompt in prompts + .into_iter() + .filter(context_server_command::acceptable_prompt) + { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + context_server_registry.register_command( + server.id.clone(), + prompt.name.as_str(), + ); + slash_command_registry.register_command( + context_server_command::ContextServerSlashCommand::new( + &server, prompt, + ), + true, + ); + } + } + }) + .detach(); + } + }, + ); + } + context_servers::manager::Event::ServerStopped { server_id } => { + let slash_command_registry = SlashCommandRegistry::global(cx); + let context_server_registry = ContextServerRegistry::global(cx); + if let Some(commands) = context_server_registry.get_commands(server_id) { + for command_name in commands { + slash_command_registry.unregister_command_by_name(&command_name); + context_server_registry.unregister_command(&server_id, &command_name); + } + } + } + }, + ) + .detach(); +} + fn init_language_model_settings(cx: &mut AppContext) { update_active_language_model_from_settings(cx); diff --git a/crates/assistant/src/slash_command.rs b/crates/assistant/src/slash_command.rs index 37fcb6358e..9f0d24ea26 100644 --- a/crates/assistant/src/slash_command.rs +++ b/crates/assistant/src/slash_command.rs @@ -18,6 +18,7 @@ use std::{ use ui::ActiveTheme; use workspace::Workspace; +pub mod context_server_command; pub mod default_command; pub mod diagnostics_command; pub mod docs_command; diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs new file mode 100644 index 0000000000..95c58be1ee --- /dev/null +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -0,0 +1,125 @@ +use anyhow::{anyhow, Result}; +use assistant_slash_command::{ + ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, +}; +use collections::HashMap; +use context_servers::{ + manager::{ContextServer, ContextServerManager}, + protocol::PromptInfo, +}; +use gpui::{Task, WeakView, WindowContext}; +use language::LspAdapterDelegate; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use ui::{IconName, SharedString}; +use workspace::Workspace; + +pub struct ContextServerSlashCommand { + server_id: String, + prompt: PromptInfo, +} + +impl ContextServerSlashCommand { + pub fn new(server: &Arc, prompt: PromptInfo) -> Self { + Self { + server_id: server.id.clone(), + prompt, + } + } +} + +impl SlashCommand for ContextServerSlashCommand { + fn name(&self) -> String { + self.prompt.name.clone() + } + + fn description(&self) -> String { + format!("Run context server command: {}", self.prompt.name) + } + + fn menu_text(&self) -> String { + format!("Run '{}' from {}", self.prompt.name, self.server_id) + } + + fn requires_argument(&self) -> bool { + self.prompt + .arguments + .as_ref() + .map_or(false, |args| !args.is_empty()) + } + + fn complete_argument( + self: Arc, + _arguments: &[String], + _cancel: Arc, + _workspace: Option>, + _cx: &mut WindowContext, + ) -> Task>> { + Task::ready(Ok(Vec::new())) + } + + fn run( + self: Arc, + arguments: &[String], + _workspace: WeakView, + _delegate: Option>, + cx: &mut WindowContext, + ) -> Task> { + let server_id = self.server_id.clone(); + let prompt_name = self.prompt.name.clone(); + let argument = arguments.first().cloned(); + + let manager = ContextServerManager::global(cx); + let manager = manager.read(cx); + if let Some(server) = manager.get_server(&server_id) { + cx.foreground_executor().spawn(async move { + let Some(protocol) = server.client.read().clone() else { + return Err(anyhow!("Context server not initialized")); + }; + + let result = protocol + .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?) + .await?; + + Ok(SlashCommandOutput { + sections: vec![SlashCommandOutputSection { + range: 0..result.len(), + icon: IconName::ZedAssistant, + label: SharedString::from(format!("Result from {}", prompt_name)), + }], + text: result, + run_commands_in_text: false, + }) + }) + } else { + Task::ready(Err(anyhow!("Context server not found"))) + } + } +} + +fn prompt_arguments( + prompt: &PromptInfo, + argument: Option, +) -> Result> { + match &prompt.arguments { + Some(args) if args.len() >= 2 => Err(anyhow!( + "Prompt has more than one argument, which is not supported" + )), + Some(args) if args.len() == 1 => match argument { + Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])), + None => Err(anyhow!("Prompt expects argument but none given")), + }, + Some(_) | None => Ok(HashMap::default()), + } +} + +/// MCP servers can return prompts with multiple arguments. Since we only +/// support one argument, we ignore all others. This is the necessary predicate +/// for this. +pub fn acceptable_prompt(prompt: &PromptInfo) -> bool { + match &prompt.arguments { + None => true, + Some(args) if args.len() == 1 => true, + _ => false, + } +} diff --git a/crates/assistant_slash_command/src/slash_command_registry.rs b/crates/assistant_slash_command/src/slash_command_registry.rs index f0afc60234..d8a4014cfc 100644 --- a/crates/assistant_slash_command/src/slash_command_registry.rs +++ b/crates/assistant_slash_command/src/slash_command_registry.rs @@ -58,10 +58,14 @@ impl SlashCommandRegistry { /// Unregisters the provided [`SlashCommand`]. pub fn unregister_command(&self, command: impl SlashCommand) { + self.unregister_command_by_name(command.name().as_str()) + } + + /// Unregisters the command with the given name. + pub fn unregister_command_by_name(&self, command_name: &str) { let mut state = self.state.write(); - let command_name: Arc = command.name().into(); - state.featured_commands.remove(&command_name); - state.commands.remove(&command_name); + state.featured_commands.remove(command_name); + state.commands.remove(command_name); } /// Returns the names of registered [`SlashCommand`]s. diff --git a/crates/context_servers/Cargo.toml b/crates/context_servers/Cargo.toml new file mode 100644 index 0000000000..21bf6a1fc8 --- /dev/null +++ b/crates/context_servers/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "context_servers" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/context_servers.rs" + +[dependencies] +anyhow.workspace = true +collections.workspace = true +futures.workspace = true +gpui.workspace = true +log.workspace = true +parking_lot.workspace = true +postage.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +url = { workspace = true, features = ["serde"] } +util.workspace = true +workspace.workspace = true diff --git a/crates/context_servers/LICENSE-GPL b/crates/context_servers/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/context_servers/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/context_servers/src/client.rs b/crates/context_servers/src/client.rs new file mode 100644 index 0000000000..9cc6343203 --- /dev/null +++ b/crates/context_servers/src/client.rs @@ -0,0 +1,432 @@ +use anyhow::{anyhow, Context, Result}; +use collections::HashMap; +use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt}; +use gpui::{AsyncAppContext, BackgroundExecutor, Task}; +use parking_lot::Mutex; +use postage::barrier; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::{value::RawValue, Value}; +use smol::{ + channel, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, + process::{self, Child}, +}; +use std::{ + fmt, + path::PathBuf, + sync::{ + atomic::{AtomicI32, Ordering::SeqCst}, + Arc, + }, + time::{Duration, Instant}, +}; +use util::TryFutureExt; + +const JSON_RPC_VERSION: &str = "2.0"; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); + +type ResponseHandler = Box)>; +type NotificationHandler = Box; + +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +#[serde(untagged)] +pub enum RequestId { + Int(i32), + Str(String), +} + +pub struct Client { + server_id: ContextServerId, + next_id: AtomicI32, + outbound_tx: channel::Sender, + name: Arc, + notification_handlers: Arc>>, + response_handlers: Arc>>>, + #[allow(clippy::type_complexity)] + #[allow(dead_code)] + io_tasks: Mutex>, Task>)>>, + #[allow(dead_code)] + output_done_rx: Mutex>, + executor: BackgroundExecutor, + server: Arc>>, +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct ContextServerId(pub String); + +#[derive(Serialize, Deserialize)] +struct Request<'a, T> { + jsonrpc: &'static str, + id: RequestId, + method: &'a str, + params: T, +} + +#[derive(Serialize, Deserialize)] +struct AnyResponse<'a> { + jsonrpc: &'a str, + id: RequestId, + #[serde(default)] + error: Option, + #[serde(borrow)] + result: Option<&'a RawValue>, +} + +#[derive(Deserialize)] +#[allow(dead_code)] +struct Response { + jsonrpc: &'static str, + id: RequestId, + #[serde(flatten)] + value: CspResult, +} + +#[derive(Deserialize)] +#[serde(rename_all = "snake_case")] +enum CspResult { + #[serde(rename = "result")] + Ok(Option), + #[allow(dead_code)] + Error(Option), +} + +#[derive(Serialize, Deserialize)] +struct Notification<'a, T> { + jsonrpc: &'static str, + id: RequestId, + #[serde(borrow)] + method: &'a str, + params: T, +} + +#[derive(Debug, Clone, Deserialize)] +struct AnyNotification<'a> { + jsonrpc: &'a str, + id: RequestId, + method: String, + #[serde(default)] + params: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Error { + message: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelContextServerBinary { + pub executable: PathBuf, + pub args: Vec, + pub env: Option>, +} + +impl Client { + /// Creates a new Client instance for a context server. + /// + /// This function initializes a new Client by spawning a child process for the context server, + /// setting up communication channels, and initializing handlers for input/output operations. + /// It takes a server ID, binary information, and an async app context as input. + pub fn new( + server_id: ContextServerId, + binary: ModelContextServerBinary, + cx: AsyncAppContext, + ) -> Result { + log::info!( + "starting context server (executable={:?}, args={:?})", + binary.executable, + &binary.args + ); + + let mut command = process::Command::new(&binary.executable); + command + .args(&binary.args) + .envs(binary.env.unwrap_or_default()) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + let mut server = command.spawn().with_context(|| { + format!( + "failed to spawn command. (path={:?}, args={:?})", + binary.executable, &binary.args + ) + })?; + + let stdin = server.stdin.take().unwrap(); + let stdout = server.stdout.take().unwrap(); + let stderr = server.stderr.take().unwrap(); + + let (outbound_tx, outbound_rx) = channel::unbounded::(); + let (output_done_tx, output_done_rx) = barrier::channel(); + + let notification_handlers = + Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); + let response_handlers = + Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + + let stdout_input_task = cx.spawn({ + let notification_handlers = notification_handlers.clone(); + let response_handlers = response_handlers.clone(); + move |cx| { + Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err() + } + }); + let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err()); + let input_task = cx.spawn(|_| async move { + let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); + stdout.or(stderr) + }); + let output_task = cx.background_executor().spawn({ + Self::handle_output( + stdin, + outbound_rx, + output_done_tx, + response_handlers.clone(), + ) + .log_err() + }); + + let mut context_server = Self { + server_id, + notification_handlers, + response_handlers, + name: "".into(), + next_id: Default::default(), + outbound_tx, + executor: cx.background_executor().clone(), + io_tasks: Mutex::new(Some((input_task, output_task))), + output_done_rx: Mutex::new(Some(output_done_rx)), + server: Arc::new(Mutex::new(Some(server))), + }; + + if let Some(name) = binary.executable.file_name() { + context_server.name = name.to_string_lossy().into(); + } + + Ok(context_server) + } + + /// Handles input from the server's stdout. + /// + /// This function continuously reads lines from the provided stdout stream, + /// parses them as JSON-RPC responses or notifications, and dispatches them + /// to the appropriate handlers. It processes both responses (which are matched + /// to pending requests) and notifications (which trigger registered handlers). + async fn handle_input( + stdout: Stdout, + notification_handlers: Arc>>, + response_handlers: Arc>>>, + cx: AsyncAppContext, + ) -> anyhow::Result<()> + where + Stdout: AsyncRead + Unpin + Send + 'static, + { + let mut stdout = BufReader::new(stdout); + let mut buffer = String::new(); + + loop { + buffer.clear(); + if stdout.read_line(&mut buffer).await? == 0 { + return Ok(()); + } + + let content = buffer.trim(); + + if !content.is_empty() { + if let Ok(response) = serde_json::from_str::(&content) { + if let Some(handlers) = response_handlers.lock().as_mut() { + if let Some(handler) = handlers.remove(&response.id) { + handler(Ok(content.to_string())); + } + } + } else if let Ok(notification) = serde_json::from_str::(&content) { + let mut notification_handlers = notification_handlers.lock(); + if let Some(handler) = + notification_handlers.get_mut(notification.method.as_str()) + { + handler( + notification.id, + notification.params.unwrap_or(Value::Null), + cx.clone(), + ); + } + } + } + + smol::future::yield_now().await; + } + } + + /// Handles the stderr output from the context server. + /// Continuously reads and logs any error messages from the server. + async fn handle_stderr(stderr: Stderr) -> anyhow::Result<()> + where + Stderr: AsyncRead + Unpin + Send + 'static, + { + let mut stderr = BufReader::new(stderr); + let mut buffer = String::new(); + + loop { + buffer.clear(); + if stderr.read_line(&mut buffer).await? == 0 { + return Ok(()); + } + log::warn!("context server stderr: {}", buffer.trim()); + smol::future::yield_now().await; + } + } + + /// Handles the output to the context server's stdin. + /// This function continuously receives messages from the outbound channel, + /// writes them to the server's stdin, and manages the lifecycle of response handlers. + async fn handle_output( + stdin: Stdin, + outbound_rx: channel::Receiver, + output_done_tx: barrier::Sender, + response_handlers: Arc>>>, + ) -> anyhow::Result<()> + where + Stdin: AsyncWrite + Unpin + Send + 'static, + { + let mut stdin = BufWriter::new(stdin); + let _clear_response_handlers = util::defer({ + let response_handlers = response_handlers.clone(); + move || { + response_handlers.lock().take(); + } + }); + while let Ok(message) = outbound_rx.recv().await { + log::trace!("outgoing message: {}", message); + + stdin.write_all(message.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + } + drop(output_done_tx); + Ok(()) + } + + /// Sends a JSON-RPC request to the context server and waits for a response. + /// This function handles serialization, deserialization, timeout, and error handling. + pub async fn request( + &self, + method: &str, + params: impl Serialize, + ) -> Result { + let id = self.next_id.fetch_add(1, SeqCst); + let request = serde_json::to_string(&Request { + jsonrpc: JSON_RPC_VERSION, + id: RequestId::Int(id), + method, + params, + }) + .unwrap(); + + let (tx, rx) = oneshot::channel(); + let handle_response = self + .response_handlers + .lock() + .as_mut() + .ok_or_else(|| anyhow!("server shut down")) + .map(|handlers| { + handlers.insert( + RequestId::Int(id), + Box::new(move |result| { + let _ = tx.send(result); + }), + ); + }); + + let send = self + .outbound_tx + .try_send(request) + .context("failed to write to context server's stdin"); + + let executor = self.executor.clone(); + let started = Instant::now(); + handle_response?; + send?; + + let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + select! { + response = rx.fuse() => { + let elapsed = started.elapsed(); + log::trace!("took {elapsed:?} to receive response to {method:?} id {id}"); + match response? { + Ok(response) => { + let parsed: AnyResponse = serde_json::from_str(&response)?; + if let Some(error) = parsed.error { + Err(anyhow!(error.message)) + } else if let Some(result) = parsed.result { + Ok(serde_json::from_str(result.get())?) + } else { + Err(anyhow!("Invalid response: no result or error")) + } + } + Err(_) => anyhow::bail!("cancelled") + } + } + _ = timeout => { + log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); + anyhow::bail!("Context server request timeout"); + } + } + } + + /// Sends a notification to the context server without expecting a response. + /// This function serializes the notification and sends it through the outbound channel. + pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> { + let id = self.next_id.fetch_add(1, SeqCst); + let notification = serde_json::to_string(&Notification { + jsonrpc: JSON_RPC_VERSION, + id: RequestId::Int(id), + method, + params, + }) + .unwrap(); + self.outbound_tx.try_send(notification)?; + Ok(()) + } + + pub fn on_notification(&self, method: &'static str, mut f: F) + where + F: 'static + Send + FnMut(Value, AsyncAppContext), + { + self.notification_handlers + .lock() + .insert(method, Box::new(move |_, params, cx| f(params, cx))); + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn server_id(&self) -> ContextServerId { + self.server_id.clone() + } +} + +impl Drop for Client { + fn drop(&mut self) { + if let Some(mut server) = self.server.lock().take() { + let _ = server.kill(); + } + } +} + +impl fmt::Display for ContextServerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Context Server Client") + .field("id", &self.server_id.0) + .field("name", &self.name) + .finish_non_exhaustive() + } +} diff --git a/crates/context_servers/src/context_servers.rs b/crates/context_servers/src/context_servers.rs new file mode 100644 index 0000000000..3892adff56 --- /dev/null +++ b/crates/context_servers/src/context_servers.rs @@ -0,0 +1,36 @@ +use gpui::{actions, AppContext, Context, ViewContext}; +use log; +use manager::ContextServerManager; +use workspace::Workspace; + +pub mod client; +pub mod manager; +pub mod protocol; +mod registry; +pub mod types; + +pub use registry::*; + +actions!(context_servers, [Restart]); + +pub fn init(cx: &mut AppContext) { + log::info!("initializing context server client"); + manager::init(cx); + ContextServerRegistry::register(cx); + + cx.observe_new_views( + |workspace: &mut Workspace, _cx: &mut ViewContext| { + workspace.register_action(restart_servers); + }, + ) + .detach(); +} + +fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext) { + let model = ContextServerManager::global(&cx); + cx.update_model(&model, |manager, cx| { + for server in manager.servers() { + manager.restart_server(&server.id, cx).detach(); + } + }); +} diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs new file mode 100644 index 0000000000..9d7d67a72f --- /dev/null +++ b/crates/context_servers/src/manager.rs @@ -0,0 +1,278 @@ +//! This module implements a context server management system for Zed. +//! +//! It provides functionality to: +//! - Define and load context server settings +//! - Manage individual context servers (start, stop, restart) +//! - Maintain a global manager for all context servers +//! +//! Key components: +//! - `ContextServerSettings`: Defines the structure for server configurations +//! - `ContextServer`: Represents an individual context server +//! - `ContextServerManager`: Manages multiple context servers +//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager +//! +//! The module also includes initialization logic to set up the context server system +//! and react to changes in settings. + +use collections::{HashMap, HashSet}; +use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task}; +use log; +use parking_lot::RwLock; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources, SettingsStore}; +use std::path::Path; +use std::sync::Arc; + +use crate::{ + client::{self, Client}, + types, +}; + +#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)] +pub struct ContextServerSettings { + pub servers: Vec, +} + +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)] +pub struct ServerConfig { + pub id: String, + pub executable: String, + pub args: Vec, +} + +impl Settings for ContextServerSettings { + const KEY: Option<&'static str> = Some("experimental.context_servers"); + + type FileContent = Self; + + fn load( + sources: SettingsSources, + _: &mut gpui::AppContext, + ) -> anyhow::Result { + sources.json_merge() + } +} + +pub struct ContextServer { + pub id: String, + pub config: ServerConfig, + pub client: RwLock>>, +} + +impl ContextServer { + fn new(config: ServerConfig) -> Self { + Self { + id: config.id.clone(), + config, + client: RwLock::new(None), + } + } + + async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> { + log::info!("starting context server {}", self.config.id); + let client = Client::new( + client::ContextServerId(self.config.id.clone()), + client::ModelContextServerBinary { + executable: Path::new(&self.config.executable).to_path_buf(), + args: self.config.args.clone(), + env: None, + }, + cx.clone(), + )?; + + let protocol = crate::protocol::ModelContextProtocol::new(client); + let client_info = types::EntityInfo { + name: "Zed".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + let initialized_protocol = protocol.initialize(client_info).await?; + + log::debug!( + "context server {} initialized: {:?}", + self.config.id, + initialized_protocol.initialize, + ); + + *self.client.write() = Some(Arc::new(initialized_protocol)); + Ok(()) + } + + async fn stop(&self) -> anyhow::Result<()> { + let mut client = self.client.write(); + if let Some(protocol) = client.take() { + drop(protocol); + } + Ok(()) + } +} + +/// A Context server manager manages the starting and stopping +/// of all servers. To obtain a server to interact with, a crate +/// must go through the `GlobalContextServerManager` which holds +/// a model to the ContextServerManager. +pub struct ContextServerManager { + servers: HashMap>, + pending_servers: HashSet, +} + +pub enum Event { + ServerStarted { server_id: String }, + ServerStopped { server_id: String }, +} + +impl Global for ContextServerManager {} +impl EventEmitter for ContextServerManager {} + +impl ContextServerManager { + pub fn new() -> Self { + Self { + servers: HashMap::default(), + pending_servers: HashSet::default(), + } + } + pub fn global(cx: &AppContext) -> Model { + cx.global::().0.clone() + } + + pub fn add_server( + &mut self, + config: ServerConfig, + cx: &mut ModelContext, + ) -> Task> { + let server_id = config.id.clone(); + let server_id2 = config.id.clone(); + + if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) { + return Task::ready(Ok(())); + } + + let task = cx.spawn(|this, mut cx| async move { + let server = Arc::new(ContextServer::new(config)); + server.start(&cx).await?; + this.update(&mut cx, |this, cx| { + this.servers.insert(server_id.clone(), server); + this.pending_servers.remove(&server_id); + cx.emit(Event::ServerStarted { + server_id: server_id.clone(), + }); + })?; + Ok(()) + }); + + self.pending_servers.insert(server_id2); + task + } + + pub fn get_server(&self, id: &str) -> Option> { + self.servers.get(id).cloned() + } + + pub fn remove_server( + &mut self, + id: &str, + cx: &mut ModelContext, + ) -> Task> { + let id = id.to_string(); + cx.spawn(|this, mut cx| async move { + if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { + server.stop().await?; + } + this.update(&mut cx, |this, cx| { + this.pending_servers.remove(&id); + cx.emit(Event::ServerStopped { + server_id: id.clone(), + }) + })?; + Ok(()) + }) + } + + pub fn restart_server( + &mut self, + id: &str, + cx: &mut ModelContext, + ) -> Task> { + let id = id.to_string(); + cx.spawn(|this, mut cx| async move { + if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { + server.stop().await?; + let config = server.config.clone(); + let new_server = Arc::new(ContextServer::new(config)); + new_server.start(&cx).await?; + this.update(&mut cx, |this, cx| { + this.servers.insert(id.clone(), new_server); + cx.emit(Event::ServerStopped { + server_id: id.clone(), + }); + cx.emit(Event::ServerStarted { + server_id: id.clone(), + }); + })?; + } + Ok(()) + }) + } + + pub fn servers(&self) -> Vec> { + self.servers.values().cloned().collect() + } + + pub fn model(cx: &mut AppContext) -> Model { + cx.new_model(|_cx| ContextServerManager::new()) + } +} + +pub struct GlobalContextServerManager(Model); +impl Global for GlobalContextServerManager {} + +impl GlobalContextServerManager { + fn register(cx: &mut AppContext) { + let model = ContextServerManager::model(cx); + cx.set_global(Self(model)); + } +} + +pub fn init(cx: &mut AppContext) { + ContextServerSettings::register(cx); + GlobalContextServerManager::register(cx); + cx.observe_global::(|cx| { + let manager = ContextServerManager::global(cx); + cx.update_model(&manager, |manager, cx| { + let settings = ContextServerSettings::get_global(cx); + let current_servers: HashMap = manager + .servers() + .into_iter() + .map(|server| (server.id.clone(), server.config.clone())) + .collect(); + + let new_servers = settings + .servers + .iter() + .map(|config| (config.id.clone(), config.clone())) + .collect::>(); + + let servers_to_add = new_servers + .values() + .filter(|config| !current_servers.contains_key(&config.id)) + .cloned() + .collect::>(); + + let servers_to_remove = current_servers + .keys() + .filter(|id| !new_servers.contains_key(*id)) + .cloned() + .collect::>(); + + log::trace!("servers_to_add={:?}", servers_to_add); + for config in servers_to_add { + manager.add_server(config, cx).detach(); + } + + for id in servers_to_remove { + manager.remove_server(&id, cx).detach(); + } + }) + }) + .detach(); +} diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs new file mode 100644 index 0000000000..779ae89a05 --- /dev/null +++ b/crates/context_servers/src/protocol.rs @@ -0,0 +1,140 @@ +//! This module implements parts of the Model Context Protocol. +//! +//! It handles the lifecycle messages, and provides a general interface to +//! interacting with an MCP server. It uses the generic JSON-RPC client to +//! read/write messages and the types from types.rs for serialization/deserialization +//! of messages. + +use anyhow::Result; +use collections::HashMap; + +use crate::client::Client; +use crate::types; + +pub use types::PromptInfo; + +const PROTOCOL_VERSION: u32 = 1; + +pub struct ModelContextProtocol { + inner: Client, +} + +impl ModelContextProtocol { + pub fn new(inner: Client) -> Self { + Self { inner } + } + + pub async fn initialize( + self, + client_info: types::EntityInfo, + ) -> Result { + let params = types::InitializeParams { + protocol_version: PROTOCOL_VERSION, + capabilities: types::ClientCapabilities { + experimental: None, + sampling: None, + }, + client_info, + }; + + let response: types::InitializeResponse = self + .inner + .request(types::RequestType::Initialize.as_str(), params) + .await?; + + log::trace!("mcp server info {:?}", response.server_info); + + self.inner.notify( + types::NotificationType::Initialized.as_str(), + serde_json::json!({}), + )?; + + let initialized_protocol = InitializedContextServerProtocol { + inner: self.inner, + initialize: response, + }; + + Ok(initialized_protocol) + } +} + +pub struct InitializedContextServerProtocol { + inner: Client, + pub initialize: types::InitializeResponse, +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum ServerCapability { + Experimental, + Logging, + Prompts, + Resources, + Tools, +} + +impl InitializedContextServerProtocol { + /// Check if the server supports a specific capability + pub fn capable(&self, capability: ServerCapability) -> bool { + match capability { + ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(), + ServerCapability::Logging => self.initialize.capabilities.logging.is_some(), + ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(), + ServerCapability::Resources => self.initialize.capabilities.resources.is_some(), + ServerCapability::Tools => self.initialize.capabilities.tools.is_some(), + } + } + + fn check_capability(&self, capability: ServerCapability) -> Result<()> { + if self.capable(capability) { + Ok(()) + } else { + Err(anyhow::anyhow!( + "Server does not support {:?} capability", + capability + )) + } + } + + /// List the MCP prompts. + pub async fn list_prompts(&self) -> Result> { + self.check_capability(ServerCapability::Prompts)?; + + let response: types::PromptsListResponse = self + .inner + .request(types::RequestType::PromptsList.as_str(), ()) + .await?; + + Ok(response.prompts) + } + + /// Executes a prompt with the given arguments and returns the result. + pub async fn run_prompt>( + &self, + prompt: P, + arguments: HashMap, + ) -> Result { + self.check_capability(ServerCapability::Prompts)?; + + let params = types::PromptsGetParams { + name: prompt.as_ref().to_string(), + arguments: Some(arguments), + }; + + let response: types::PromptsGetResponse = self + .inner + .request(types::RequestType::PromptsGet.as_str(), params) + .await?; + + Ok(response.prompt) + } +} + +impl InitializedContextServerProtocol { + pub async fn request( + &self, + method: &str, + params: impl serde::Serialize, + ) -> Result { + self.inner.request(method, params).await + } +} diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs new file mode 100644 index 0000000000..625f308c15 --- /dev/null +++ b/crates/context_servers/src/registry.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use collections::HashMap; +use gpui::{AppContext, Global, ReadGlobal}; +use parking_lot::RwLock; + +struct GlobalContextServerRegistry(Arc); + +impl Global for GlobalContextServerRegistry {} + +pub struct ContextServerRegistry { + registry: RwLock>>>, +} + +impl ContextServerRegistry { + pub fn global(cx: &AppContext) -> Arc { + GlobalContextServerRegistry::global(cx).0.clone() + } + + pub fn register(cx: &mut AppContext) { + cx.set_global(GlobalContextServerRegistry(Arc::new( + ContextServerRegistry { + registry: RwLock::new(HashMap::default()), + }, + ))) + } + + pub fn register_command(&self, server_id: String, command_name: &str) { + let mut registry = self.registry.write(); + registry + .entry(server_id) + .or_default() + .push(command_name.into()); + } + + pub fn unregister_command(&self, server_id: &str, command_name: &str) { + let mut registry = self.registry.write(); + if let Some(commands) = registry.get_mut(server_id) { + commands.retain(|name| name.as_ref() != command_name); + } + } + + pub fn get_commands(&self, server_id: &str) -> Option>> { + let registry = self.registry.read(); + registry.get(server_id).cloned() + } +} diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs new file mode 100644 index 0000000000..fb736d9de5 --- /dev/null +++ b/crates/context_servers/src/types.rs @@ -0,0 +1,234 @@ +use collections::HashMap; +use serde::{Deserialize, Serialize}; +use url::Url; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub enum RequestType { + Initialize, + CallTool, + ResourcesUnsubscribe, + ResourcesSubscribe, + ResourcesRead, + ResourcesList, + LoggingSetLevel, + PromptsGet, + PromptsList, +} + +impl RequestType { + pub fn as_str(&self) -> &'static str { + match self { + RequestType::Initialize => "initialize", + RequestType::CallTool => "tools/call", + RequestType::ResourcesUnsubscribe => "resources/unsubscribe", + RequestType::ResourcesSubscribe => "resources/subscribe", + RequestType::ResourcesRead => "resources/read", + RequestType::ResourcesList => "resources/list", + RequestType::LoggingSetLevel => "logging/setLevel", + RequestType::PromptsGet => "prompts/get", + RequestType::PromptsList => "prompts/list", + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeParams { + pub protocol_version: u32, + pub capabilities: ClientCapabilities, + pub client_info: EntityInfo, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolParams { + pub name: String, + pub arguments: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesUnsubscribeParams { + pub uri: Url, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesSubscribeParams { + pub uri: Url, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesReadParams { + pub uri: Url, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct LoggingSetLevelParams { + pub level: LoggingLevel, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsGetParams { + pub name: String, + pub arguments: Option>, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResponse { + pub protocol_version: u32, + pub capabilities: ServerCapabilities, + pub server_info: EntityInfo, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesReadResponse { + pub contents: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesListResponse { + pub resource_templates: Option>, + pub resources: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsGetResponse { + pub prompt: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsListResponse { + pub prompts: Vec, +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PromptInfo { + pub name: String, + pub arguments: Option>, +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PromptArgument { + pub name: String, + pub description: Option, + pub required: Option, +} + +// Shared Types + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClientCapabilities { + pub experimental: Option>, + pub sampling: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerCapabilities { + pub experimental: Option>, + pub logging: Option>, + pub prompts: Option>, + pub resources: Option, + pub tools: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesCapabilities { + pub subscribe: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + pub name: String, + pub description: Option, + pub input_schema: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EntityInfo { + pub name: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Resource { + pub uri: Url, + pub mime_type: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceContent { + pub uri: Url, + pub mime_type: Option, + pub content_type: String, + pub text: Option, + pub data: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceTemplate { + pub uri_template: String, + pub name: Option, + pub description: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LoggingLevel { + Debug, + Info, + Warning, + Error, +} + +// Client Notifications + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub enum NotificationType { + Initialized, + Progress, +} + +impl NotificationType { + pub fn as_str(&self) -> &'static str { + match self { + NotificationType::Initialized => "notifications/initialized", + NotificationType::Progress => "notifications/progress", + } + } +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum ClientNotification { + Initialized, + Progress(ProgressParams), +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ProgressParams { + pub progress_token: String, + pub progress: f64, + pub total: Option, +}