context_store: Refactor state management (#29910)
Because we instantiated `ContextServerManager` both in `agent` and `assistant-context-editor`, and these two entities track the running MCP servers separately, we were effectively running every MCP server twice. This PR moves the `ContextServerManager` into the project crate (now called `ContextServerStore`). The store can be accessed via a project instance. This ensures that we only instantiate one `ContextServerStore` per project. Also, this PR adds a bunch of tests to ensure that the `ContextServerStore` behaves correctly (Previously there were none). Closes #28714 Closes #29530 Release Notes: - N/A
This commit is contained in:
parent
71f7100083
commit
108005f1b8
31
Cargo.lock
generated
31
Cargo.lock
generated
@ -484,7 +484,6 @@ dependencies = [
|
||||
"client",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"context_server",
|
||||
"ctor",
|
||||
"db",
|
||||
"editor",
|
||||
@ -3328,40 +3327,19 @@ name = "context_server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"async-trait",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"context_server_settings",
|
||||
"extension",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"icons",
|
||||
"language_model",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
"project",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"url",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "context_server_settings"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"gpui",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"url",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@ -5028,7 +5006,6 @@ dependencies = [
|
||||
"clap",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dirs 4.0.0",
|
||||
"dotenv",
|
||||
"env_logger 0.11.8",
|
||||
@ -5181,7 +5158,6 @@ dependencies = [
|
||||
"async-trait",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server_settings",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
@ -11110,6 +11086,7 @@ dependencies = [
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dap",
|
||||
"dap_adapters",
|
||||
"env_logger 0.11.8",
|
||||
|
@ -34,7 +34,6 @@ members = [
|
||||
"crates/component",
|
||||
"crates/component_preview",
|
||||
"crates/context_server",
|
||||
"crates/context_server_settings",
|
||||
"crates/copilot",
|
||||
"crates/credentials_provider",
|
||||
"crates/dap",
|
||||
@ -243,7 +242,6 @@ command_palette_hooks = { path = "crates/command_palette_hooks" }
|
||||
component = { path = "crates/component" }
|
||||
component_preview = { path = "crates/component_preview" }
|
||||
context_server = { path = "crates/context_server" }
|
||||
context_server_settings = { path = "crates/context_server_settings" }
|
||||
copilot = { path = "crates/copilot" }
|
||||
credentials_provider = { path = "crates/credentials_provider" }
|
||||
dap = { path = "crates/dap" }
|
||||
|
@ -3487,7 +3487,6 @@ fn open_editor_at_position(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assistant_tool::{ToolRegistry, ToolWorkingSet};
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{TestAppContext, VisualTestContext};
|
||||
@ -3559,7 +3558,6 @@ mod tests {
|
||||
workspace::init_settings(cx);
|
||||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
});
|
||||
|
@ -1748,7 +1748,6 @@ mod tests {
|
||||
use crate::{Keep, ThreadStore, thread_store};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
|
||||
use project::{FakeFs, Project};
|
||||
@ -1771,7 +1770,6 @@ mod tests {
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
});
|
||||
@ -1928,7 +1926,6 @@ mod tests {
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
workspace::register_project_item::<Editor>(cx);
|
||||
|
@ -7,6 +7,7 @@ mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_server_tool;
|
||||
mod context_store;
|
||||
mod context_strip;
|
||||
mod debug;
|
||||
|
@ -8,13 +8,14 @@ use std::{sync::Arc, time::Duration};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
|
||||
use context_server::ContextServerId;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, ScrollHandle, Subscription, pulsating_between,
|
||||
};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
|
||||
@ -33,8 +34,8 @@ pub struct AssistantConfiguration {
|
||||
fs: Arc<dyn Fs>,
|
||||
focus_handle: FocusHandle,
|
||||
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
expanded_context_server_tools: HashMap<Arc<str>, bool>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
expanded_context_server_tools: HashMap<ContextServerId, bool>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
_registry_subscription: Subscription,
|
||||
scroll_handle: ScrollHandle,
|
||||
@ -44,7 +45,7 @@ pub struct AssistantConfiguration {
|
||||
impl AssistantConfiguration {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
@ -75,7 +76,7 @@ impl AssistantConfiguration {
|
||||
fs,
|
||||
focus_handle,
|
||||
configuration_views_by_provider: HashMap::default(),
|
||||
context_server_manager,
|
||||
context_server_store,
|
||||
expanded_context_server_tools: HashMap::default(),
|
||||
tools,
|
||||
_registry_subscription: registry_subscription,
|
||||
@ -306,7 +307,7 @@ impl AssistantConfiguration {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
|
||||
let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
|
||||
|
||||
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
|
||||
|
||||
@ -322,9 +323,9 @@ impl AssistantConfiguration {
|
||||
.child(Label::new(SUBHEADING).color(Color::Muted)),
|
||||
)
|
||||
.children(
|
||||
context_servers
|
||||
.into_iter()
|
||||
.map(|context_server| self.render_context_server(context_server, window, cx)),
|
||||
context_server_ids.into_iter().map(|context_server_id| {
|
||||
self.render_context_server(context_server_id, window, cx)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
@ -374,19 +375,20 @@ impl AssistantConfiguration {
|
||||
|
||||
fn render_context_server(
|
||||
&self,
|
||||
context_server: Arc<ContextServer>,
|
||||
context_server_id: ContextServerId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl use<> + IntoElement {
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let server_status = self
|
||||
.context_server_manager
|
||||
.context_server_store
|
||||
.read(cx)
|
||||
.status_for_server(&context_server.id());
|
||||
.status_for_server(&context_server_id)
|
||||
.unwrap_or(ContextServerStatus::Stopped);
|
||||
|
||||
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
|
||||
let is_running = matches!(server_status, ContextServerStatus::Running);
|
||||
|
||||
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
|
||||
let error = if let ContextServerStatus::Error(error) = server_status.clone() {
|
||||
Some(error)
|
||||
} else {
|
||||
None
|
||||
@ -394,13 +396,13 @@ impl AssistantConfiguration {
|
||||
|
||||
let are_tools_expanded = self
|
||||
.expanded_context_server_tools
|
||||
.get(&context_server.id())
|
||||
.get(&context_server_id)
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let tools = tools_by_source
|
||||
.get(&ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
id: context_server_id.0.clone().into(),
|
||||
})
|
||||
.map_or([].as_slice(), |tools| tools.as_slice());
|
||||
let tool_count = tools.len();
|
||||
@ -408,7 +410,7 @@ impl AssistantConfiguration {
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server.id()))
|
||||
.id(SharedString::from(context_server_id.0.clone()))
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(border_color)
|
||||
@ -432,7 +434,7 @@ impl AssistantConfiguration {
|
||||
)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
let context_server_id = context_server_id.clone();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_context_server_tools
|
||||
@ -444,14 +446,14 @@ impl AssistantConfiguration {
|
||||
})),
|
||||
)
|
||||
.child(match server_status {
|
||||
Some(ContextServerStatus::Starting) => {
|
||||
ContextServerStatus::Starting => {
|
||||
let color = Color::Success.color(cx);
|
||||
Indicator::dot()
|
||||
.color(Color::Success)
|
||||
.with_animation(
|
||||
SharedString::from(format!(
|
||||
"{}-starting",
|
||||
context_server.id(),
|
||||
context_server_id.0.clone(),
|
||||
)),
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
@ -462,15 +464,17 @@ impl AssistantConfiguration {
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Running) => {
|
||||
ContextServerStatus::Running => {
|
||||
Indicator::dot().color(Color::Success).into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Error(_)) => {
|
||||
ContextServerStatus::Error(_) => {
|
||||
Indicator::dot().color(Color::Error).into_any_element()
|
||||
}
|
||||
None => Indicator::dot().color(Color::Muted).into_any_element(),
|
||||
ContextServerStatus::Stopped => {
|
||||
Indicator::dot().color(Color::Muted).into_any_element()
|
||||
}
|
||||
})
|
||||
.child(Label::new(context_server.id()).ml_0p5())
|
||||
.child(Label::new(context_server_id.0.clone()).ml_0p5())
|
||||
.when(is_running, |this| {
|
||||
this.child(
|
||||
Label::new(if tool_count == 1 {
|
||||
@ -487,32 +491,22 @@ impl AssistantConfiguration {
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
let context_server_manager = self.context_server_store.clone();
|
||||
let context_server_id = context_server_id.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected | ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx).log_err();
|
||||
this.stop_server(&context_server_id, cx).log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(context_server, cx)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
if let Some(server) =
|
||||
this.get_server(&context_server_id)
|
||||
{
|
||||
this.start_server(server, cx).log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
|
@ -1,5 +1,6 @@
|
||||
use context_server::{ContextServerSettings, ServerCommand, ServerConfig};
|
||||
use context_server::ContextServerCommand;
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, prelude::*};
|
||||
use project::project_settings::{ContextServerConfiguration, ProjectSettings};
|
||||
use serde_json::json;
|
||||
use settings::update_settings_file;
|
||||
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*};
|
||||
@ -77,11 +78,11 @@ impl AddContextServerModal {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
update_settings_file::<ContextServerSettings>(fs.clone(), cx, |settings, _| {
|
||||
update_settings_file::<ProjectSettings>(fs.clone(), cx, |settings, _| {
|
||||
settings.context_servers.insert(
|
||||
name.into(),
|
||||
ServerConfig {
|
||||
command: Some(ServerCommand {
|
||||
ContextServerConfiguration {
|
||||
command: Some(ContextServerCommand {
|
||||
path,
|
||||
args,
|
||||
env: None,
|
||||
|
@ -4,9 +4,8 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use context_server::ContextServerId;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
|
||||
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
|
||||
@ -14,6 +13,10 @@ use gpui::{
|
||||
use language::{Language, LanguageRegistry};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use project::{
|
||||
context_server_store::{ContextServerStatus, ContextServerStore},
|
||||
project_settings::{ContextServerConfiguration, ProjectSettings},
|
||||
};
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
@ -23,11 +26,11 @@ use workspace::{ModalView, Workspace};
|
||||
pub(crate) struct ConfigureContextServerModal {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_servers_to_setup: Vec<ConfigureContextServer>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
}
|
||||
|
||||
struct ConfigureContextServer {
|
||||
id: Arc<str>,
|
||||
id: ContextServerId,
|
||||
installation_instructions: Entity<markdown::Markdown>,
|
||||
settings_validator: Option<jsonschema::Validator>,
|
||||
settings_editor: Entity<Editor>,
|
||||
@ -37,9 +40,9 @@ struct ConfigureContextServer {
|
||||
|
||||
impl ConfigureContextServerModal {
|
||||
pub fn new(
|
||||
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
|
||||
configurations: impl Iterator<Item = (ContextServerId, extension::ContextServerConfiguration)>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
jsonc_language: Option<Arc<Language>>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
@ -85,7 +88,7 @@ impl ConfigureContextServerModal {
|
||||
Some(Self {
|
||||
workspace,
|
||||
context_servers_to_setup,
|
||||
context_server_manager,
|
||||
context_server_store,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -126,14 +129,14 @@ impl ConfigureContextServerModal {
|
||||
}
|
||||
let id = configuration.id.clone();
|
||||
|
||||
let settings_changed = context_server::ContextServerSettings::get_global(cx)
|
||||
let settings_changed = ProjectSettings::get_global(cx)
|
||||
.context_servers
|
||||
.get(&id)
|
||||
.get(&id.0)
|
||||
.map_or(true, |config| {
|
||||
config.settings.as_ref() != Some(&settings_value)
|
||||
});
|
||||
|
||||
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
|
||||
let is_running = self.context_server_store.read(cx).status_for_server(&id)
|
||||
== Some(ContextServerStatus::Running);
|
||||
|
||||
if !settings_changed && is_running {
|
||||
@ -143,7 +146,7 @@ impl ConfigureContextServerModal {
|
||||
|
||||
configuration.waiting_for_context_server = true;
|
||||
|
||||
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
|
||||
let task = wait_for_context_server(&self.context_server_store, id.clone(), cx);
|
||||
cx.spawn({
|
||||
let id = id.clone();
|
||||
async move |this, cx| {
|
||||
@ -167,29 +170,25 @@ impl ConfigureContextServerModal {
|
||||
.detach();
|
||||
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
update_settings_file::<context_server::ContextServerSettings>(
|
||||
workspace.read(cx).app_state().fs.clone(),
|
||||
cx,
|
||||
{
|
||||
let id = id.clone();
|
||||
|settings, _| {
|
||||
if let Some(server_config) = settings.context_servers.get_mut(&id) {
|
||||
server_config.settings = Some(settings_value);
|
||||
} else {
|
||||
settings.context_servers.insert(
|
||||
id,
|
||||
context_server::ServerConfig {
|
||||
settings: Some(settings_value),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
}
|
||||
update_settings_file::<ProjectSettings>(workspace.read(cx).app_state().fs.clone(), cx, {
|
||||
let id = id.clone();
|
||||
|settings, _| {
|
||||
if let Some(server_config) = settings.context_servers.get_mut(&id.0) {
|
||||
server_config.settings = Some(settings_value);
|
||||
} else {
|
||||
settings.context_servers.insert(
|
||||
id.0,
|
||||
ContextServerConfiguration {
|
||||
settings: Some(settings_value),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
|
||||
fn complete_setup(&mut self, id: ContextServerId, cx: &mut Context<Self>) {
|
||||
self.context_servers_to_setup.remove(0);
|
||||
cx.notify();
|
||||
|
||||
@ -223,31 +222,40 @@ impl ConfigureContextServerModal {
|
||||
}
|
||||
|
||||
fn wait_for_context_server(
|
||||
context_server_manager: &Entity<ContextServerManager>,
|
||||
context_server_id: Arc<str>,
|
||||
context_server_store: &Entity<ContextServerStore>,
|
||||
context_server_id: ContextServerId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), Arc<str>>> {
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
let tx = Arc::new(Mutex::new(Some(tx)));
|
||||
|
||||
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Ok(()));
|
||||
let subscription = cx.subscribe(context_server_store, move |_, event, _cx| match event {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
ContextServerStatus::Running => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ContextServerStatus::Error(error)) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Err(error.clone()));
|
||||
ContextServerStatus::Stopped => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Err("Context server stopped running".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
ContextServerStatus::Error(error) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Err(error.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
|
@ -1026,14 +1026,14 @@ impl AssistantPanel {
|
||||
}
|
||||
|
||||
pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let context_server_manager = self.thread_store.read(cx).context_server_manager();
|
||||
let context_server_store = self.project.read(cx).context_server_store();
|
||||
let tools = self.thread_store.read(cx).tools();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
self.set_active_view(ActiveView::Configuration, window, cx);
|
||||
self.configuration =
|
||||
Some(cx.new(|cx| {
|
||||
AssistantConfiguration::new(fs, context_server_manager, tools, window, cx)
|
||||
AssistantConfiguration::new(fs, context_server_store, tools, window, cx)
|
||||
}));
|
||||
|
||||
if let Some(configuration) = self.configuration.as_ref() {
|
||||
|
@ -1095,9 +1095,7 @@ mod tests {
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(language_settings::init);
|
||||
init_test(cx);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
@ -1167,8 +1165,7 @@ mod tests {
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
init_test(cx);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
@ -1237,9 +1234,7 @@ mod tests {
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
init_test(cx);
|
||||
|
||||
let text = concat!(
|
||||
"fn main() {\n",
|
||||
@ -1305,9 +1300,7 @@ mod tests {
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
init_test(cx);
|
||||
|
||||
let text = indoc! {"
|
||||
func main() {
|
||||
@ -1367,9 +1360,7 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
init_test(cx);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
@ -1473,6 +1464,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(language_settings::init);
|
||||
}
|
||||
|
||||
fn simulate_response_stream(
|
||||
codegen: Entity<CodegenAlternative>,
|
||||
cx: &mut TestAppContext,
|
||||
|
@ -1,14 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use context_server::ContextServerId;
|
||||
use extension::ExtensionManifest;
|
||||
use language::LanguageRegistry;
|
||||
use project::context_server_store::registry::ContextServerDescriptorRegistry;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
|
||||
use crate::assistant_configuration::ConfigureContextServerModal;
|
||||
|
||||
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
|
||||
cx.observe_new(move |_: &mut Workspace, window, cx| {
|
||||
@ -60,18 +61,10 @@ fn show_configure_mcp_modal(
|
||||
window: &mut Window,
|
||||
cx: &mut Context<'_, Workspace>,
|
||||
) {
|
||||
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
|
||||
panel
|
||||
.read(cx)
|
||||
.thread_store()
|
||||
.read(cx)
|
||||
.context_server_manager()
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let context_server_store = workspace.project().read(cx).context_server_store();
|
||||
|
||||
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
|
||||
let project = workspace.project().clone();
|
||||
let registry = ContextServerDescriptorRegistry::default_global(cx).read(cx);
|
||||
let worktree_store = workspace.project().read(cx).worktree_store();
|
||||
let configuration_tasks = manifest
|
||||
.context_servers
|
||||
.keys()
|
||||
@ -80,15 +73,15 @@ fn show_configure_mcp_modal(
|
||||
|key| {
|
||||
let descriptor = registry.context_server_descriptor(&key)?;
|
||||
Some(cx.spawn({
|
||||
let project = project.clone();
|
||||
let worktree_store = worktree_store.clone();
|
||||
async move |_, cx| {
|
||||
descriptor
|
||||
.configuration(project, &cx)
|
||||
.configuration(worktree_store.clone(), &cx)
|
||||
.await
|
||||
.context("Failed to resolve context server configuration")
|
||||
.log_err()
|
||||
.flatten()
|
||||
.map(|config| (key, config))
|
||||
.map(|config| (ContextServerId(key), config))
|
||||
}
|
||||
}))
|
||||
}
|
||||
@ -104,8 +97,8 @@ fn show_configure_mcp_modal(
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let modal = ConfigureContextServerModal::new(
|
||||
descriptors.into_iter().flatten(),
|
||||
context_server_store,
|
||||
jsonc_language,
|
||||
context_server_manager,
|
||||
language_registry,
|
||||
cx.entity().downgrade(),
|
||||
window,
|
||||
|
@ -2,29 +2,27 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
||||
use context_server::{ContextServerId, types};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use icons::IconName;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
|
||||
use crate::manager::ContextServerManager;
|
||||
use crate::types;
|
||||
use project::{Project, context_server_store::ContextServerStore};
|
||||
use ui::IconName;
|
||||
|
||||
pub struct ContextServerTool {
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server_id: Arc<str>,
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: types::Tool,
|
||||
}
|
||||
|
||||
impl ContextServerTool {
|
||||
pub fn new(
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server_id: impl Into<Arc<str>>,
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: types::Tool,
|
||||
) -> Self {
|
||||
Self {
|
||||
server_manager,
|
||||
server_id: server_id.into(),
|
||||
store,
|
||||
server_id,
|
||||
tool,
|
||||
}
|
||||
}
|
||||
@ -45,7 +43,7 @@ impl Tool for ContextServerTool {
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
ToolSource::ContextServer {
|
||||
id: self.server_id.clone().into(),
|
||||
id: self.server_id.clone().0.into(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,7 +78,7 @@ impl Tool for ContextServerTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
|
||||
if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
|
||||
let tool_name = self.tool.name.clone();
|
||||
let server_clone = server.clone();
|
||||
let input_clone = input.clone();
|
@ -2660,7 +2660,6 @@ mod tests {
|
||||
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::ToolRegistry;
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
use gpui::TestAppContext;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
@ -3082,7 +3081,6 @@ fn main() {{
|
||||
workspace::init_settings(cx);
|
||||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
});
|
||||
|
@ -9,8 +9,7 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
|
||||
use context_server::ContextServerId;
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
@ -21,6 +20,7 @@ use gpui::{
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
|
||||
@ -30,6 +30,7 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::context_server_tool::ContextServerTool;
|
||||
use crate::thread::{
|
||||
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||
};
|
||||
@ -62,8 +63,7 @@ pub struct ThreadStore {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
project_context: SharedProjectContext,
|
||||
reload_system_prompt_tx: mpsc::Sender<()>,
|
||||
@ -108,11 +108,6 @@ impl ThreadStore {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Self, oneshot::Receiver<()>) {
|
||||
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
|
||||
let mut subscriptions = vec![
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
@ -159,7 +154,6 @@ impl ThreadStore {
|
||||
tools,
|
||||
prompt_builder,
|
||||
prompt_store,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
project_context: SharedProjectContext::default(),
|
||||
@ -354,10 +348,6 @@ impl ThreadStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
@ -494,11 +484,17 @@ impl ThreadStore {
|
||||
});
|
||||
|
||||
if profile.enable_all_context_servers {
|
||||
for context_server in self.context_server_manager.read(cx).all_servers() {
|
||||
for context_server_id in self
|
||||
.project
|
||||
.read(cx)
|
||||
.context_server_store()
|
||||
.read(cx)
|
||||
.all_server_ids()
|
||||
{
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable_source(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
id: context_server_id.0.into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
@ -541,7 +537,7 @@ impl ThreadStore {
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||
cx.subscribe(
|
||||
&self.context_server_manager.clone(),
|
||||
&self.project.read(cx).context_server_store(),
|
||||
Self::handle_context_server_event,
|
||||
)
|
||||
.detach();
|
||||
@ -549,18 +545,19 @@ impl ThreadStore {
|
||||
|
||||
fn handle_context_server_event(
|
||||
&mut self,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
event: &context_server::manager::Event,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
event: &project::context_server_store::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let tool_working_set = self.tools.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
ContextServerStatus::Running => {
|
||||
if let Some(server) =
|
||||
context_server_store.read(cx).get_running_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
let context_server_manager = context_server_store.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
@ -608,7 +605,7 @@ impl ThreadStore {
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
None => {
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
|
@ -31,7 +31,6 @@ async-watch.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
context_server.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
|
@ -106,7 +106,6 @@ pub fn init(
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_server::init(cx);
|
||||
|
||||
register_slash_commands(cx);
|
||||
inline_assistant::init(
|
||||
|
@ -1192,21 +1192,19 @@ impl AssistantPanel {
|
||||
|
||||
fn restart_context_servers(
|
||||
workspace: &mut Workspace,
|
||||
_action: &context_server::Restart,
|
||||
_action: &project::context_server_store::Restart,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
assistant_panel.update(cx, |assistant_panel, cx| {
|
||||
assistant_panel
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.restart_context_servers(cx);
|
||||
});
|
||||
});
|
||||
workspace
|
||||
.project()
|
||||
.read(cx)
|
||||
.context_server_store()
|
||||
.update(cx, |store, cx| {
|
||||
for server in store.running_servers() {
|
||||
store.restart_server(&server.id(), cx).log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,15 +7,17 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
|
||||
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use context_server::ContextServerId;
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::StreamExt;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::LanguageRegistry;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
use project::{
|
||||
Project,
|
||||
context_server_store::{ContextServerStatus, ContextServerStore},
|
||||
};
|
||||
use prompt_store::PromptBuilder;
|
||||
use regex::Regex;
|
||||
use rpc::AnyProtoClient;
|
||||
@ -40,8 +42,7 @@ pub struct RemoteContextMetadata {
|
||||
pub struct ContextStore {
|
||||
contexts: Vec<ContextHandle>,
|
||||
contexts_metadata: Vec<SavedContextMetadata>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
|
||||
context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
|
||||
host_contexts: Vec<RemoteContextMetadata>,
|
||||
fs: Arc<dyn Fs>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
@ -98,15 +99,9 @@ impl ContextStore {
|
||||
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
|
||||
|
||||
let this = cx.new(|cx: &mut Context<Self>| {
|
||||
let context_server_factory_registry =
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let mut this = Self {
|
||||
contexts: Vec::new(),
|
||||
contexts_metadata: Vec::new(),
|
||||
context_server_manager,
|
||||
context_server_slash_command_ids: HashMap::default(),
|
||||
host_contexts: Vec::new(),
|
||||
fs,
|
||||
@ -802,22 +797,9 @@ impl ContextStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
|
||||
cx.update_entity(
|
||||
&self.context_server_manager,
|
||||
|context_server_manager, cx| {
|
||||
for server in context_server_manager.running_servers() {
|
||||
context_server_manager
|
||||
.restart_server(&server.id(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||
cx.subscribe(
|
||||
&self.context_server_manager.clone(),
|
||||
&self.project.read(cx).context_server_store(),
|
||||
Self::handle_context_server_event,
|
||||
)
|
||||
.detach();
|
||||
@ -825,16 +807,18 @@ impl ContextStore {
|
||||
|
||||
fn handle_context_server_event(
|
||||
&mut self,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
event: &context_server::manager::Event,
|
||||
context_server_manager: Entity<ContextServerStore>,
|
||||
event: &project::context_server_store::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let slash_command_working_set = self.slash_commands.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
ContextServerStatus::Running => {
|
||||
if let Some(server) = context_server_manager
|
||||
.read(cx)
|
||||
.get_running_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
@ -858,7 +842,7 @@ impl ContextStore {
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
server.id(),
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
@ -877,7 +861,7 @@ impl ContextStore {
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
None => {
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
|
@ -4,12 +4,10 @@ use assistant_slash_command::{
|
||||
SlashCommandOutputSection, SlashCommandResult,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use context_server::{
|
||||
manager::{ContextServer, ContextServerManager},
|
||||
types::Prompt,
|
||||
};
|
||||
use context_server::{ContextServerId, types::Prompt};
|
||||
use gpui::{App, Entity, Task, WeakEntity, Window};
|
||||
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
|
||||
use project::context_server_store::ContextServerStore;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use text::LineEnding;
|
||||
@ -19,21 +17,17 @@ use workspace::Workspace;
|
||||
use crate::create_label_for_command;
|
||||
|
||||
pub struct ContextServerSlashCommand {
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server_id: Arc<str>,
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
prompt: Prompt,
|
||||
}
|
||||
|
||||
impl ContextServerSlashCommand {
|
||||
pub fn new(
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server: &Arc<ContextServer>,
|
||||
prompt: Prompt,
|
||||
) -> Self {
|
||||
pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
|
||||
Self {
|
||||
server_id: server.id(),
|
||||
server_id: id,
|
||||
prompt,
|
||||
server_manager,
|
||||
store,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -88,7 +82,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||
let server_id = self.server_id.clone();
|
||||
let prompt_name = self.prompt.name.clone();
|
||||
|
||||
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
|
||||
if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let Some(protocol) = server.client() else {
|
||||
return Err(anyhow!("Context server not initialized"));
|
||||
@ -142,8 +136,8 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||
Err(e) => return Task::ready(Err(e)),
|
||||
};
|
||||
|
||||
let manager = self.server_manager.read(cx);
|
||||
if let Some(server) = manager.get_server(&server_id) {
|
||||
let store = self.store.read(cx);
|
||||
if let Some(server) = store.get_running_server(&server_id) {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let Some(protocol) = server.client() else {
|
||||
return Err(anyhow!("Context server not initialized"));
|
||||
|
@ -6709,8 +6709,6 @@ async fn test_context_collaboration_with_reconnect(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
cx_a.update(context_server::init);
|
||||
cx_b.update(context_server::init);
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context_store_a = cx_a
|
||||
.update(|cx| {
|
||||
|
@ -712,6 +712,7 @@ impl TestClient {
|
||||
worktree
|
||||
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
(project, worktree.read_with(cx, |tree, _| tree.id()))
|
||||
}
|
||||
|
||||
|
@ -13,28 +13,17 @@ path = "src/context_server.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-trait.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
context_server_settings.workspace = true
|
||||
extension.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
icons.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
|
@ -40,7 +40,7 @@ pub enum RequestId {
|
||||
Str(String),
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
pub(crate) struct Client {
|
||||
server_id: ContextServerId,
|
||||
next_id: AtomicI32,
|
||||
outbound_tx: channel::Sender<String>,
|
||||
@ -59,7 +59,7 @@ pub struct Client {
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct ContextServerId(pub Arc<str>);
|
||||
pub(crate) struct ContextServerId(pub Arc<str>);
|
||||
|
||||
fn is_null_value<T: Serialize>(value: &T) -> bool {
|
||||
if let Ok(Value::Null) = serde_json::to_value(value) {
|
||||
@ -367,6 +367,7 @@ impl Client {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
@ -375,14 +376,6 @@ impl Client {
|
||||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn server_id(&self) -> ContextServerId {
|
||||
self.server_id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
|
@ -1,30 +1,117 @@
|
||||
pub mod client;
|
||||
mod context_server_tool;
|
||||
mod extension_context_server;
|
||||
pub mod manager;
|
||||
pub mod protocol;
|
||||
mod registry;
|
||||
mod transport;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
|
||||
use gpui::{App, actions};
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::context_server_tool::ContextServerTool;
|
||||
pub use crate::registry::ContextServerDescriptorRegistry;
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use parking_lot::RwLock;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
actions!(context_servers, [Restart]);
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ContextServerId(pub Arc<str>);
|
||||
|
||||
/// The namespace for the context servers actions.
|
||||
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
context_server_settings::init(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
extension_context_server::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
|
||||
});
|
||||
impl Display for ContextServerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ContextServerCommand {
|
||||
pub path: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
enum ContextServerTransport {
|
||||
Stdio(ContextServerCommand),
|
||||
Custom(Arc<dyn crate::transport::Transport>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
id: ContextServerId,
|
||||
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
configuration: ContextServerTransport,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
|
||||
Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
configuration: ContextServerTransport::Stdio(command),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
configuration: ContextServerTransport::Custom(transport),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> ContextServerId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = match &self.configuration {
|
||||
ContextServerTransport::Stdio(command) => Client::stdio(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?,
|
||||
ContextServerTransport::Custom(transport) => Client::new(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
self.id().0,
|
||||
transport.clone(),
|
||||
cx.clone(),
|
||||
)?,
|
||||
};
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||
|
||||
log::debug!(
|
||||
"context server {} initialized: {:?}",
|
||||
self.id,
|
||||
initialized_protocol.initialize,
|
||||
);
|
||||
|
||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop(&self) -> Result<()> {
|
||||
let mut client = self.client.write();
|
||||
if let Some(protocol) = client.take() {
|
||||
drop(protocol);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,584 +0,0 @@
|
||||
//! This module implements a context server management system for Zed.
|
||||
//!
|
||||
//! It provides functionality to:
|
||||
//! - Define and load context server settings
|
||||
//! - Manage individual context servers (start, stop, restart)
|
||||
//! - Maintain a global manager for all context servers
|
||||
//!
|
||||
//! Key components:
|
||||
//! - `ContextServerSettings`: Defines the structure for server configurations
|
||||
//! - `ContextServer`: Represents an individual context server
|
||||
//! - `ContextServerManager`: Manages multiple context servers
|
||||
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
|
||||
//!
|
||||
//! The module also includes initialization logic to set up the context server system
|
||||
//! and react to changes in settings.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use collections::HashMap;
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use log;
|
||||
use parking_lot::RwLock;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::{ContextServerSettings, ServerConfig};
|
||||
|
||||
use crate::{
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
|
||||
client::{self, Client},
|
||||
types,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ContextServerStatus {
|
||||
Starting,
|
||||
Running,
|
||||
Error(Arc<str>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
pub id: Arc<str>,
|
||||
pub config: Arc<ServerConfig>,
|
||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
transport: Option<Arc<dyn Transport>>,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
config,
|
||||
client: RwLock::new(None),
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
config: Arc::new(ServerConfig::default()),
|
||||
transport: Some(transport),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Arc<str> {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
pub fn config(&self) -> Arc<ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = if let Some(transport) = self.transport.clone() {
|
||||
Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
self.id(),
|
||||
transport,
|
||||
cx.clone(),
|
||||
)?
|
||||
} else {
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
};
|
||||
Client::stdio(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?
|
||||
};
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||
|
||||
log::debug!(
|
||||
"context server {} initialized: {:?}",
|
||||
self.id,
|
||||
initialized_protocol.initialize,
|
||||
);
|
||||
|
||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop(&self) -> Result<()> {
|
||||
let mut client = self.client.write();
|
||||
if let Some(protocol) = client.take() {
|
||||
drop(protocol);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextServerManager {
|
||||
servers: HashMap<Arc<str>, Arc<ContextServer>>,
|
||||
server_status: HashMap<Arc<str>, ContextServerStatus>,
|
||||
project: Entity<Project>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
update_servers_task: Option<Task<Result<()>>>,
|
||||
needs_server_update: bool,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ServerStatusChanged {
|
||||
server_id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for ContextServerManager {}
|
||||
|
||||
impl ContextServerManager {
|
||||
pub fn new(
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let mut this = Self {
|
||||
_subscriptions: vec![
|
||||
cx.observe(®istry, |this, _registry, cx| {
|
||||
this.available_context_servers_changed(cx);
|
||||
}),
|
||||
cx.observe_global::<SettingsStore>(|this, cx| {
|
||||
this.available_context_servers_changed(cx);
|
||||
}),
|
||||
],
|
||||
project,
|
||||
registry,
|
||||
needs_server_update: false,
|
||||
servers: HashMap::default(),
|
||||
server_status: HashMap::default(),
|
||||
update_servers_task: None,
|
||||
};
|
||||
this.available_context_servers_changed(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
|
||||
if self.update_servers_task.is_some() {
|
||||
self.needs_server_update = true;
|
||||
} else {
|
||||
self.update_servers_task = Some(cx.spawn(async move |this, cx| {
|
||||
this.update(cx, |this, _| {
|
||||
this.needs_server_update = false;
|
||||
})?;
|
||||
|
||||
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
|
||||
log::error!("Error maintaining context servers: {}", err);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let has_any_context_servers = !this.running_servers().is_empty();
|
||||
if has_any_context_servers {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
|
||||
});
|
||||
}
|
||||
|
||||
this.update_servers_task.take();
|
||||
if this.needs_server_update {
|
||||
this.available_context_servers_changed(cx);
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
|
||||
self.servers
|
||||
.get(id)
|
||||
.filter(|server| server.client().is_some())
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
|
||||
self.server_status.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn start_server(
|
||||
&self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
|
||||
}
|
||||
|
||||
pub fn stop_server(
|
||||
&mut self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
server.stop().log_err();
|
||||
self.update_server_status(server.id().clone(), None, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let id = id.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
|
||||
let config = server.config();
|
||||
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx))??;
|
||||
let new_server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
Self::run_server(this, new_server, cx).await?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
|
||||
self.servers.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
|
||||
self.servers
|
||||
.values()
|
||||
.filter(|server| server.client().is_some())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
|
||||
let mut desired_servers = HashMap::default();
|
||||
|
||||
let (registry, project) = this.update(cx, |this, cx| {
|
||||
let location = this
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| settings::SettingsLocation {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: Path::new(""),
|
||||
});
|
||||
let settings = ContextServerSettings::get(location, cx);
|
||||
desired_servers = settings.context_servers.clone();
|
||||
|
||||
(this.registry.clone(), this.project.clone())
|
||||
})?;
|
||||
|
||||
for (id, descriptor) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
|
||||
{
|
||||
let config = desired_servers.entry(id).or_default();
|
||||
if config.command.is_none() {
|
||||
if let Some(extension_command) =
|
||||
descriptor.command(project.clone(), &cx).await.log_err()
|
||||
{
|
||||
config.command = Some(extension_command);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut servers_to_start = HashMap::default();
|
||||
let mut servers_to_stop = HashMap::default();
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.servers.retain(|id, server| {
|
||||
if desired_servers.contains_key(id) {
|
||||
true
|
||||
} else {
|
||||
servers_to_stop.insert(id.clone(), server.clone());
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
for (id, config) in desired_servers {
|
||||
let existing_config = this.servers.get(&id).map(|server| server.config());
|
||||
if existing_config.as_deref() != Some(&config) {
|
||||
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
|
||||
servers_to_start.insert(id.clone(), server.clone());
|
||||
if let Some(old_server) = this.servers.remove(&id) {
|
||||
servers_to_stop.insert(id, old_server);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
for (_, server) in servers_to_stop {
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
|
||||
}
|
||||
|
||||
for (_, server) in servers_to_start {
|
||||
Self::run_server(this.clone(), server, cx).await.ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(
|
||||
this: WeakEntity<Self>,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let id = server.id();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
|
||||
this.servers.insert(id.clone(), server.clone());
|
||||
})?;
|
||||
|
||||
match server.start(&cx).await {
|
||||
Ok(_) => {
|
||||
log::debug!("`{}` context server started", id);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("`{}` context server failed to start\n{}", id, err);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(
|
||||
id.clone(),
|
||||
Some(ContextServerStatus::Error(err.to_string().into())),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_server_status(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(status) = status.clone() {
|
||||
self.server_status.insert(id.clone(), status);
|
||||
} else {
|
||||
self.server_status.remove(&id);
|
||||
}
|
||||
|
||||
cx.emit(Event::ServerStatusChanged {
|
||||
server_id: id,
|
||||
status,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::types::{
|
||||
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_context_server_status(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
let project = create_test_project(cx, json!({"code.rs": ""})).await;
|
||||
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
|
||||
|
||||
let server_1_id: Arc<str> = "mcp-1".into();
|
||||
let server_2_id: Arc<str> = "mcp-2".into();
|
||||
|
||||
let transport_1 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
|
||||
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_1, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_2_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
}
|
||||
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
||||
serde_json::to_value(&InitializeResponse {
|
||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
server_info: Implementation {
|
||||
name: server_name,
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities::default(),
|
||||
meta: None,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
struct FakeTransport {
|
||||
on_request: Arc<
|
||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
}
|
||||
|
||||
impl FakeTransport {
|
||||
fn new(
|
||||
on_request: impl Fn(
|
||||
u64,
|
||||
Option<RequestType>,
|
||||
serde_json::Value,
|
||||
) -> Option<serde_json::Value>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self {
|
||||
on_request: Arc::new(on_request),
|
||||
tx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FakeTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||
|
||||
if let Some(method) = msg.get("method") {
|
||||
let request_type = method
|
||||
.as_str()
|
||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": payload
|
||||
});
|
||||
|
||||
self.tx
|
||||
.unbounded_send(response.to_string())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
let rx = self.rx.clone();
|
||||
Box::pin(futures::stream::unfold(rx, |rx| async move {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
if let Some(message) = rx_guard.next().await {
|
||||
drop(rx_guard);
|
||||
Some((message, rx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(futures::stream::empty())
|
||||
}
|
||||
}
|
||||
}
|
@ -16,7 +16,7 @@ pub struct ModelContextProtocol {
|
||||
}
|
||||
|
||||
impl ModelContextProtocol {
|
||||
pub fn new(inner: Client) -> Self {
|
||||
pub(crate) fn new(inner: Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
|
@ -610,7 +610,7 @@ pub enum ToolResponseContent {
|
||||
Resource { resource: ResourceContents },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ListToolsResponse {
|
||||
pub tools: Vec<Tool>,
|
||||
|
@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "context_server_settings"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/context_server_settings.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
workspace-hack.workspace = true
|
@ -1 +0,0 @@
|
||||
../../LICENSE-GPL
|
@ -1,99 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use gpui::App;
|
||||
use schemars::JsonSchema;
|
||||
use schemars::r#gen::SchemaGenerator;
|
||||
use schemars::schema::{InstanceType, Schema, SchemaObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
ContextServerSettings::register(cx);
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
|
||||
pub struct ServerConfig {
|
||||
/// The command to run this context server.
|
||||
///
|
||||
/// This will override the command set by an extension.
|
||||
pub command: Option<ServerCommand>,
|
||||
/// The settings for this context server.
|
||||
///
|
||||
/// Consult the documentation for the context server to see what settings
|
||||
/// are supported.
|
||||
#[schemars(schema_with = "server_config_settings_json_schema")]
|
||||
pub settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::Object.into()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ServerCommand {
|
||||
pub path: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ContextServerSettings {
|
||||
/// Settings for context servers used in the Assistant.
|
||||
#[serde(default)]
|
||||
pub context_servers: HashMap<Arc<str>, ServerConfig>,
|
||||
}
|
||||
|
||||
impl Settings for ContextServerSettings {
|
||||
const KEY: Option<&'static str> = None;
|
||||
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(
|
||||
sources: SettingsSources<Self::FileContent>,
|
||||
_: &mut gpui::App,
|
||||
) -> anyhow::Result<Self> {
|
||||
sources.json_merge()
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
// we don't handle "inputs" replacement strings, see perplexity-key in this example:
|
||||
// https://code.visualstudio.com/docs/copilot/chat/mcp-servers#_configuration-example
|
||||
#[derive(Deserialize)]
|
||||
struct VsCodeServerCommand {
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
// note: we don't support envFile and type
|
||||
}
|
||||
impl From<VsCodeServerCommand> for ServerCommand {
|
||||
fn from(cmd: VsCodeServerCommand) -> Self {
|
||||
Self {
|
||||
path: cmd.command,
|
||||
args: cmd.args.unwrap_or_default(),
|
||||
env: cmd.env,
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
|
||||
current
|
||||
.context_servers
|
||||
.extend(mcp.iter().filter_map(|(k, v)| {
|
||||
Some((
|
||||
k.clone().into(),
|
||||
ServerConfig {
|
||||
command: Some(
|
||||
serde_json::from_value::<VsCodeServerCommand>(v.clone())
|
||||
.ok()?
|
||||
.into(),
|
||||
),
|
||||
settings: None,
|
||||
},
|
||||
))
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
@ -30,7 +30,6 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
dirs.workspace = true
|
||||
dotenv.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
@ -423,7 +423,6 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
context_server::init(cx);
|
||||
prompt_store::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
|
@ -362,6 +362,8 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
|
||||
server_id: Arc<str>,
|
||||
cx: &mut App,
|
||||
);
|
||||
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ExtensionHostProxy {
|
||||
@ -377,6 +379,14 @@ impl ExtensionContextServerProxy for ExtensionHostProxy {
|
||||
|
||||
proxy.register_context_server(extension, server_id, cx)
|
||||
}
|
||||
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
|
||||
let Some(proxy) = self.context_server_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.unregister_context_server(server_id, cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static {
|
||||
|
@ -22,7 +22,6 @@ async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server_settings.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
|
@ -1130,6 +1130,10 @@ impl ExtensionStore {
|
||||
.remove_language_server(&language, language_server_name);
|
||||
}
|
||||
}
|
||||
|
||||
for (server_id, _) in extension.manifest.context_servers.iter() {
|
||||
self.proxy.unregister_context_server(server_id.clone(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
self.wasm_extensions
|
||||
|
@ -7,7 +7,6 @@ use anyhow::{Context, Result, anyhow, bail};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use context_server_settings::ContextServerSettings;
|
||||
use extension::{
|
||||
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
|
||||
};
|
||||
@ -676,21 +675,23 @@ impl ExtensionImports for WasmState {
|
||||
})?)
|
||||
}
|
||||
"context_servers" => {
|
||||
let settings = key
|
||||
let configuration = key
|
||||
.and_then(|key| {
|
||||
ContextServerSettings::get(location, cx)
|
||||
ProjectSettings::get(location, cx)
|
||||
.context_servers
|
||||
.get(key.as_str())
|
||||
})
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
Ok(serde_json::to_string(&settings::ContextServerSettings {
|
||||
command: settings.command.map(|command| settings::CommandSettings {
|
||||
path: Some(command.path),
|
||||
arguments: Some(command.args),
|
||||
env: command.env.map(|env| env.into_iter().collect()),
|
||||
command: configuration.command.map(|command| {
|
||||
settings::CommandSettings {
|
||||
path: Some(command.path),
|
||||
arguments: Some(command.args),
|
||||
env: command.env.map(|env| env.into_iter().collect()),
|
||||
}
|
||||
}),
|
||||
settings: settings.settings,
|
||||
settings: configuration.settings,
|
||||
})?)
|
||||
}
|
||||
_ => {
|
||||
|
@ -36,6 +36,7 @@ circular-buffer.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
dap.workspace = true
|
||||
extension.workspace = true
|
||||
fancy-regex.workspace = true
|
||||
|
1129
crates/project/src/context_server_store.rs
Normal file
1129
crates/project/src/context_server_store.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,19 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use context_server::ContextServerCommand;
|
||||
use extension::{
|
||||
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
|
||||
ProjectDelegate,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
|
||||
use crate::worktree_store::WorktreeStore;
|
||||
|
||||
use super::registry::{self, ContextServerDescriptorRegistry};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
|
||||
context_server_factory_registry: ContextServerDescriptorRegistry::default_global(cx),
|
||||
});
|
||||
}
|
||||
|
||||
@ -32,10 +34,13 @@ struct ContextServerDescriptor {
|
||||
extension: Arc<dyn Extension>,
|
||||
}
|
||||
|
||||
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
|
||||
project.update(cx, |project, cx| {
|
||||
fn extension_project(
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Arc<ExtensionProject>> {
|
||||
worktree_store.update(cx, |worktree_store, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
worktree_ids: worktree_store
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
@ -44,11 +49,15 @@ fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<
|
||||
}
|
||||
|
||||
impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
|
||||
fn command(
|
||||
&self,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<ContextServerCommand>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let extension_project = extension_project(worktree_store, cx)?;
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project.clone())
|
||||
.await?;
|
||||
@ -59,7 +68,7 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
Ok(ContextServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
@ -69,13 +78,13 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let extension_project = extension_project(worktree_store, cx)?;
|
||||
let configuration = extension
|
||||
.context_server_configuration(id.clone(), extension_project)
|
||||
.await?;
|
||||
@ -102,4 +111,11 @@ impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
registry.unregister_context_server_descriptor_by_id(&server_id)
|
||||
});
|
||||
}
|
||||
}
|
@ -2,17 +2,21 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerCommand;
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
|
||||
use project::Project;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, Task};
|
||||
|
||||
use crate::ServerCommand;
|
||||
use crate::worktree_store::WorktreeStore;
|
||||
|
||||
pub trait ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
|
||||
fn command(
|
||||
&self,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<ContextServerCommand>>;
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>>;
|
||||
}
|
||||
@ -27,11 +31,6 @@ pub struct ContextServerDescriptorRegistry {
|
||||
}
|
||||
|
||||
impl ContextServerDescriptorRegistry {
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
|
@ -244,9 +244,8 @@ mod tests {
|
||||
use git::status::{FileStatus, StatusCode, TrackedSummary, UnmergedStatus, UnmergedStatusCode};
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
use worktree::WorktreeSettings;
|
||||
|
||||
const CONFLICT: FileStatus = FileStatus::Unmerged(UnmergedStatus {
|
||||
first_head: UnmergedStatusCode::Updated,
|
||||
@ -682,7 +681,7 @@ mod tests {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
WorktreeSettings::register(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub mod buffer_store;
|
||||
mod color_extractor;
|
||||
pub mod connection_manager;
|
||||
pub mod context_server_store;
|
||||
pub mod debounced_delay;
|
||||
pub mod debugger;
|
||||
pub mod git_store;
|
||||
@ -23,6 +24,7 @@ mod project_tests;
|
||||
mod direnv;
|
||||
mod environment;
|
||||
use buffer_diff::BufferDiff;
|
||||
use context_server_store::ContextServerStore;
|
||||
pub use environment::{EnvironmentErrorMessage, ProjectEnvironmentEvent};
|
||||
use git_store::{Repository, RepositoryId};
|
||||
pub mod search_history;
|
||||
@ -182,6 +184,7 @@ pub struct Project {
|
||||
client_subscriptions: Vec<client::Subscription>,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
buffer_store: Entity<BufferStore>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
image_store: Entity<ImageStore>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
@ -845,6 +848,7 @@ impl Project {
|
||||
ToolchainStore::init(&client);
|
||||
DapStore::init(&client, cx);
|
||||
BreakpointStore::init(&client);
|
||||
context_server_store::init(cx);
|
||||
}
|
||||
|
||||
pub fn local(
|
||||
@ -865,6 +869,9 @@ impl Project {
|
||||
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
|
||||
.detach();
|
||||
|
||||
let context_server_store =
|
||||
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
|
||||
|
||||
let environment = cx.new(|_| ProjectEnvironment::new(env));
|
||||
let toolchain_store = cx.new(|cx| {
|
||||
ToolchainStore::local(
|
||||
@ -965,6 +972,7 @@ impl Project {
|
||||
buffer_store,
|
||||
image_store,
|
||||
lsp_store,
|
||||
context_server_store,
|
||||
join_project_response_message_id: 0,
|
||||
client_state: ProjectClientState::Local,
|
||||
git_store,
|
||||
@ -1025,6 +1033,9 @@ impl Project {
|
||||
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
|
||||
.detach();
|
||||
|
||||
let context_server_store =
|
||||
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
|
||||
|
||||
let buffer_store = cx.new(|cx| {
|
||||
BufferStore::remote(
|
||||
worktree_store.clone(),
|
||||
@ -1109,6 +1120,7 @@ impl Project {
|
||||
buffer_store,
|
||||
image_store,
|
||||
lsp_store,
|
||||
context_server_store,
|
||||
breakpoint_store,
|
||||
dap_store,
|
||||
join_project_response_message_id: 0,
|
||||
@ -1267,6 +1279,8 @@ impl Project {
|
||||
let image_store = cx.new(|cx| {
|
||||
ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx)
|
||||
})?;
|
||||
let context_server_store =
|
||||
cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?;
|
||||
|
||||
let environment = cx.new(|_| ProjectEnvironment::new(None))?;
|
||||
|
||||
@ -1360,6 +1374,7 @@ impl Project {
|
||||
image_store,
|
||||
worktree_store: worktree_store.clone(),
|
||||
lsp_store: lsp_store.clone(),
|
||||
context_server_store,
|
||||
active_entry: None,
|
||||
collaborators: Default::default(),
|
||||
join_project_response_message_id: response.message_id,
|
||||
@ -1590,6 +1605,10 @@ impl Project {
|
||||
self.worktree_store.clone()
|
||||
}
|
||||
|
||||
pub fn context_server_store(&self) -> Entity<ContextServerStore> {
|
||||
self.context_server_store.clone()
|
||||
}
|
||||
|
||||
pub fn buffer_for_id(&self, remote_id: BufferId, cx: &App) -> Option<Entity<Buffer>> {
|
||||
self.buffer_store.read(cx).get(remote_id)
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use anyhow::Context as _;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerCommand;
|
||||
use dap::adapters::DebugAdapterName;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt as _;
|
||||
@ -51,6 +52,10 @@ pub struct ProjectSettings {
|
||||
#[serde(default)]
|
||||
pub dap: HashMap<DebugAdapterName, DapSettings>,
|
||||
|
||||
/// Settings for context servers used for AI-related features.
|
||||
#[serde(default)]
|
||||
pub context_servers: HashMap<Arc<str>, ContextServerConfiguration>,
|
||||
|
||||
/// Configuration for Diagnostics-related features.
|
||||
#[serde(default)]
|
||||
pub diagnostics: DiagnosticsSettings,
|
||||
@ -78,6 +83,19 @@ pub struct DapSettings {
|
||||
pub binary: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
|
||||
pub struct ContextServerConfiguration {
|
||||
/// The command to run this context server.
|
||||
///
|
||||
/// This will override the command set by an extension.
|
||||
pub command: Option<ContextServerCommand>,
|
||||
/// The settings for this context server.
|
||||
///
|
||||
/// Consult the documentation for the context server to see what settings
|
||||
/// are supported.
|
||||
pub settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct NodeBinarySettings {
|
||||
/// The path to the Node binary.
|
||||
@ -354,6 +372,40 @@ impl Settings for ProjectSettings {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct VsCodeContextServerCommand {
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
// note: we don't support envFile and type
|
||||
}
|
||||
impl From<VsCodeContextServerCommand> for ContextServerCommand {
|
||||
fn from(cmd: VsCodeContextServerCommand) -> Self {
|
||||
Self {
|
||||
path: cmd.command,
|
||||
args: cmd.args.unwrap_or_default(),
|
||||
env: cmd.env,
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
|
||||
current
|
||||
.context_servers
|
||||
.extend(mcp.iter().filter_map(|(k, v)| {
|
||||
Some((
|
||||
k.clone().into(),
|
||||
ContextServerConfiguration {
|
||||
command: Some(
|
||||
serde_json::from_value::<VsCodeContextServerCommand>(v.clone())
|
||||
.ok()?
|
||||
.into(),
|
||||
),
|
||||
settings: None,
|
||||
},
|
||||
))
|
||||
}));
|
||||
}
|
||||
|
||||
// TODO: translate lsp settings for rust-analyzer and other popular ones to old.lsp
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user