context_store: Refactor state management (#29910)

Because we instantiated `ContextServerManager` both in `agent` and
`assistant-context-editor`, and these two entities track the running MCP
servers separately, we were effectively running every MCP server twice.

This PR moves the `ContextServerManager` into the project crate (now
called `ContextServerStore`). The store can be accessed via a project
instance. This ensures that we only instantiate one `ContextServerStore`
per project.

Also, this PR adds a bunch of tests to ensure that the
`ContextServerStore` behaves correctly (Previously there were none).

Closes #28714
Closes #29530

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-05-05 21:36:12 +02:00 committed by Joseph T. Lyons
parent 71f7100083
commit 108005f1b8
43 changed files with 1570 additions and 1049 deletions

31
Cargo.lock generated
View File

@ -484,7 +484,6 @@ dependencies = [
"client",
"collections",
"command_palette_hooks",
"context_server",
"ctor",
"db",
"editor",
@ -3328,40 +3327,19 @@ name = "context_server"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"async-trait",
"collections",
"command_palette_hooks",
"context_server_settings",
"extension",
"futures 0.3.31",
"gpui",
"icons",
"language_model",
"log",
"parking_lot",
"postage",
"project",
"serde",
"serde_json",
"settings",
"smol",
"url",
"util",
"workspace-hack",
]
[[package]]
name = "context_server_settings"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"url",
"util",
"workspace-hack",
]
@ -5028,7 +5006,6 @@ dependencies = [
"clap",
"client",
"collections",
"context_server",
"dirs 4.0.0",
"dotenv",
"env_logger 0.11.8",
@ -5181,7 +5158,6 @@ dependencies = [
"async-trait",
"client",
"collections",
"context_server_settings",
"ctor",
"env_logger 0.11.8",
"extension",
@ -11110,6 +11086,7 @@ dependencies = [
"client",
"clock",
"collections",
"context_server",
"dap",
"dap_adapters",
"env_logger 0.11.8",

View File

@ -34,7 +34,6 @@ members = [
"crates/component",
"crates/component_preview",
"crates/context_server",
"crates/context_server_settings",
"crates/copilot",
"crates/credentials_provider",
"crates/dap",
@ -243,7 +242,6 @@ command_palette_hooks = { path = "crates/command_palette_hooks" }
component = { path = "crates/component" }
component_preview = { path = "crates/component_preview" }
context_server = { path = "crates/context_server" }
context_server_settings = { path = "crates/context_server_settings" }
copilot = { path = "crates/copilot" }
credentials_provider = { path = "crates/credentials_provider" }
dap = { path = "crates/dap" }

View File

@ -3487,7 +3487,6 @@ fn open_editor_at_position(
#[cfg(test)]
mod tests {
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use context_server::ContextServerSettings;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{TestAppContext, VisualTestContext};
@ -3559,7 +3558,6 @@ mod tests {
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});

View File

@ -1748,7 +1748,6 @@ mod tests {
use crate::{Keep, ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
use project::{FakeFs, Project};
@ -1771,7 +1770,6 @@ mod tests {
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
language_model::init_settings(cx);
});
@ -1928,7 +1926,6 @@ mod tests {
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
language_model::init_settings(cx);
workspace::register_project_item::<Editor>(cx);

View File

@ -7,6 +7,7 @@ mod buffer_codegen;
mod context;
mod context_picker;
mod context_server_configuration;
mod context_server_tool;
mod context_store;
mod context_strip;
mod debug;

View File

@ -8,13 +8,14 @@ use std::{sync::Arc, time::Duration};
use assistant_settings::AssistantSettings;
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
use context_server::ContextServerId;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, pulsating_between,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
@ -33,8 +34,8 @@ pub struct AssistantConfiguration {
fs: Arc<dyn Fs>,
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
context_server_store: Entity<ContextServerStore>,
expanded_context_server_tools: HashMap<ContextServerId, bool>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
@ -44,7 +45,7 @@ pub struct AssistantConfiguration {
impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>,
context_server_store: Entity<ContextServerStore>,
tools: Entity<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
@ -75,7 +76,7 @@ impl AssistantConfiguration {
fs,
focus_handle,
configuration_views_by_provider: HashMap::default(),
context_server_manager,
context_server_store,
expanded_context_server_tools: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
@ -306,7 +307,7 @@ impl AssistantConfiguration {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@ -322,9 +323,9 @@ impl AssistantConfiguration {
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(
context_servers
.into_iter()
.map(|context_server| self.render_context_server(context_server, window, cx)),
context_server_ids.into_iter().map(|context_server_id| {
self.render_context_server(context_server_id, window, cx)
}),
)
.child(
h_flex()
@ -374,19 +375,20 @@ impl AssistantConfiguration {
fn render_context_server(
&self,
context_server: Arc<ContextServer>,
context_server_id: ContextServerId,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl use<> + IntoElement {
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let server_status = self
.context_server_manager
.context_server_store
.read(cx)
.status_for_server(&context_server.id());
.status_for_server(&context_server_id)
.unwrap_or(ContextServerStatus::Stopped);
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
let is_running = matches!(server_status, ContextServerStatus::Running);
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
let error = if let ContextServerStatus::Error(error) = server_status.clone() {
Some(error)
} else {
None
@ -394,13 +396,13 @@ impl AssistantConfiguration {
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.get(&context_server_id)
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
id: context_server_id.0.clone().into(),
})
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
@ -408,7 +410,7 @@ impl AssistantConfiguration {
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
.id(SharedString::from(context_server.id()))
.id(SharedString::from(context_server_id.0.clone()))
.border_1()
.rounded_md()
.border_color(border_color)
@ -432,7 +434,7 @@ impl AssistantConfiguration {
)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
let context_server_id = context_server_id.clone();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
@ -444,14 +446,14 @@ impl AssistantConfiguration {
})),
)
.child(match server_status {
Some(ContextServerStatus::Starting) => {
ContextServerStatus::Starting => {
let color = Color::Success.color(cx);
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!(
"{}-starting",
context_server.id(),
context_server_id.0.clone(),
)),
Animation::new(Duration::from_secs(2))
.repeat()
@ -462,15 +464,17 @@ impl AssistantConfiguration {
)
.into_any_element()
}
Some(ContextServerStatus::Running) => {
ContextServerStatus::Running => {
Indicator::dot().color(Color::Success).into_any_element()
}
Some(ContextServerStatus::Error(_)) => {
ContextServerStatus::Error(_) => {
Indicator::dot().color(Color::Error).into_any_element()
}
None => Indicator::dot().color(Color::Muted).into_any_element(),
ContextServerStatus::Stopped => {
Indicator::dot().color(Color::Muted).into_any_element()
}
})
.child(Label::new(context_server.id()).ml_0p5())
.child(Label::new(context_server_id.0.clone()).ml_0p5())
.when(is_running, |this| {
this.child(
Label::new(if tool_count == 1 {
@ -487,32 +491,22 @@ impl AssistantConfiguration {
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_manager.clone();
let context_server = context_server.clone();
let context_server_manager = self.context_server_store.clone();
let context_server_id = context_server_id.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx).log_err();
this.stop_server(&context_server_id, cx).log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(context_server, cx)
})
.log_err()
{
start_server_task.await.log_err();
}
context_server_manager.update(cx, |this, cx| {
if let Some(server) =
this.get_server(&context_server_id)
{
this.start_server(server, cx).log_err();
}
})
.detach();
}
}
}),

View File

@ -1,5 +1,6 @@
use context_server::{ContextServerSettings, ServerCommand, ServerConfig};
use context_server::ContextServerCommand;
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, prelude::*};
use project::project_settings::{ContextServerConfiguration, ProjectSettings};
use serde_json::json;
use settings::update_settings_file;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*};
@ -77,11 +78,11 @@ impl AddContextServerModal {
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ContextServerSettings>(fs.clone(), cx, |settings, _| {
update_settings_file::<ProjectSettings>(fs.clone(), cx, |settings, _| {
settings.context_servers.insert(
name.into(),
ServerConfig {
command: Some(ServerCommand {
ContextServerConfiguration {
command: Some(ContextServerCommand {
path,
args,
env: None,

View File

@ -4,9 +4,8 @@ use std::{
};
use anyhow::Context as _;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use context_server::ContextServerId;
use editor::{Editor, EditorElement, EditorStyle};
use extension::ContextServerConfiguration;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
@ -14,6 +13,10 @@ use gpui::{
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
context_server_store::{ContextServerStatus, ContextServerStore},
project_settings::{ContextServerConfiguration, ProjectSettings},
};
use settings::{Settings as _, update_settings_file};
use theme::ThemeSettings;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
@ -23,11 +26,11 @@ use workspace::{ModalView, Workspace};
pub(crate) struct ConfigureContextServerModal {
workspace: WeakEntity<Workspace>,
context_servers_to_setup: Vec<ConfigureContextServer>,
context_server_manager: Entity<ContextServerManager>,
context_server_store: Entity<ContextServerStore>,
}
struct ConfigureContextServer {
id: Arc<str>,
id: ContextServerId,
installation_instructions: Entity<markdown::Markdown>,
settings_validator: Option<jsonschema::Validator>,
settings_editor: Entity<Editor>,
@ -37,9 +40,9 @@ struct ConfigureContextServer {
impl ConfigureContextServerModal {
pub fn new(
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
configurations: impl Iterator<Item = (ContextServerId, extension::ContextServerConfiguration)>,
context_server_store: Entity<ContextServerStore>,
jsonc_language: Option<Arc<Language>>,
context_server_manager: Entity<ContextServerManager>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@ -85,7 +88,7 @@ impl ConfigureContextServerModal {
Some(Self {
workspace,
context_servers_to_setup,
context_server_manager,
context_server_store,
})
}
}
@ -126,14 +129,14 @@ impl ConfigureContextServerModal {
}
let id = configuration.id.clone();
let settings_changed = context_server::ContextServerSettings::get_global(cx)
let settings_changed = ProjectSettings::get_global(cx)
.context_servers
.get(&id)
.get(&id.0)
.map_or(true, |config| {
config.settings.as_ref() != Some(&settings_value)
});
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
let is_running = self.context_server_store.read(cx).status_for_server(&id)
== Some(ContextServerStatus::Running);
if !settings_changed && is_running {
@ -143,7 +146,7 @@ impl ConfigureContextServerModal {
configuration.waiting_for_context_server = true;
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
let task = wait_for_context_server(&self.context_server_store, id.clone(), cx);
cx.spawn({
let id = id.clone();
async move |this, cx| {
@ -167,29 +170,25 @@ impl ConfigureContextServerModal {
.detach();
// When we write the settings to the file, the context server will be restarted.
update_settings_file::<context_server::ContextServerSettings>(
workspace.read(cx).app_state().fs.clone(),
cx,
{
let id = id.clone();
|settings, _| {
if let Some(server_config) = settings.context_servers.get_mut(&id) {
server_config.settings = Some(settings_value);
} else {
settings.context_servers.insert(
id,
context_server::ServerConfig {
settings: Some(settings_value),
..Default::default()
},
);
}
update_settings_file::<ProjectSettings>(workspace.read(cx).app_state().fs.clone(), cx, {
let id = id.clone();
|settings, _| {
if let Some(server_config) = settings.context_servers.get_mut(&id.0) {
server_config.settings = Some(settings_value);
} else {
settings.context_servers.insert(
id.0,
ContextServerConfiguration {
settings: Some(settings_value),
..Default::default()
},
);
}
},
);
}
});
}
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
fn complete_setup(&mut self, id: ContextServerId, cx: &mut Context<Self>) {
self.context_servers_to_setup.remove(0);
cx.notify();
@ -223,31 +222,40 @@ impl ConfigureContextServerModal {
}
fn wait_for_context_server(
context_server_manager: &Entity<ContextServerManager>,
context_server_id: Arc<str>,
context_server_store: &Entity<ContextServerStore>,
context_server_id: ContextServerId,
cx: &mut App,
) -> Task<Result<(), Arc<str>>> {
let (tx, rx) = futures::channel::oneshot::channel();
let tx = Arc::new(Mutex::new(Some(tx)));
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
Some(ContextServerStatus::Running) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Ok(()));
let subscription = cx.subscribe(context_server_store, move |_, event, _cx| match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
ContextServerStatus::Running => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Ok(()));
}
}
}
}
Some(ContextServerStatus::Error(error)) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err(error.clone()));
ContextServerStatus::Stopped => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err("Context server stopped running".into()));
}
}
}
ContextServerStatus::Error(error) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err(error.clone()));
}
}
}
_ => {}
}
_ => {}
},
}
});
cx.spawn(async move |_cx| {

View File

@ -1026,14 +1026,14 @@ impl AssistantPanel {
}
pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let context_server_manager = self.thread_store.read(cx).context_server_manager();
let context_server_store = self.project.read(cx).context_server_store();
let tools = self.thread_store.read(cx).tools();
let fs = self.fs.clone();
self.set_active_view(ActiveView::Configuration, window, cx);
self.configuration =
Some(cx.new(|cx| {
AssistantConfiguration::new(fs, context_server_manager, tools, window, cx)
AssistantConfiguration::new(fs, context_server_store, tools, window, cx)
}));
if let Some(configuration) = self.configuration.as_ref() {

View File

@ -1095,9 +1095,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(language_settings::init);
init_test(cx);
let text = indoc! {"
fn main() {
@ -1167,8 +1165,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
init_test(cx);
let text = indoc! {"
fn main() {
@ -1237,9 +1234,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
init_test(cx);
let text = concat!(
"fn main() {\n",
@ -1305,9 +1300,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
init_test(cx);
let text = indoc! {"
func main() {
@ -1367,9 +1360,7 @@ mod tests {
#[gpui::test]
async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
init_test(cx);
let text = indoc! {"
fn main() {
@ -1473,6 +1464,13 @@ mod tests {
}
}
fn init_test(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(Project::init_settings);
cx.update(language_settings::init);
}
fn simulate_response_stream(
codegen: Entity<CodegenAlternative>,
cx: &mut TestAppContext,

View File

@ -1,14 +1,15 @@
use std::sync::Arc;
use anyhow::Context as _;
use context_server::ContextServerDescriptorRegistry;
use context_server::ContextServerId;
use extension::ExtensionManifest;
use language::LanguageRegistry;
use project::context_server_store::registry::ContextServerDescriptorRegistry;
use ui::prelude::*;
use util::ResultExt;
use workspace::Workspace;
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
use crate::assistant_configuration::ConfigureContextServerModal;
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
cx.observe_new(move |_: &mut Workspace, window, cx| {
@ -60,18 +61,10 @@ fn show_configure_mcp_modal(
window: &mut Window,
cx: &mut Context<'_, Workspace>,
) {
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
panel
.read(cx)
.thread_store()
.read(cx)
.context_server_manager()
}) else {
return;
};
let context_server_store = workspace.project().read(cx).context_server_store();
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
let project = workspace.project().clone();
let registry = ContextServerDescriptorRegistry::default_global(cx).read(cx);
let worktree_store = workspace.project().read(cx).worktree_store();
let configuration_tasks = manifest
.context_servers
.keys()
@ -80,15 +73,15 @@ fn show_configure_mcp_modal(
|key| {
let descriptor = registry.context_server_descriptor(&key)?;
Some(cx.spawn({
let project = project.clone();
let worktree_store = worktree_store.clone();
async move |_, cx| {
descriptor
.configuration(project, &cx)
.configuration(worktree_store.clone(), &cx)
.await
.context("Failed to resolve context server configuration")
.log_err()
.flatten()
.map(|config| (key, config))
.map(|config| (ContextServerId(key), config))
}
}))
}
@ -104,8 +97,8 @@ fn show_configure_mcp_modal(
this.update_in(cx, |this, window, cx| {
let modal = ConfigureContextServerModal::new(
descriptors.into_iter().flatten(),
context_server_store,
jsonc_language,
context_server_manager,
language_registry,
cx.entity().downgrade(),
window,

View File

@ -2,29 +2,27 @@ use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use context_server::{ContextServerId, types};
use gpui::{AnyWindowHandle, App, Entity, Task};
use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use crate::manager::ContextServerManager;
use crate::types;
use project::{Project, context_server_store::ContextServerStore};
use ui::IconName;
pub struct ContextServerTool {
server_manager: Entity<ContextServerManager>,
server_id: Arc<str>,
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: types::Tool,
}
impl ContextServerTool {
pub fn new(
server_manager: Entity<ContextServerManager>,
server_id: impl Into<Arc<str>>,
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: types::Tool,
) -> Self {
Self {
server_manager,
server_id: server_id.into(),
store,
server_id,
tool,
}
}
@ -45,7 +43,7 @@ impl Tool for ContextServerTool {
fn source(&self) -> ToolSource {
ToolSource::ContextServer {
id: self.server_id.clone().into(),
id: self.server_id.clone().0.into(),
}
}
@ -80,7 +78,7 @@ impl Tool for ContextServerTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
let tool_name = self.tool.name.clone();
let server_clone = server.clone();
let input_clone = input.clone();

View File

@ -2660,7 +2660,6 @@ mod tests {
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
@ -3082,7 +3081,6 @@ fn main() {{
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});

View File

@ -9,8 +9,7 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
use context_server::ContextServerId;
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
@ -21,6 +20,7 @@ use gpui::{
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
@ -30,6 +30,7 @@ use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::context_server_tool::ContextServerTool;
use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
@ -62,8 +63,7 @@ pub struct ThreadStore {
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
reload_system_prompt_tx: mpsc::Sender<()>,
@ -108,11 +108,6 @@ impl ThreadStore {
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> (Self, oneshot::Receiver<()>) {
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let mut subscriptions = vec![
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
@ -159,7 +154,6 @@ impl ThreadStore {
tools,
prompt_builder,
prompt_store,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
@ -354,10 +348,6 @@ impl ThreadStore {
})
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
@ -494,11 +484,17 @@ impl ThreadStore {
});
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
for context_server_id in self
.project
.read(cx)
.context_server_store()
.read(cx)
.all_server_ids()
{
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
id: context_server_id.0.into(),
},
cx,
);
@ -541,7 +537,7 @@ impl ThreadStore {
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
&self.project.read(cx).context_server_store(),
Self::handle_context_server_event,
)
.detach();
@ -549,18 +545,19 @@ impl ThreadStore {
fn handle_context_server_event(
&mut self,
context_server_manager: Entity<ContextServerManager>,
event: &context_server::manager::Event,
context_server_store: Entity<ContextServerStore>,
event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
ContextServerStatus::Running => {
if let Some(server) =
context_server_store.read(cx).get_running_server(server_id)
{
let context_server_manager = context_server_manager.clone();
let context_server_manager = context_server_store.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
@ -608,7 +605,7 @@ impl ThreadStore {
.detach();
}
}
None => {
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);

View File

@ -31,7 +31,6 @@ async-watch.workspace = true
client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server.workspace = true
db.workspace = true
editor.workspace = true
feature_flags.workspace = true

View File

@ -106,7 +106,6 @@ pub fn init(
assistant_slash_command::init(cx);
assistant_tool::init(cx);
assistant_panel::init(cx);
context_server::init(cx);
register_slash_commands(cx);
inline_assistant::init(

View File

@ -1192,21 +1192,19 @@ impl AssistantPanel {
fn restart_context_servers(
workspace: &mut Workspace,
_action: &context_server::Restart,
_action: &project::context_server_store::Restart,
_: &mut Window,
cx: &mut Context<Workspace>,
) {
let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
return;
};
assistant_panel.update(cx, |assistant_panel, cx| {
assistant_panel
.context_store
.update(cx, |context_store, cx| {
context_store.restart_context_servers(cx);
});
});
workspace
.project()
.read(cx)
.context_server_store()
.update(cx, |store, cx| {
for server in store.running_servers() {
store.restart_server(&server.id(), cx).log_err();
}
});
}
}

View File

@ -7,15 +7,17 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::HashMap;
use context_server::ContextServerDescriptorRegistry;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use context_server::ContextServerId;
use fs::{Fs, RemoveOptions};
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::LanguageRegistry;
use paths::contexts_dir;
use project::Project;
use project::{
Project,
context_server_store::{ContextServerStatus, ContextServerStore},
};
use prompt_store::PromptBuilder;
use regex::Regex;
use rpc::AnyProtoClient;
@ -40,8 +42,7 @@ pub struct RemoteContextMetadata {
pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
context_server_manager: Entity<ContextServerManager>,
context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
@ -98,15 +99,9 @@ impl ContextStore {
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new(|cx: &mut Context<Self>| {
let context_server_factory_registry =
ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let mut this = Self {
contexts: Vec::new(),
contexts_metadata: Vec::new(),
context_server_manager,
context_server_slash_command_ids: HashMap::default(),
host_contexts: Vec::new(),
fs,
@ -802,22 +797,9 @@ impl ContextStore {
})
}
pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
cx.update_entity(
&self.context_server_manager,
|context_server_manager, cx| {
for server in context_server_manager.running_servers() {
context_server_manager
.restart_server(&server.id(), cx)
.detach_and_log_err(cx);
}
},
);
}
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
&self.project.read(cx).context_server_store(),
Self::handle_context_server_event,
)
.detach();
@ -825,16 +807,18 @@ impl ContextStore {
fn handle_context_server_event(
&mut self,
context_server_manager: Entity<ContextServerManager>,
event: &context_server::manager::Event,
context_server_manager: Entity<ContextServerStore>,
event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
let slash_command_working_set = self.slash_commands.clone();
match event {
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
ContextServerStatus::Running => {
if let Some(server) = context_server_manager
.read(cx)
.get_running_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
@ -858,7 +842,7 @@ impl ContextStore {
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
server.id(),
prompt,
),
))
@ -877,7 +861,7 @@ impl ContextStore {
.detach();
}
}
None => {
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{

View File

@ -4,12 +4,10 @@ use assistant_slash_command::{
SlashCommandOutputSection, SlashCommandResult,
};
use collections::HashMap;
use context_server::{
manager::{ContextServer, ContextServerManager},
types::Prompt,
};
use context_server::{ContextServerId, types::Prompt};
use gpui::{App, Entity, Task, WeakEntity, Window};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
use project::context_server_store::ContextServerStore;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use text::LineEnding;
@ -19,21 +17,17 @@ use workspace::Workspace;
use crate::create_label_for_command;
pub struct ContextServerSlashCommand {
server_manager: Entity<ContextServerManager>,
server_id: Arc<str>,
store: Entity<ContextServerStore>,
server_id: ContextServerId,
prompt: Prompt,
}
impl ContextServerSlashCommand {
pub fn new(
server_manager: Entity<ContextServerManager>,
server: &Arc<ContextServer>,
prompt: Prompt,
) -> Self {
pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
Self {
server_id: server.id(),
server_id: id,
prompt,
server_manager,
store,
}
}
}
@ -88,7 +82,7 @@ impl SlashCommand for ContextServerSlashCommand {
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));
@ -142,8 +136,8 @@ impl SlashCommand for ContextServerSlashCommand {
Err(e) => return Task::ready(Err(e)),
};
let manager = self.server_manager.read(cx);
if let Some(server) = manager.get_server(&server_id) {
let store = self.store.read(cx);
if let Some(server) = store.get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));

View File

@ -6709,8 +6709,6 @@ async fn test_context_collaboration_with_reconnect(
assert_eq!(project.collaborators().len(), 1);
});
cx_a.update(context_server::init);
cx_b.update(context_server::init);
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a
.update(|cx| {

View File

@ -712,6 +712,7 @@ impl TestClient {
worktree
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
.await;
cx.run_until_parked();
(project, worktree.read_with(cx, |tree, _| tree.id()))
}

View File

@ -13,28 +13,17 @@ path = "src/context_server.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
async-trait.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server_settings.workspace = true
extension.workspace = true
futures.workspace = true
gpui.workspace = true
icons.workspace = true
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }

View File

@ -40,7 +40,7 @@ pub enum RequestId {
Str(String),
}
pub struct Client {
pub(crate) struct Client {
server_id: ContextServerId,
next_id: AtomicI32,
outbound_tx: channel::Sender<String>,
@ -59,7 +59,7 @@ pub struct Client {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct ContextServerId(pub Arc<str>);
pub(crate) struct ContextServerId(pub Arc<str>);
fn is_null_value<T: Serialize>(value: &T) -> bool {
if let Ok(Value::Null) = serde_json::to_value(value) {
@ -367,6 +367,7 @@ impl Client {
Ok(())
}
#[allow(unused)]
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
@ -375,14 +376,6 @@ impl Client {
.lock()
.insert(method, Box::new(f));
}
pub fn name(&self) -> &str {
&self.name
}
pub fn server_id(&self) -> ContextServerId {
self.server_id.clone()
}
}
impl fmt::Display for ContextServerId {

View File

@ -1,30 +1,117 @@
pub mod client;
mod context_server_tool;
mod extension_context_server;
pub mod manager;
pub mod protocol;
mod registry;
mod transport;
pub mod transport;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;
pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
use gpui::{App, actions};
use std::fmt::Display;
use std::path::Path;
use std::sync::Arc;
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerDescriptorRegistry;
use anyhow::Result;
use client::Client;
use collections::HashMap;
use gpui::AsyncApp;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
actions!(context_servers, [Restart]);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ContextServerId(pub Arc<str>);
/// The namespace for the context servers actions.
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut App) {
context_server_settings::init(cx);
ContextServerDescriptorRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
});
impl Display for ContextServerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ContextServerCommand {
pub path: String,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
enum ContextServerTransport {
Stdio(ContextServerCommand),
Custom(Arc<dyn crate::transport::Transport>),
}
pub struct ContextServer {
id: ContextServerId,
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
configuration: ContextServerTransport,
}
impl ContextServer {
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Stdio(command),
}
}
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Custom(transport),
}
}
pub fn id(&self) -> ContextServerId {
self.id.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration {
ContextServerTransport::Stdio(command) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?,
ContextServerTransport::Custom(transport) => Client::new(
client::ContextServerId(self.id.0.clone()),
self.id().0,
transport.clone(),
cx.clone(),
)?,
};
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}

View File

@ -1,584 +0,0 @@
//! This module implements a context server management system for Zed.
//!
//! It provides functionality to:
//! - Define and load context server settings
//! - Manage individual context servers (start, stop, restart)
//! - Maintain a global manager for all context servers
//!
//! Key components:
//! - `ContextServerSettings`: Defines the structure for server configurations
//! - `ContextServer`: Represents an individual context server
//! - `ContextServerManager`: Manages multiple context servers
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
//!
//! The module also includes initialization logic to set up the context server system
//! and react to changes in settings.
use std::path::Path;
use std::sync::Arc;
use anyhow::{Result, bail};
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
use log;
use parking_lot::RwLock;
use project::Project;
use settings::{Settings, SettingsStore};
use util::ResultExt as _;
use crate::transport::Transport;
use crate::{ContextServerSettings, ServerConfig};
use crate::{
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
client::{self, Client},
types,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ContextServerStatus {
Starting,
Running,
Error(Arc<str>),
}
pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
transport: Option<Arc<dyn Transport>>,
}
impl ContextServer {
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
Self {
id,
config,
client: RwLock::new(None),
transport: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
Arc::new(Self {
id,
client: RwLock::new(None),
config: Arc::new(ServerConfig::default()),
transport: Some(transport),
})
}
pub fn id(&self) -> Arc<str> {
self.id.clone()
}
pub fn config(&self) -> Arc<ServerConfig> {
self.config.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
let client = if let Some(transport) = self.transport.clone() {
Client::new(
client::ContextServerId(self.id.clone()),
self.id(),
transport,
cx.clone(),
)?
} else {
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
Client::stdio(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?
};
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<ContextServer>>,
server_status: HashMap<Arc<str>, ContextServerStatus>,
project: Entity<Project>,
registry: Entity<ContextServerDescriptorRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
}
pub enum Event {
ServerStatusChanged {
server_id: Arc<str>,
status: Option<ContextServerStatus>,
},
}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new(
registry: Entity<ContextServerDescriptorRegistry>,
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Self {
let mut this = Self {
_subscriptions: vec![
cx.observe(&registry, |this, _registry, cx| {
this.available_context_servers_changed(cx);
}),
cx.observe_global::<SettingsStore>(|this, cx| {
this.available_context_servers_changed(cx);
}),
],
project,
registry,
needs_server_update: false,
servers: HashMap::default(),
server_status: HashMap::default(),
update_servers_task: None,
};
this.available_context_servers_changed(cx);
this
}
fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
if self.update_servers_task.is_some() {
self.needs_server_update = true;
} else {
self.update_servers_task = Some(cx.spawn(async move |this, cx| {
this.update(cx, |this, _| {
this.needs_server_update = false;
})?;
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
log::error!("Error maintaining context servers: {}", err);
}
this.update(cx, |this, cx| {
let has_any_context_servers = !this.running_servers().is_empty();
if has_any_context_servers {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
});
}
this.update_servers_task.take();
if this.needs_server_update {
this.available_context_servers_changed(cx);
}
})?;
Ok(())
}));
}
}
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
self.servers
.get(id)
.filter(|server| server.client().is_some())
.cloned()
}
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
self.server_status.get(id).cloned()
}
pub fn start_server(
&self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
}
pub fn stop_server(
&mut self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> Result<()> {
server.stop().log_err();
self.update_server_status(server.id().clone(), None, cx);
Ok(())
}
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
cx.spawn(async move |this, cx| {
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
let config = server.config();
this.update(cx, |this, cx| this.stop_server(server, cx))??;
let new_server = Arc::new(ContextServer::new(id.clone(), config));
Self::run_server(this, new_server, cx).await?;
}
Ok(())
})
}
pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
self.servers.values().cloned().collect()
}
pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
self.servers
.values()
.filter(|server| server.client().is_some())
.cloned()
.collect()
}
async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
let mut desired_servers = HashMap::default();
let (registry, project) = this.update(cx, |this, cx| {
let location = this
.project
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| settings::SettingsLocation {
worktree_id: worktree.read(cx).id(),
path: Path::new(""),
});
let settings = ContextServerSettings::get(location, cx);
desired_servers = settings.context_servers.clone();
(this.registry.clone(), this.project.clone())
})?;
for (id, descriptor) in
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) =
descriptor.command(project.clone(), &cx).await.log_err()
{
config.command = Some(extension_command);
}
}
}
let mut servers_to_start = HashMap::default();
let mut servers_to_stop = HashMap::default();
this.update(cx, |this, _cx| {
this.servers.retain(|id, server| {
if desired_servers.contains_key(id) {
true
} else {
servers_to_stop.insert(id.clone(), server.clone());
false
}
});
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
servers_to_start.insert(id.clone(), server.clone());
if let Some(old_server) = this.servers.remove(&id) {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (_, server) in servers_to_stop {
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
}
for (_, server) in servers_to_start {
Self::run_server(this.clone(), server, cx).await.ok();
}
Ok(())
}
async fn run_server(
this: WeakEntity<Self>,
server: Arc<ContextServer>,
cx: &mut AsyncApp,
) -> Result<()> {
let id = server.id();
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
this.servers.insert(id.clone(), server.clone());
})?;
match server.start(&cx).await {
Ok(_) => {
log::debug!("`{}` context server started", id);
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
})?;
Ok(())
}
Err(err) => {
log::error!("`{}` context server failed to start\n{}", id, err);
this.update(cx, |this, cx| {
this.update_server_status(
id.clone(),
Some(ContextServerStatus::Error(err.to_string().into())),
cx,
)
})?;
Err(err)
}
}
}
fn update_server_status(
&mut self,
id: Arc<str>,
status: Option<ContextServerStatus>,
cx: &mut Context<Self>,
) {
if let Some(status) = status.clone() {
self.server_status.insert(id.clone(), status);
} else {
self.server_status.remove(&id);
}
cx.emit(Event::ServerStatusChanged {
server_id: id,
status,
});
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use crate::types::{
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
};
use super::*;
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext as _, TestAppContext};
use project::FakeFs;
use serde_json::json;
use util::path;
#[gpui::test]
async fn test_context_server_status(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({"code.rs": ""})).await;
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
let server_1_id: Arc<str> = "mcp-1".into();
let server_2_id: Arc<str> = "mcp-2".into();
let transport_1 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
manager
.update(cx, |manager, cx| manager.start_server(server_1, cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
manager
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(
manager.read(cx).status_for_server(&server_2_id),
Some(ContextServerStatus::Running)
);
});
manager
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
}
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
ContextServerSettings::register(cx);
});
}
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
}
impl FakeTransport {
fn new(
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
Box::pin(futures::stream::unfold(rx, |rx| async move {
let mut rx_guard = rx.lock().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
}

View File

@ -16,7 +16,7 @@ pub struct ModelContextProtocol {
}
impl ModelContextProtocol {
pub fn new(inner: Client) -> Self {
pub(crate) fn new(inner: Client) -> Self {
Self { inner }
}

View File

@ -610,7 +610,7 @@ pub enum ToolResponseContent {
Resource { resource: ResourceContents },
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResponse {
pub tools: Vec<Tool>,

View File

@ -1,22 +0,0 @@
[package]
name = "context_server_settings"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/context_server_settings.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
workspace-hack.workspace = true

View File

@ -1 +0,0 @@
../../LICENSE-GPL

View File

@ -1,99 +0,0 @@
use std::sync::Arc;
use collections::HashMap;
use gpui::App;
use schemars::JsonSchema;
use schemars::r#gen::SchemaGenerator;
use schemars::schema::{InstanceType, Schema, SchemaObject};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
pub fn init(cx: &mut App) {
ContextServerSettings::register(cx);
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
pub struct ServerConfig {
/// The command to run this context server.
///
/// This will override the command set by an extension.
pub command: Option<ServerCommand>,
/// The settings for this context server.
///
/// Consult the documentation for the context server to see what settings
/// are supported.
#[schemars(schema_with = "server_config_settings_json_schema")]
pub settings: Option<serde_json::Value>,
}
fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::Object.into()),
..Default::default()
})
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ServerCommand {
pub path: String,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ContextServerSettings {
/// Settings for context servers used in the Assistant.
#[serde(default)]
pub context_servers: HashMap<Arc<str>, ServerConfig>,
}
impl Settings for ContextServerSettings {
const KEY: Option<&'static str> = None;
type FileContent = Self;
fn load(
sources: SettingsSources<Self::FileContent>,
_: &mut gpui::App,
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
// we don't handle "inputs" replacement strings, see perplexity-key in this example:
// https://code.visualstudio.com/docs/copilot/chat/mcp-servers#_configuration-example
#[derive(Deserialize)]
struct VsCodeServerCommand {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
// note: we don't support envFile and type
}
impl From<VsCodeServerCommand> for ServerCommand {
fn from(cmd: VsCodeServerCommand) -> Self {
Self {
path: cmd.command,
args: cmd.args.unwrap_or_default(),
env: cmd.env,
}
}
}
if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
current
.context_servers
.extend(mcp.iter().filter_map(|(k, v)| {
Some((
k.clone().into(),
ServerConfig {
command: Some(
serde_json::from_value::<VsCodeServerCommand>(v.clone())
.ok()?
.into(),
),
settings: None,
},
))
}));
}
}
}

View File

@ -30,7 +30,6 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
dirs.workspace = true
dotenv.workspace = true
env_logger.workspace = true

View File

@ -423,7 +423,6 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
context_server::init(cx);
prompt_store::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);

View File

@ -362,6 +362,8 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
server_id: Arc<str>,
cx: &mut App,
);
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
}
impl ExtensionContextServerProxy for ExtensionHostProxy {
@ -377,6 +379,14 @@ impl ExtensionContextServerProxy for ExtensionHostProxy {
proxy.register_context_server(extension, server_id, cx)
}
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
let Some(proxy) = self.context_server_proxy.read().clone() else {
return;
};
proxy.unregister_context_server(server_id, cx)
}
}
pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static {

View File

@ -22,7 +22,6 @@ async-tar.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
context_server_settings.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true

View File

@ -1130,6 +1130,10 @@ impl ExtensionStore {
.remove_language_server(&language, language_server_name);
}
}
for (server_id, _) in extension.manifest.context_servers.iter() {
self.proxy.unregister_context_server(server_id.clone(), cx);
}
}
self.wasm_extensions

View File

@ -7,7 +7,6 @@ use anyhow::{Context, Result, anyhow, bail};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
use context_server_settings::ContextServerSettings;
use extension::{
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
};
@ -676,21 +675,23 @@ impl ExtensionImports for WasmState {
})?)
}
"context_servers" => {
let settings = key
let configuration = key
.and_then(|key| {
ContextServerSettings::get(location, cx)
ProjectSettings::get(location, cx)
.context_servers
.get(key.as_str())
})
.cloned()
.unwrap_or_default();
Ok(serde_json::to_string(&settings::ContextServerSettings {
command: settings.command.map(|command| settings::CommandSettings {
path: Some(command.path),
arguments: Some(command.args),
env: command.env.map(|env| env.into_iter().collect()),
command: configuration.command.map(|command| {
settings::CommandSettings {
path: Some(command.path),
arguments: Some(command.args),
env: command.env.map(|env| env.into_iter().collect()),
}
}),
settings: settings.settings,
settings: configuration.settings,
})?)
}
_ => {

View File

@ -36,6 +36,7 @@ circular-buffer.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
context_server.workspace = true
dap.workspace = true
extension.workspace = true
fancy-regex.workspace = true

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,21 @@
use std::sync::Arc;
use anyhow::Result;
use context_server::ContextServerCommand;
use extension::{
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
ProjectDelegate,
};
use gpui::{App, AsyncApp, Entity, Task};
use project::Project;
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
use crate::worktree_store::WorktreeStore;
use super::registry::{self, ContextServerDescriptorRegistry};
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
context_server_factory_registry: ContextServerDescriptorRegistry::default_global(cx),
});
}
@ -32,10 +34,13 @@ struct ContextServerDescriptor {
extension: Arc<dyn Extension>,
}
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
project.update(cx, |project, cx| {
fn extension_project(
worktree_store: Entity<WorktreeStore>,
cx: &mut AsyncApp,
) -> Result<Arc<ExtensionProject>> {
worktree_store.update(cx, |worktree_store, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
worktree_ids: worktree_store
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
@ -44,11 +49,15 @@ fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<
}
impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
fn command(
&self,
worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<ContextServerCommand>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let extension_project = extension_project(worktree_store, cx)?;
let mut command = extension
.context_server_command(id.clone(), extension_project.clone())
.await?;
@ -59,7 +68,7 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
Ok(ContextServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
@ -69,13 +78,13 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn configuration(
&self,
project: Entity<Project>,
worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let extension_project = extension_project(worktree_store, cx)?;
let configuration = extension
.context_server_configuration(id.clone(), extension_project)
.await?;
@ -102,4 +111,11 @@ impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
)
});
}
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
self.context_server_factory_registry
.update(cx, |registry, _| {
registry.unregister_context_server_descriptor_by_id(&server_id)
});
}
}

View File

@ -2,17 +2,21 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use context_server::ContextServerCommand;
use extension::ContextServerConfiguration;
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
use project::Project;
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, Task};
use crate::ServerCommand;
use crate::worktree_store::WorktreeStore;
pub trait ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
fn command(
&self,
worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<ContextServerCommand>>;
fn configuration(
&self,
project: Entity<Project>,
worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>>;
}
@ -27,11 +31,6 @@ pub struct ContextServerDescriptorRegistry {
}
impl ContextServerDescriptorRegistry {
/// Returns the global [`ContextServerDescriptorRegistry`].
pub fn global(cx: &App) -> Entity<Self> {
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerDescriptorRegistry`].
///
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.

View File

@ -244,9 +244,8 @@ mod tests {
use git::status::{FileStatus, StatusCode, TrackedSummary, UnmergedStatus, UnmergedStatusCode};
use gpui::TestAppContext;
use serde_json::json;
use settings::{Settings as _, SettingsStore};
use settings::SettingsStore;
use util::path;
use worktree::WorktreeSettings;
const CONFLICT: FileStatus = FileStatus::Unmerged(UnmergedStatus {
first_head: UnmergedStatusCode::Updated,
@ -682,7 +681,7 @@ mod tests {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
WorktreeSettings::register(cx);
Project::init_settings(cx);
});
}

View File

@ -1,6 +1,7 @@
pub mod buffer_store;
mod color_extractor;
pub mod connection_manager;
pub mod context_server_store;
pub mod debounced_delay;
pub mod debugger;
pub mod git_store;
@ -23,6 +24,7 @@ mod project_tests;
mod direnv;
mod environment;
use buffer_diff::BufferDiff;
use context_server_store::ContextServerStore;
pub use environment::{EnvironmentErrorMessage, ProjectEnvironmentEvent};
use git_store::{Repository, RepositoryId};
pub mod search_history;
@ -182,6 +184,7 @@ pub struct Project {
client_subscriptions: Vec<client::Subscription>,
worktree_store: Entity<WorktreeStore>,
buffer_store: Entity<BufferStore>,
context_server_store: Entity<ContextServerStore>,
image_store: Entity<ImageStore>,
lsp_store: Entity<LspStore>,
_subscriptions: Vec<gpui::Subscription>,
@ -845,6 +848,7 @@ impl Project {
ToolchainStore::init(&client);
DapStore::init(&client, cx);
BreakpointStore::init(&client);
context_server_store::init(cx);
}
pub fn local(
@ -865,6 +869,9 @@ impl Project {
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
.detach();
let context_server_store =
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
let environment = cx.new(|_| ProjectEnvironment::new(env));
let toolchain_store = cx.new(|cx| {
ToolchainStore::local(
@ -965,6 +972,7 @@ impl Project {
buffer_store,
image_store,
lsp_store,
context_server_store,
join_project_response_message_id: 0,
client_state: ProjectClientState::Local,
git_store,
@ -1025,6 +1033,9 @@ impl Project {
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
.detach();
let context_server_store =
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
let buffer_store = cx.new(|cx| {
BufferStore::remote(
worktree_store.clone(),
@ -1109,6 +1120,7 @@ impl Project {
buffer_store,
image_store,
lsp_store,
context_server_store,
breakpoint_store,
dap_store,
join_project_response_message_id: 0,
@ -1267,6 +1279,8 @@ impl Project {
let image_store = cx.new(|cx| {
ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx)
})?;
let context_server_store =
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?;
let environment = cx.new(|_| ProjectEnvironment::new(None))?;
@ -1360,6 +1374,7 @@ impl Project {
image_store,
worktree_store: worktree_store.clone(),
lsp_store: lsp_store.clone(),
context_server_store,
active_entry: None,
collaborators: Default::default(),
join_project_response_message_id: response.message_id,
@ -1590,6 +1605,10 @@ impl Project {
self.worktree_store.clone()
}
pub fn context_server_store(&self) -> Entity<ContextServerStore> {
self.context_server_store.clone()
}
pub fn buffer_for_id(&self, remote_id: BufferId, cx: &App) -> Option<Entity<Buffer>> {
self.buffer_store.read(cx).get(remote_id)
}

View File

@ -1,5 +1,6 @@
use anyhow::Context as _;
use collections::HashMap;
use context_server::ContextServerCommand;
use dap::adapters::DebugAdapterName;
use fs::Fs;
use futures::StreamExt as _;
@ -51,6 +52,10 @@ pub struct ProjectSettings {
#[serde(default)]
pub dap: HashMap<DebugAdapterName, DapSettings>,
/// Settings for context servers used for AI-related features.
#[serde(default)]
pub context_servers: HashMap<Arc<str>, ContextServerConfiguration>,
/// Configuration for Diagnostics-related features.
#[serde(default)]
pub diagnostics: DiagnosticsSettings,
@ -78,6 +83,19 @@ pub struct DapSettings {
pub binary: Option<String>,
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
pub struct ContextServerConfiguration {
/// The command to run this context server.
///
/// This will override the command set by an extension.
pub command: Option<ContextServerCommand>,
/// The settings for this context server.
///
/// Consult the documentation for the context server to see what settings
/// are supported.
pub settings: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct NodeBinarySettings {
/// The path to the Node binary.
@ -354,6 +372,40 @@ impl Settings for ProjectSettings {
}
}
#[derive(Deserialize)]
struct VsCodeContextServerCommand {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
// note: we don't support envFile and type
}
impl From<VsCodeContextServerCommand> for ContextServerCommand {
fn from(cmd: VsCodeContextServerCommand) -> Self {
Self {
path: cmd.command,
args: cmd.args.unwrap_or_default(),
env: cmd.env,
}
}
}
if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
current
.context_servers
.extend(mcp.iter().filter_map(|(k, v)| {
Some((
k.clone().into(),
ContextServerConfiguration {
command: Some(
serde_json::from_value::<VsCodeContextServerCommand>(v.clone())
.ok()?
.into(),
),
settings: None,
},
))
}));
}
// TODO: translate lsp settings for rust-analyzer and other popular ones to old.lsp
}
}