diff --git a/Cargo.lock b/Cargo.lock index 63a57d7df1..98a035ff74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,7 +61,6 @@ dependencies = [ "buffer_diff", "chrono", "client", - "clock", "collections", "command_palette_hooks", "component", @@ -99,6 +98,7 @@ dependencies = [ "prompt_store", "proto", "rand 0.8.5", + "ref-cast", "release_channel", "rope", "rules_library", @@ -11716,6 +11716,26 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "refineable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3f4caf37b7..d90a256a9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -500,6 +500,7 @@ prost-types = "0.9" pulldown-cmark = { version = "0.12.0", default-features = false } quote = "1.0.9" rand = "0.8.5" +ref-cast = "1.0.24" rayon = "1.8" regex = "1.5" repair_json = "0.1.0" diff --git a/clippy.toml b/clippy.toml index 8c8da03a26..e606ad4c79 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,2 +1,7 @@ allow-private-module-inception = true avoid-breaking-exported-api = false +ignore-interior-mutability = [ + # Suppresses clippy::mutable_key_type, which is a false positive as the Eq + # and Hash impls do not use fields with interior mutability. + "agent::context::AgentContextKey" +] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ec366da1ad..4152843d22 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -28,7 +28,6 @@ async-watch.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true -clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -65,6 +64,7 @@ project.workspace = true rules_library.workspace = true prompt_store.workspace = true proto.workspace = true +ref-cast.workspace = true release_channel.workspace = true rope.workspace = true schemars.workspace = true diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 310a283287..0d0e2dd08d 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,4 +1,4 @@ -use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string}; +use crate::context::{AgentContext, RULES_ICON}; use crate::context_picker::MentionLink; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, @@ -25,8 +25,8 @@ use gpui::{ }; use language::{Buffer, LanguageRegistry}; use language_model::{ - LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role, - StopReason, + LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, + RequestUsage, Role, StopReason, }; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; @@ -47,13 +47,10 @@ use util::markdown::MarkdownString; use workspace::{OpenOptions, Workspace}; use zed_actions::assistant::OpenRulesLibrary; -use crate::context_store::ContextStore; - pub struct ActiveThread { language_registry: Arc, thread_store: Entity, thread: Entity, - context_store: Entity, workspace: WeakEntity, save_thread_task: Option>, messages: Vec, @@ -717,7 +714,6 @@ impl ActiveThread { thread: Entity, thread_store: Entity, language_registry: Arc, - context_store: Entity, workspace: WeakEntity, window: &mut Window, cx: &mut Context, @@ -740,7 +736,6 @@ impl ActiveThread { language_registry, thread_store, thread: thread.clone(), - context_store, workspace, save_thread_task: None, messages: Vec::new(), @@ -780,10 +775,6 @@ impl ActiveThread { this } - pub fn context_store(&self) -> &Entity { - &self.context_store - } - pub fn thread(&self) -> &Entity { &self.thread } @@ -1273,26 +1264,36 @@ impl ActiveThread { } let token_count = if let Some(task) = cx.update(|cx| { - let context = thread.read(cx).context_for_message(message_id); - let new_context = thread.read(cx).filter_new_context(context); - let context_text = - format_context_as_string(new_context, cx).unwrap_or(String::new()); + let Some(message) = thread.read(cx).message(message_id) else { + log::error!("Message that was being edited no longer exists"); + return None; + }; let message_text = editor.read(cx).text(cx); - let content = context_text + &message_text; - - if content.is_empty() { + if message_text.is_empty() && message.loaded_context.is_empty() { return None; } + let mut request_message = LanguageModelRequestMessage { + role: language_model::Role::User, + content: Vec::new(), + cache: false, + }; + + message + .loaded_context + .add_to_request_message(&mut request_message); + + if !message_text.is_empty() { + request_message + .content + .push(MessageContent::Text(message_text)); + } + let request = language_model::LanguageModelRequest { thread_id: None, prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: language_model::Role::User, - content: vec![content.into()], - cache: false, - }], + messages: vec![request_message], tools: vec![], stop: vec![], temperature: None, @@ -1487,13 +1488,21 @@ impl ActiveThread { return Empty.into_any(); }; - let context_store = self.context_store.clone(); let workspace = self.workspace.clone(); let thread = self.thread.read(cx); + let prompt_store = self.thread_store.read(cx).prompt_store().as_ref(); // Get all the data we need from thread before we start using it in closures let checkpoint = thread.checkpoint_for_message(message_id); - let context = thread.context_for_message(message_id).collect::>(); + let added_context = if let Some(workspace) = workspace.upgrade() { + let project = workspace.read(cx).project().read(cx); + thread + .context_for_message(message_id) + .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx)) + .collect::>() + } else { + return Empty.into_any(); + }; let tool_uses = thread.tool_uses_for_message(message_id, cx); let has_tool_uses = !tool_uses.is_empty(); @@ -1641,90 +1650,78 @@ impl ActiveThread { }; let message_is_empty = message.should_display_content(); - let has_content = !message_is_empty || !context.is_empty(); + let has_content = !message_is_empty || !added_context.is_empty(); - let message_content = - has_content.then(|| { - v_flex() - .gap_1() - .when(!message_is_empty, |parent| { - parent.child( - if let Some(edit_message_editor) = edit_message_editor.clone() { - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small.rems(cx); - let line_height = font_size.to_pixels(window.rem_size()) * 1.5; + let message_content = has_content.then(|| { + v_flex() + .gap_1() + .when(!message_is_empty, |parent| { + parent.child( + if let Some(edit_message_editor) = edit_message_editor.clone() { + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small.rems(cx); + let line_height = font_size.to_pixels(window.rem_size()) * 1.5; - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; - div() - .key_context("EditMessageEditor") - .on_action(cx.listener(Self::cancel_editing_message)) - .on_action(cx.listener(Self::confirm_editing_message)) - .min_h_6() - .child(EditorElement::new( - &edit_message_editor, - EditorStyle { - background: colors.editor_background, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - )) - .into_any() - } else { - div() - .min_h_6() - .child(self.render_message_content( - message_id, - rendered_message, - has_tool_uses, - workspace.clone(), - window, - cx, - )) - .into_any() - }, - ) - }) - .when(!context.is_empty(), |parent| { - parent.child(h_flex().flex_wrap().gap_1().children( - context.into_iter().map(|context| { - let context_id = context.id(); - ContextPill::added( - AddedContext::new(context, cx), - false, - false, - None, - ) - .on_click(Rc::new(cx.listener({ + div() + .key_context("EditMessageEditor") + .on_action(cx.listener(Self::cancel_editing_message)) + .on_action(cx.listener(Self::confirm_editing_message)) + .min_h_6() + .child(EditorElement::new( + &edit_message_editor, + EditorStyle { + background: colors.editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + )) + .into_any() + } else { + div() + .min_h_6() + .child(self.render_message_content( + message_id, + rendered_message, + has_tool_uses, + workspace.clone(), + window, + cx, + )) + .into_any() + }, + ) + }) + .when(!added_context.is_empty(), |parent| { + parent.child(h_flex().flex_wrap().gap_1().children( + added_context.into_iter().map(|added_context| { + let context = added_context.context.clone(); + ContextPill::added(added_context, false, false, None).on_click(Rc::new( + cx.listener({ let workspace = workspace.clone(); - let context_store = context_store.clone(); move |_, _, window, cx| { if let Some(workspace) = workspace.upgrade() { - open_context( - context_id, - context_store.clone(), - workspace, - window, - cx, - ); + open_context(&context, workspace, window, cx); cx.notify(); } } - }))) - }), - )) - }) - }); + }), + )) + }), + )) + }) + }); let styled_message = match message.role { Role::User => v_flex() @@ -3173,20 +3170,14 @@ impl Render for ActiveThread { } pub(crate) fn open_context( - id: ContextId, - context_store: Entity, + context: &AgentContext, workspace: Entity, window: &mut Window, cx: &mut App, ) { - let Some(context) = context_store.read(cx).context_for_id(id) else { - return; - }; - match context { - AssistantContext::File(file_context) => { - if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx) - { + AgentContext::File(file_context) => { + if let Some(project_path) = file_context.project_path(cx) { workspace.update(cx, |workspace, cx| { workspace .open_path(project_path, None, true, window, cx) @@ -3194,7 +3185,8 @@ pub(crate) fn open_context( }); } } - AssistantContext::Directory(directory_context) => { + + AgentContext::Directory(directory_context) => { let entry_id = directory_context.entry_id; workspace.update(cx, |workspace, cx| { workspace.project().update(cx, |_project, cx| { @@ -3202,61 +3194,51 @@ pub(crate) fn open_context( }) }) } - AssistantContext::Symbol(symbol_context) => { - if let Some(project_path) = symbol_context - .context_symbol - .buffer - .read(cx) - .project_path(cx) - { - let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot(); - let target_position = symbol_context - .context_symbol - .id - .range - .start - .to_point(&snapshot); + AgentContext::Symbol(symbol_context) => { + let buffer = symbol_context.buffer.read(cx); + if let Some(project_path) = buffer.project_path(cx) { + let snapshot = buffer.snapshot(); + let target_position = symbol_context.range.start.to_point(&snapshot); open_editor_at_position(project_path, target_position, &workspace, window, cx) .detach(); } } - AssistantContext::Selection(selection_context) => { - if let Some(project_path) = selection_context - .context_buffer - .buffer - .read(cx) - .project_path(cx) - { - let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot(); + + AgentContext::Selection(selection_context) => { + let buffer = selection_context.buffer.read(cx); + if let Some(project_path) = buffer.project_path(cx) { + let snapshot = buffer.snapshot(); let target_position = selection_context.range.start.to_point(&snapshot); open_editor_at_position(project_path, target_position, &workspace, window, cx) .detach(); } } - AssistantContext::FetchedUrl(fetched_url_context) => { + + AgentContext::FetchedUrl(fetched_url_context) => { cx.open_url(&fetched_url_context.url); } - AssistantContext::Thread(thread_context) => { - let thread_id = thread_context.thread.read(cx).id().clone(); - workspace.update(cx, |workspace, cx| { - if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel - .open_thread(&thread_id, window, cx) - .detach_and_log_err(cx) - }); - } - }) - } - AssistantContext::Rules(rules_context) => window.dispatch_action( + + AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + let thread_id = thread_context.thread.read(cx).id().clone(); + panel + .open_thread(&thread_id, window, cx) + .detach_and_log_err(cx) + }); + } + }), + + AgentContext::Rules(rules_context) => window.dispatch_action( Box::new(OpenRulesLibrary { prompt_to_select: Some(rules_context.prompt_id.0), }), cx, ), - AssistantContext::Image(_) => {} + + AgentContext::Image(_) => {} } } diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index ad2aee93ab..0e66308f37 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -962,11 +962,13 @@ mod tests { }) .unwrap(); + let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), cx.new(|_| ToolWorkingSet::default()), + prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), cx, ) diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 03e13d6f68..80c69665fb 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -39,6 +39,7 @@ use thread::ThreadId; pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; +pub use crate::context::{ContextLoadResult, LoadedContext}; pub use crate::inline_assistant::InlineAssistant; pub use crate::thread::{Message, Thread, ThreadEvent}; pub use crate::thread_store::ThreadStore; diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 3e83c1c224..404702c0b2 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -24,7 +24,7 @@ use language::LanguageRegistry; use language_model::{LanguageModelProviderTosView, LanguageModelRegistry}; use language_model_selector::ToggleModelSelector; use project::Project; -use prompt_store::{PromptBuilder, PromptId, UserPromptId}; +use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use settings::{Settings, update_settings_file}; @@ -189,6 +189,7 @@ pub struct AssistantPanel { message_editor: Entity, _active_thread_subscriptions: Vec, context_store: Entity, + prompt_store: Option>, configuration: Option>, configuration_subscription: Option, local_timezone: UtcOffset, @@ -205,14 +206,25 @@ impl AssistantPanel { pub fn load( workspace: WeakEntity, prompt_builder: Arc, - cx: AsyncWindowContext, + mut cx: AsyncWindowContext, ) -> Task>> { + let prompt_store = cx.update(|_window, cx| PromptStore::global(cx)); cx.spawn(async move |cx| { + let prompt_store = match prompt_store { + Ok(prompt_store) => prompt_store.await.ok(), + Err(_) => None, + }; let tools = cx.new(|_| ToolWorkingSet::default())?; let thread_store = workspace .update(cx, |workspace, cx| { let project = workspace.project().clone(); - ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx) + ThreadStore::load( + project, + tools.clone(), + prompt_store.clone(), + prompt_builder.clone(), + cx, + ) })? .await?; @@ -230,7 +242,16 @@ impl AssistantPanel { .await?; workspace.update_in(cx, |workspace, window, cx| { - cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx)) + cx.new(|cx| { + Self::new( + workspace, + thread_store, + context_store, + prompt_store, + window, + cx, + ) + }) }) }) } @@ -239,6 +260,7 @@ impl AssistantPanel { workspace: &Workspace, thread_store: Entity, context_store: Entity, + prompt_store: Option>, window: &mut Window, cx: &mut Context, ) -> Self { @@ -262,6 +284,7 @@ impl AssistantPanel { fs.clone(), workspace.clone(), message_editor_context_store.clone(), + prompt_store.clone(), thread_store.downgrade(), thread.clone(), window, @@ -293,7 +316,6 @@ impl AssistantPanel { thread.clone(), thread_store.clone(), language_registry.clone(), - message_editor_context_store.clone(), workspace.clone(), window, cx, @@ -322,6 +344,7 @@ impl AssistantPanel { message_editor_subscription, ], context_store, + prompt_store, configuration: None, configuration_subscription: None, local_timezone: UtcOffset::from_whole_seconds( @@ -355,6 +378,10 @@ impl AssistantPanel { self.local_timezone } + pub(crate) fn prompt_store(&self) -> &Option> { + &self.prompt_store + } + pub(crate) fn thread_store(&self) -> &Entity { &self.thread_store } @@ -411,7 +438,6 @@ impl AssistantPanel { thread.clone(), self.thread_store.clone(), self.language_registry.clone(), - message_editor_context_store.clone(), self.workspace.clone(), window, cx, @@ -430,6 +456,7 @@ impl AssistantPanel { self.fs.clone(), self.workspace.clone(), message_editor_context_store, + self.prompt_store.clone(), self.thread_store.downgrade(), thread, window, @@ -500,9 +527,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_select.map(|uuid| PromptId::User { - uuid: UserPromptId(uuid), - }), + action + .prompt_to_select + .map(|uuid| UserPromptId(uuid).into()), cx, ) .detach_and_log_err(cx); @@ -598,7 +625,6 @@ impl AssistantPanel { thread.clone(), this.thread_store.clone(), this.language_registry.clone(), - message_editor_context_store.clone(), this.workspace.clone(), window, cx, @@ -617,6 +643,7 @@ impl AssistantPanel { this.fs.clone(), this.workspace.clone(), message_editor_context_store, + this.prompt_store.clone(), this.thread_store.downgrade(), thread, window, @@ -1876,11 +1903,14 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { else { return; }; + let prompt_store = None; + let thread_store = None; assistant.assist( &prompt_editor, self.workspace.clone(), project, - None, + prompt_store, + thread_store, window, cx, ) @@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate { // being updated. cx.defer_in(window, move |panel, window, cx| { if panel.has_active_thread() { - panel.thread.update(cx, |thread, cx| { - thread.context_store().update(cx, |store, cx| { + panel.message_editor.update(cx, |message_editor, cx| { + message_editor.context_store().update(cx, |store, cx| { let buffer = buffer.read(cx); let selection_ranges = selection_ranges .into_iter() @@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate { .collect::>(); for (buffer, range) in selection_ranges { - store - .add_selection(buffer, range, cx) - .detach_and_log_err(cx); + store.add_selection(buffer, range, cx); } }) }) diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index ebdf9e3d9f..4dbadcbaa5 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -1,6 +1,6 @@ -use crate::context::attach_context_to_message; -use crate::context_store::ContextStore; +use crate::context::ContextLoadResult; use crate::inline_prompt_editor::CodegenStatus; +use crate::{context::load_context, context_store::ContextStore}; use anyhow::Result; use client::telemetry::Telemetry; use collections::HashSet; @@ -8,7 +8,7 @@ use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset use futures::{ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join, }; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, @@ -16,7 +16,9 @@ use language_model::{ }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; +use project::Project; use prompt_store::PromptBuilder; +use prompt_store::PromptStore; use rope::Rope; use smol::future::FutureExt; use std::{ @@ -41,6 +43,8 @@ pub struct BufferCodegen { range: Range, initial_transaction_id: Option, context_store: Entity, + project: WeakEntity, + prompt_store: Option>, telemetry: Arc, builder: Arc, pub is_insertion: bool, @@ -52,6 +56,8 @@ impl BufferCodegen { range: Range, initial_transaction_id: Option, context_store: Entity, + project: WeakEntity, + prompt_store: Option>, telemetry: Arc, builder: Arc, cx: &mut Context, @@ -62,6 +68,8 @@ impl BufferCodegen { range.clone(), false, Some(context_store.clone()), + project.clone(), + prompt_store.clone(), Some(telemetry.clone()), builder.clone(), cx, @@ -77,6 +85,8 @@ impl BufferCodegen { range, initial_transaction_id, context_store, + project, + prompt_store, telemetry, builder, }; @@ -155,6 +165,8 @@ impl BufferCodegen { self.range.clone(), false, Some(self.context_store.clone()), + self.project.clone(), + self.prompt_store.clone(), Some(self.telemetry.clone()), self.builder.clone(), cx, @@ -231,13 +243,14 @@ pub struct CodegenAlternative { generation: Task<()>, diff: Diff, context_store: Option>, + project: WeakEntity, + prompt_store: Option>, telemetry: Option>, _subscription: gpui::Subscription, builder: Arc, active: bool, edits: Vec<(Range, String)>, line_operations: Vec, - request: Option, elapsed_time: Option, completion: Option, pub message_id: Option, @@ -251,6 +264,8 @@ impl CodegenAlternative { range: Range, active: bool, context_store: Option>, + project: WeakEntity, + prompt_store: Option>, telemetry: Option>, builder: Arc, cx: &mut Context, @@ -292,6 +307,8 @@ impl CodegenAlternative { generation: Task::ready(()), diff: Diff::default(), context_store, + project, + prompt_store, telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), builder, @@ -299,7 +316,6 @@ impl CodegenAlternative { edits: Vec::new(), line_operations: Vec::new(), range, - request: None, elapsed_time: None, completion: None, } @@ -368,16 +384,18 @@ impl CodegenAlternative { async { Ok(LanguageModelTextStream::default()) }.boxed_local() } else { let request = self.build_request(user_prompt, cx)?; - self.request = Some(request.clone()); - - cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await) + cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await) .boxed_local() }; self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); Ok(()) } - fn build_request(&self, user_prompt: String, cx: &mut App) -> Result { + fn build_request( + &self, + user_prompt: String, + cx: &mut App, + ) -> Result> { let buffer = self.buffer.read(cx).snapshot(cx); let language = buffer.language_at(self.range.start); let language_name = if let Some(language) = language.as_ref() { @@ -410,30 +428,44 @@ impl CodegenAlternative { .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range) .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; - let mut request_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; + let context_task = self.context_store.as_ref().map(|context_store| { + if let Some(project) = self.project.upgrade() { + let context = context_store + .read(cx) + .context() + .cloned() + .collect::>(); + load_context(context, &project, &self.prompt_store, cx) + } else { + Task::ready(ContextLoadResult::default()) + } + }); - if let Some(context_store) = &self.context_store { - attach_context_to_message( - &mut request_message, - context_store.read(cx).context().iter(), - cx, - ); - } + Ok(cx.spawn(async move |_cx| { + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; - request_message.content.push(prompt.into()); + if let Some(context_task) = context_task { + context_task + .await + .loaded_context + .add_to_request_message(&mut request_message); + } - Ok(LanguageModelRequest { - thread_id: None, - prompt_id: None, - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - messages: vec![request_message], - }) + request_message.content.push(prompt.into()); + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + messages: vec![request_message], + } + })) } pub fn handle_stream( @@ -1038,6 +1070,7 @@ impl Diff { #[cfg(test)] mod tests { use super::*; + use fs::FakeFs; use futures::{ Stream, stream::{self}, @@ -1080,12 +1113,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1144,12 +1181,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1211,12 +1252,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1278,12 +1323,16 @@ mod tests { snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1333,12 +1382,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), false, None, + project.downgrade(), + None, None, prompt_builder, cx, diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index cf813a1426..c9101c5615 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -1,34 +1,25 @@ -use std::{ - ops::Range, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::hash::{Hash, Hasher}; +use std::usize; +use std::{ops::Range, path::Path, sync::Arc}; +use collections::HashSet; +use futures::future; use futures::{FutureExt, future::Shared}; -use gpui::{App, Entity, SharedString, Task}; +use gpui::{App, AppContext as _, Entity, SharedString, Task}; use language::Buffer; -use language_model::{LanguageModelImage, LanguageModelRequestMessage}; -use project::{ProjectEntryId, ProjectPath, Worktree}; -use prompt_store::UserPromptId; -use rope::Point; -use serde::{Deserialize, Serialize}; -use text::{Anchor, BufferId}; -use ui::IconName; -use util::post_inc; +use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; +use project::{Project, ProjectEntryId, ProjectPath, Worktree}; +use prompt_store::{PromptStore, UserPromptId}; +use ref_cast::RefCast; +use rope::{Point, Rope}; +use text::{Anchor, OffsetRangeExt as _}; +use ui::{ElementId, IconName}; +use util::{ResultExt as _, post_inc}; use crate::thread::Thread; pub const RULES_ICON: IconName = IconName::Context; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct ContextId(pub(crate) usize); - -impl ContextId { - pub fn post_inc(&mut self) -> Self { - Self(post_inc(&mut self.0)) - } -} - pub enum ContextKind { File, Directory, @@ -55,307 +46,761 @@ impl ContextKind { } } +/// Handle for context that can be added to a user message. +/// +/// This uses IDs that are stable enough for tracking renames and identifying when context has +/// already been added to the thread. To use this in a set, wrap it in `AgentContextKey` to opt in +/// to `PartialEq` and `Hash` impls that use the subset of the fields used for this stable identity. #[derive(Debug, Clone)] -pub enum AssistantContext { +pub enum AgentContext { File(FileContext), Directory(DirectoryContext), Symbol(SymbolContext), + Selection(SelectionContext), FetchedUrl(FetchedUrlContext), Thread(ThreadContext), - Selection(SelectionContext), Rules(RulesContext), Image(ImageContext), } -impl AssistantContext { - pub fn id(&self) -> ContextId { +impl AgentContext { + fn id(&self) -> ContextId { match self { - Self::File(file) => file.id, - Self::Directory(directory) => directory.id, - Self::Symbol(symbol) => symbol.id, - Self::FetchedUrl(url) => url.id, - Self::Thread(thread) => thread.id, - Self::Selection(selection) => selection.id, - Self::Rules(rules) => rules.id, - Self::Image(image) => image.id, + Self::File(context) => context.context_id, + Self::Directory(context) => context.context_id, + Self::Symbol(context) => context.context_id, + Self::Selection(context) => context.context_id, + Self::FetchedUrl(context) => context.context_id, + Self::Thread(context) => context.context_id, + Self::Rules(context) => context.context_id, + Self::Image(context) => context.context_id, } } + + pub fn element_id(&self, name: SharedString) -> ElementId { + ElementId::NamedInteger(name, self.id().0) + } } +/// ID created at time of context add, for use in ElementId. This is not the stable identity of a +/// context, instead that's handled by the `PartialEq` and `Hash` impls of `AgentContextKey`. +#[derive(Debug, Copy, Clone)] +pub struct ContextId(usize); + +impl ContextId { + pub fn zero() -> Self { + ContextId(0) + } + + fn for_lookup() -> Self { + ContextId(usize::MAX) + } + + pub fn post_inc(&mut self) -> Self { + Self(post_inc(&mut self.0)) + } +} + +/// File context provides the entire contents of a file. +/// +/// This holds an `Entity` so that file path renames affect its display and so that it can +/// be opened even if the file has been deleted. An alternative might be to use `ProjectEntryId`, +/// but then when deleted there is no path info or ability to open. #[derive(Debug, Clone)] pub struct FileContext { - pub id: ContextId, - pub context_buffer: ContextBuffer, + pub buffer: Entity, + pub context_id: ContextId, } -#[derive(Debug, Clone)] -pub struct DirectoryContext { - pub id: ContextId, - pub worktree: Entity, - pub entry_id: ProjectEntryId, - pub last_path: Arc, - /// Buffers of the files within the directory. - pub context_buffers: Vec, -} +impl FileContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer + } -impl DirectoryContext { - pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> { - self.worktree.read(cx).entry_for_id(self.entry_id) + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state) } pub fn project_path(&self, cx: &App) -> Option { - let worktree = self.worktree.read(cx); - worktree - .entry_for_id(self.entry_id) - .map(|entry| ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }) + let file = self.buffer.read(cx).file()?; + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("file context missing path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref.as_rope().clone(); + let buffer = self.buffer.clone(); + Some( + cx.background_spawn( + async move { (to_fenced_codeblock(&full_path, rope, None), buffer) }, + ), + ) + } +} + +/// Directory contents provides the entire contents of text files in a directory. +/// +/// This has a `ProjectEntryId` so that it follows renames. +#[derive(Debug, Clone)] +pub struct DirectoryContext { + pub entry_id: ProjectEntryId, + pub context_id: ContextId, +} + +impl DirectoryContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.entry_id == other.entry_id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.entry_id.hash(state) + } + + fn load( + &self, + project: Entity, + cx: &mut App, + ) -> Option)>>> { + let worktree = project.read(cx).worktree_for_entry(self.entry_id, cx)?; + let worktree_ref = worktree.read(cx); + let entry = worktree_ref.entry_for_id(self.entry_id)?; + if entry.is_file() { + log::error!("DirectoryContext unexpectedly refers to a file."); + return None; + } + + let file_paths = collect_files_in_path(worktree_ref, entry.path.as_ref()); + let texts_future = future::join_all(file_paths.into_iter().map(|path| { + load_file_path_text_as_fenced_codeblock(project.clone(), worktree.clone(), path, cx) + })); + + Some(cx.background_spawn(async move { + texts_future.await.into_iter().flatten().collect::>() + })) } } #[derive(Debug, Clone)] pub struct SymbolContext { - pub id: ContextId, - pub context_symbol: ContextSymbol, + pub buffer: Entity, + pub symbol: SharedString, + pub range: Range, + /// The range that fully contain the symbol. e.g. for function symbol, this will include not + /// only the signature, but also the body. Not used by `PartialEq` or `Hash` for `AgentContextKey`. + pub enclosing_range: Range, + pub context_id: ContextId, +} + +impl SymbolContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer && self.symbol == other.symbol && self.range == other.range + } + + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state); + self.symbol.hash(state); + self.range.hash(state); + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("symbol context's file has no path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref + .text_for_range(self.enclosing_range.clone()) + .collect::(); + let line_range = self.enclosing_range.to_point(&buffer_ref.snapshot()); + let buffer = self.buffer.clone(); + Some(cx.background_spawn(async move { + ( + to_fenced_codeblock(&full_path, rope, Some(line_range)), + buffer, + ) + })) + } +} + +#[derive(Debug, Clone)] +pub struct SelectionContext { + pub buffer: Entity, + pub range: Range, + pub context_id: ContextId, +} + +impl SelectionContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer && self.range == other.range + } + + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state); + self.range.hash(state); + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("selection context's file has no path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref + .text_for_range(self.range.clone()) + .collect::(); + let line_range = self.range.to_point(&buffer_ref.snapshot()); + let buffer = self.buffer.clone(); + Some(cx.background_spawn(async move { + ( + to_fenced_codeblock(&full_path, rope, Some(line_range)), + buffer, + ) + })) + } } #[derive(Debug, Clone)] pub struct FetchedUrlContext { - pub id: ContextId, pub url: SharedString, + /// Text contents of the fetched url. Unlike other context types, the contents of this gets + /// populated when added rather than when sending the message. Not used by `PartialEq` or `Hash` + /// for `AgentContextKey`. pub text: SharedString, + pub context_id: ContextId, +} + +impl FetchedUrlContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.url == other.url + } + + pub fn hash_for_key(&self, state: &mut H) { + self.url.hash(state); + } + + pub fn lookup_key(url: SharedString) -> AgentContextKey { + AgentContextKey(AgentContext::FetchedUrl(FetchedUrlContext { + url, + text: "".into(), + context_id: ContextId::for_lookup(), + })) + } } #[derive(Debug, Clone)] pub struct ThreadContext { - pub id: ContextId, - // TODO: Entity holds onto the thread even if the thread is deleted. Should probably be - // a WeakEntity and handle removal from the UI when it has dropped. pub thread: Entity, - pub text: SharedString, + pub context_id: ContextId, } impl ThreadContext { - pub fn summary(&self, cx: &App) -> SharedString { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.thread == other.thread + } + + pub fn hash_for_key(&self, state: &mut H) { + self.thread.hash(state) + } + + pub fn name(&self, cx: &App) -> SharedString { self.thread .read(cx) .summary() - .unwrap_or("New thread".into()) + .unwrap_or_else(|| "New thread".into()) + } + + pub fn load(&self, cx: &App) -> String { + let name = self.name(cx); + let contents = self.thread.read(cx).latest_detailed_summary_or_text(); + let mut text = String::new(); + text.push_str(&name); + text.push('\n'); + text.push_str(&contents.trim()); + text.push('\n'); + text + } +} + +#[derive(Debug, Clone)] +pub struct RulesContext { + pub prompt_id: UserPromptId, + pub context_id: ContextId, +} + +impl RulesContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.prompt_id == other.prompt_id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.prompt_id.hash(state) + } + + pub fn lookup_key(prompt_id: UserPromptId) -> AgentContextKey { + AgentContextKey(AgentContext::Rules(RulesContext { + prompt_id, + context_id: ContextId::for_lookup(), + })) + } + + pub fn load( + &self, + prompt_store: &Option>, + cx: &App, + ) -> Task> { + let Some(prompt_store) = prompt_store.as_ref() else { + return Task::ready(None); + }; + let prompt_store = prompt_store.read(cx); + let prompt_id = self.prompt_id.into(); + let Some(metadata) = prompt_store.metadata(prompt_id) else { + return Task::ready(None); + }; + let contents_task = prompt_store.load(prompt_id, cx); + cx.background_spawn(async move { + let contents = contents_task.await.ok()?; + let mut text = String::new(); + if let Some(title) = metadata.title { + text.push_str("Rules title: "); + text.push_str(&title); + text.push('\n'); + } + text.push_str("``````\n"); + text.push_str(contents.trim()); + text.push_str("\n``````\n"); + Some(text) + }) } } #[derive(Debug, Clone)] pub struct ImageContext { - pub id: ContextId, pub original_image: Arc, + // TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml + // needed due to a false positive of `clippy::mutable_key_type`. pub image_task: Shared>>, + pub context_id: ContextId, +} + +pub enum ImageStatus { + Loading, + Error, + Ready, } impl ImageContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.original_image.id == other.original_image.id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.original_image.id.hash(state); + } + pub fn image(&self) -> Option { self.image_task.clone().now_or_never().flatten() } - pub fn is_loading(&self) -> bool { - self.image_task.clone().now_or_never().is_none() - } - - pub fn is_error(&self) -> bool { - self.image_task - .clone() - .now_or_never() - .map(|result| result.is_none()) - .unwrap_or(false) - } -} - -#[derive(Clone)] -pub struct ContextBuffer { - pub id: BufferId, - // TODO: Entity holds onto the buffer even if the buffer is deleted. Should probably be - // a WeakEntity and handle removal from the UI when it has dropped. - pub buffer: Entity, - pub last_full_path: Arc, - pub version: clock::Global, - pub text: SharedString, -} - -impl ContextBuffer { - pub fn full_path(&self, cx: &App) -> PathBuf { - let file = self.buffer.read(cx).file(); - // Note that in practice file can't be `None` because it is present when this is created and - // there's no way for buffers to go from having a file to not. - file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx)) - } -} - -impl std::fmt::Debug for ContextBuffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ContextBuffer") - .field("id", &self.id) - .field("buffer", &self.buffer) - .field("version", &self.version) - .field("text", &self.text) - .finish() - } -} - -#[derive(Debug, Clone)] -pub struct ContextSymbol { - pub id: ContextSymbolId, - pub buffer: Entity, - pub buffer_version: clock::Global, - /// The range that the symbol encloses, e.g. for function symbol, this will - /// include not only the signature, but also the body - pub enclosing_range: Range, - pub text: SharedString, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ContextSymbolId { - pub path: ProjectPath, - pub name: SharedString, - pub range: Range, -} - -#[derive(Debug, Clone)] -pub struct SelectionContext { - pub id: ContextId, - pub range: Range, - pub line_range: Range, - pub context_buffer: ContextBuffer, -} - -#[derive(Debug, Clone)] -pub struct RulesContext { - pub id: ContextId, - pub prompt_id: UserPromptId, - pub title: SharedString, - pub text: SharedString, -} - -/// Formats a collection of contexts into a string representation -pub fn format_context_as_string<'a>( - contexts: impl Iterator, - cx: &App, -) -> Option { - let mut file_context = Vec::new(); - let mut directory_context = Vec::new(); - let mut symbol_context = Vec::new(); - let mut selection_context = Vec::new(); - let mut fetch_context = Vec::new(); - let mut thread_context = Vec::new(); - let mut rules_context = Vec::new(); - - for context in contexts { - match context { - AssistantContext::File(context) => file_context.push(context), - AssistantContext::Directory(context) => directory_context.push(context), - AssistantContext::Symbol(context) => symbol_context.push(context), - AssistantContext::Selection(context) => selection_context.push(context), - AssistantContext::FetchedUrl(context) => fetch_context.push(context), - AssistantContext::Thread(context) => thread_context.push(context), - AssistantContext::Rules(context) => rules_context.push(context), - AssistantContext::Image(_) => {} + pub fn status(&self) -> ImageStatus { + match self.image_task.clone().now_or_never() { + None => ImageStatus::Loading, + Some(None) => ImageStatus::Error, + Some(Some(_)) => ImageStatus::Ready, } } +} - if file_context.is_empty() - && directory_context.is_empty() - && symbol_context.is_empty() - && selection_context.is_empty() - && fetch_context.is_empty() - && thread_context.is_empty() - && rules_context.is_empty() - { - return None; +#[derive(Debug, Clone, Default)] +pub struct ContextLoadResult { + pub loaded_context: LoadedContext, + pub referenced_buffers: HashSet>, +} + +#[derive(Debug, Clone, Default)] +pub struct LoadedContext { + pub contexts: Vec, + pub text: String, + pub images: Vec, +} + +impl LoadedContext { + pub fn is_empty(&self) -> bool { + self.text.is_empty() && self.images.is_empty() } - let mut result = String::new(); - result.push_str("\n\n\ - The following items were attached by the user. You don't need to use other tools to read them.\n\n"); - - if !file_context.is_empty() { - result.push_str("\n"); - for context in file_context { - result.push_str(&context.context_buffer.text); + pub fn add_to_request_message(&self, request_message: &mut LanguageModelRequestMessage) { + if !self.text.is_empty() { + request_message + .content + .push(MessageContent::Text(self.text.to_string())); } - result.push_str("\n"); - } - if !directory_context.is_empty() { - result.push_str("\n"); - for context in directory_context { - for context_buffer in &context.context_buffers { - result.push_str(&context_buffer.text); + if !self.images.is_empty() { + // Some providers only support image parts after an initial text part + if request_message.content.is_empty() { + request_message + .content + .push(MessageContent::Text("Images attached by user:".to_string())); + } + + for image in &self.images { + request_message + .content + .push(MessageContent::Image(image.clone())) } } - result.push_str("\n"); } +} - if !symbol_context.is_empty() { - result.push_str("\n"); - for context in symbol_context { - result.push_str(&context.context_symbol.text); - result.push('\n'); +/// Loads and formats a collection of contexts. +pub fn load_context( + contexts: Vec, + project: &Entity, + prompt_store: &Option>, + cx: &mut App, +) -> Task { + let mut file_tasks = Vec::new(); + let mut directory_tasks = Vec::new(); + let mut symbol_tasks = Vec::new(); + let mut selection_tasks = Vec::new(); + let mut fetch_context = Vec::new(); + let mut thread_context = Vec::new(); + let mut rules_tasks = Vec::new(); + let mut image_tasks = Vec::new(); + + for context in contexts.iter().cloned() { + match context { + AgentContext::File(context) => file_tasks.extend(context.load(cx)), + AgentContext::Directory(context) => { + directory_tasks.extend(context.load(project.clone(), cx)) + } + AgentContext::Symbol(context) => symbol_tasks.extend(context.load(cx)), + AgentContext::Selection(context) => selection_tasks.extend(context.load(cx)), + AgentContext::FetchedUrl(context) => fetch_context.push(context), + AgentContext::Thread(context) => thread_context.push(context.load(cx)), + AgentContext::Rules(context) => rules_tasks.push(context.load(prompt_store, cx)), + AgentContext::Image(context) => image_tasks.push(context.image_task.clone()), } - result.push_str("\n"); } - if !selection_context.is_empty() { - result.push_str("\n"); - for context in selection_context { - result.push_str(&context.context_buffer.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !fetch_context.is_empty() { - result.push_str("\n"); - for context in &fetch_context { - result.push_str(&context.url); - result.push('\n'); - result.push_str(&context.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !thread_context.is_empty() { - result.push_str("\n"); - for context in &thread_context { - result.push_str(&context.summary(cx)); - result.push('\n'); - result.push_str(&context.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !rules_context.is_empty() { - result.push_str( - "\n\ - The user has specified the following rules that should be applied:\n\n", + cx.background_spawn(async move { + let ( + file_context, + directory_context, + symbol_context, + selection_context, + rules_context, + images, + ) = futures::join!( + future::join_all(file_tasks), + future::join_all(directory_tasks), + future::join_all(symbol_tasks), + future::join_all(selection_tasks), + future::join_all(rules_tasks), + future::join_all(image_tasks) ); - for context in &rules_context { - result.push_str(&context.text); - result.push('\n'); + + let directory_context = directory_context.into_iter().flatten().collect::>(); + let rules_context = rules_context.into_iter().flatten().collect::>(); + let images = images.into_iter().flatten().collect::>(); + + let mut referenced_buffers = HashSet::default(); + let mut text = String::new(); + + if file_context.is_empty() + && directory_context.is_empty() + && symbol_context.is_empty() + && selection_context.is_empty() + && fetch_context.is_empty() + && thread_context.is_empty() + && rules_context.is_empty() + { + return ContextLoadResult { + loaded_context: LoadedContext { + contexts, + text, + images, + }, + referenced_buffers, + }; } - result.push_str("\n"); - } - result.push_str("\n"); - Some(result) + text.push_str( + "\n\n\ + The following items were attached by the user. \ + You don't need to use other tools to read them.\n\n", + ); + + if !file_context.is_empty() { + text.push_str(""); + for (file_text, buffer) in file_context { + text.push('\n'); + text.push_str(&file_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !directory_context.is_empty() { + text.push_str(""); + for (file_text, buffer) in directory_context { + text.push('\n'); + text.push_str(&file_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !symbol_context.is_empty() { + text.push_str(""); + for (symbol_text, buffer) in symbol_context { + text.push('\n'); + text.push_str(&symbol_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !selection_context.is_empty() { + text.push_str(""); + for (selection_text, buffer) in selection_context { + text.push('\n'); + text.push_str(&selection_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !fetch_context.is_empty() { + text.push_str(""); + for context in fetch_context { + text.push('\n'); + text.push_str(&context.url); + text.push('\n'); + text.push_str(&context.text); + } + text.push_str("\n"); + } + + if !thread_context.is_empty() { + text.push_str(""); + for thread_text in thread_context { + text.push('\n'); + text.push_str(&thread_text); + } + text.push_str("\n"); + } + + if !rules_context.is_empty() { + text.push_str( + "\n\ + The user has specified the following rules that should be applied:\n", + ); + for rules_text in rules_context { + text.push('\n'); + text.push_str(&rules_text); + } + text.push_str("\n"); + } + + text.push_str("\n"); + + ContextLoadResult { + loaded_context: LoadedContext { + contexts, + text, + images, + }, + referenced_buffers, + } + }) } -pub fn attach_context_to_message<'a>( - message: &mut LanguageModelRequestMessage, - contexts: impl Iterator, - cx: &App, -) { - if let Some(context_string) = format_context_as_string(contexts, cx) { - message.content.push(context_string.into()); +fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec> { + let mut files = Vec::new(); + + for entry in worktree.child_entries(path) { + if entry.is_dir() { + files.extend(collect_files_in_path(worktree, &entry.path)); + } else if entry.is_file() { + files.push(entry.path.clone()); + } + } + + files +} + +fn load_file_path_text_as_fenced_codeblock( + project: Entity, + worktree: Entity, + path: Arc, + cx: &mut App, +) -> Task)>> { + let worktree_ref = worktree.read(cx); + let worktree_id = worktree_ref.id(); + let full_path = worktree_ref.full_path(&path); + + let open_task = project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + let project_path = ProjectPath { worktree_id, path }; + buffer_store.open_buffer(project_path, cx) + }) + }); + + let rope_task = cx.spawn(async move |cx| { + let buffer = open_task.await.log_err()?; + let rope = buffer + .read_with(cx, |buffer, _cx| buffer.as_rope().clone()) + .log_err()?; + Some((rope, buffer)) + }); + + cx.background_spawn(async move { + let (rope, buffer) = rope_task.await?; + Some((to_fenced_codeblock(&full_path, rope, None), buffer)) + }) +} + +fn to_fenced_codeblock( + full_path: &Path, + content: Rope, + line_range: Option>, +) -> String { + let line_range_text = line_range.map(|range| { + if range.start.row == range.end.row { + format!(":{}", range.start.row + 1) + } else { + format!(":{}-{}", range.start.row + 1, range.end.row + 1) + } + }); + + let path_extension = full_path.extension().and_then(|ext| ext.to_str()); + let path_string = full_path.to_string_lossy(); + let capacity = 3 + + path_extension.map_or(0, |extension| extension.len() + 1) + + path_string.len() + + line_range_text.as_ref().map_or(0, |text| text.len()) + + 1 + + content.len() + + 5; + let mut buffer = String::with_capacity(capacity); + + buffer.push_str("```"); + + if let Some(extension) = path_extension { + buffer.push_str(extension); + buffer.push(' '); + } + buffer.push_str(&path_string); + + if let Some(line_range_text) = line_range_text { + buffer.push_str(&line_range_text); + } + + buffer.push('\n'); + for chunk in content.chunks() { + buffer.push_str(chunk); + } + + if !buffer.ends_with('\n') { + buffer.push('\n'); + } + + buffer.push_str("```\n"); + + debug_assert!( + buffer.len() == capacity - 1 || buffer.len() == capacity, + "to_fenced_codeblock calculated capacity of {}, but length was {}", + capacity, + buffer.len(), + ); + + buffer +} + +/// Wraps `AgentContext` to opt-in to `PartialEq` and `Hash` impls which use a subset of fields +/// needed for stable context identity. +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] +pub struct AgentContextKey(pub AgentContext); + +impl AsRef for AgentContextKey { + fn as_ref(&self) -> &AgentContext { + &self.0 + } +} + +impl Eq for AgentContextKey {} + +impl PartialEq for AgentContextKey { + fn eq(&self, other: &Self) -> bool { + match &self.0 { + AgentContext::File(context) => { + if let AgentContext::File(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Directory(context) => { + if let AgentContext::Directory(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Symbol(context) => { + if let AgentContext::Symbol(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Selection(context) => { + if let AgentContext::Selection(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::FetchedUrl(context) => { + if let AgentContext::FetchedUrl(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Thread(context) => { + if let AgentContext::Thread(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Rules(context) => { + if let AgentContext::Rules(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Image(context) => { + if let AgentContext::Image(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + } + false + } +} + +impl Hash for AgentContextKey { + fn hash(&self, state: &mut H) { + match &self.0 { + AgentContext::File(context) => context.hash_for_key(state), + AgentContext::Directory(context) => context.hash_for_key(state), + AgentContext::Symbol(context) => context.hash_for_key(state), + AgentContext::Selection(context) => context.hash_for_key(state), + AgentContext::FetchedUrl(context) => context.hash_for_key(state), + AgentContext::Thread(context) => context.hash_for_key(state), + AgentContext::Rules(context) => context.hash_for_key(state), + AgentContext::Image(context) => context.hash_for_key(state), + } } } diff --git a/crates/agent/src/context_picker.rs b/crates/agent/src/context_picker.rs index 8e5cca941c..752f8a0af2 100644 --- a/crates/agent/src/context_picker.rs +++ b/crates/agent/src/context_picker.rs @@ -10,8 +10,11 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{Result, anyhow}; +pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, FoldId}; use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use fetch_context_picker::FetchContextPicker; +use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; use gpui::{ App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, @@ -20,10 +23,10 @@ use gpui::{ use language::Buffer; use multi_buffer::MultiBufferRow; use project::{Entry, ProjectPath}; -use prompt_store::UserPromptId; -use rules_context_picker::RulesContextEntry; +use prompt_store::{PromptStore, UserPromptId}; +use rules_context_picker::{RulesContextEntry, RulesContextPicker}; use symbol_context_picker::SymbolContextPicker; -use thread_context_picker::{ThreadContextEntry, render_thread_context_entry}; +use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry}; use ui::{ ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*, }; @@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt}; use crate::AssistantPanel; use crate::context::RULES_ICON; -pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider; -use crate::context_picker::fetch_context_picker::FetchContextPicker; -use crate::context_picker::file_context_picker::FileContextPicker; -use crate::context_picker::rules_context_picker::RulesContextPicker; -use crate::context_picker::thread_context_picker::ThreadContextPicker; use crate::context_store::ContextStore; use crate::thread::ThreadId; use crate::thread_store::ThreadStore; @@ -166,6 +164,7 @@ pub(super) struct ContextPicker { workspace: WeakEntity, context_store: WeakEntity, thread_store: Option>, + prompt_store: Option>, _subscriptions: Vec, } @@ -193,6 +192,13 @@ impl ContextPicker { ) .collect::>(); + let prompt_store = thread_store.as_ref().and_then(|thread_store| { + thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten() + }); + ContextPicker { mode: ContextPickerState::Default(ContextMenu::build( window, @@ -202,6 +208,7 @@ impl ContextPicker { workspace, context_store, thread_store, + prompt_store, _subscriptions: subscriptions, } } @@ -226,7 +233,12 @@ impl ContextPicker { .workspace .upgrade() .map(|workspace| { - available_context_picker_entries(&self.thread_store, &workspace, cx) + available_context_picker_entries( + &self.prompt_store, + &self.thread_store, + &workspace, + cx, + ) }) .unwrap_or_default(); @@ -304,10 +316,10 @@ impl ContextPicker { })); } ContextPickerMode::Rules => { - if let Some(thread_store) = self.thread_store.as_ref() { + if let Some(prompt_store) = self.prompt_store.as_ref() { self.mode = ContextPickerState::Rules(cx.new(|cx| { RulesContextPicker::new( - thread_store.clone(), + prompt_store.clone(), context_picker.clone(), self.context_store.clone(), window, @@ -526,6 +538,7 @@ enum RecentEntry { } fn available_context_picker_entries( + prompt_store: &Option>, thread_store: &Option>, workspace: &Entity, cx: &mut App, @@ -550,6 +563,9 @@ fn available_context_picker_entries( if thread_store.is_some() { entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread)); + } + + if prompt_store.is_some() { entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules)); } @@ -585,22 +601,21 @@ fn recent_context_picker_entries( }), ); - let mut current_threads = context_store.read(cx).thread_ids(); + let current_threads = context_store.read(cx).thread_ids(); - if let Some(active_thread) = workspace + let active_thread_id = workspace .panel::(cx) - .map(|panel| panel.read(cx).active_thread(cx)) - { - current_threads.insert(active_thread.read(cx).id().clone()); - } + .map(|panel| panel.read(cx).active_thread(cx).read(cx).id()); if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) { recent.extend( thread_store .read(cx) - .threads() + .reverse_chronological_threads() .into_iter() - .filter(|thread| !current_threads.contains(&thread.id)) + .filter(|thread| { + Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id) + }) .take(2) .map(|thread| { RecentEntry::Thread(ThreadContextEntry { @@ -622,9 +637,7 @@ fn add_selections_as_context( let selection_ranges = selection_ranges(workspace, cx); context_store.update(cx, |context_store, cx| { for (buffer, range) in selection_ranges { - context_store - .add_selection(buffer, range, cx) - .detach_and_log_err(cx); + context_store.add_selection(buffer, range, cx); } }) } diff --git a/crates/agent/src/context_picker/completion_provider.rs b/crates/agent/src/context_picker/completion_provider.rs index f82dab0a1b..05cecc68e7 100644 --- a/crates/agent/src/context_picker/completion_provider.rs +++ b/crates/agent/src/context_picker/completion_provider.rs @@ -15,22 +15,21 @@ use itertools::Itertools; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId}; -use prompt_store::PromptId; +use prompt_store::PromptStore; use rope::Point; use text::{Anchor, OffsetRangeExt, ToPoint}; use ui::prelude::*; use workspace::Workspace; use crate::context::RULES_ICON; -use crate::context_picker::file_context_picker::search_files; -use crate::context_picker::symbol_context_picker::search_symbols; use crate::context_store::ContextStore; use crate::thread_store::ThreadStore; use super::fetch_context_picker::fetch_url_content; -use super::file_context_picker::FileMatch; +use super::file_context_picker::{FileMatch, search_files}; use super::rules_context_picker::{RulesContextEntry, search_rules}; use super::symbol_context_picker::SymbolMatch; +use super::symbol_context_picker::search_symbols; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry, @@ -38,8 +37,8 @@ use super::{ }; pub(crate) enum Match { - Symbol(SymbolMatch), File(FileMatch), + Symbol(SymbolMatch), Thread(ThreadMatch), Fetch(SharedString), Rules(RulesContextEntry), @@ -69,6 +68,7 @@ fn search( query: String, cancellation_flag: Arc, recent_entries: Vec, + prompt_store: Option>, thread_store: Option>, workspace: Entity, cx: &mut App, @@ -85,6 +85,7 @@ fn search( .collect() }) } + Some(ContextPickerMode::Symbol) => { let search_symbols_task = search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); @@ -96,6 +97,7 @@ fn search( .collect() }) } + Some(ContextPickerMode::Thread) => { if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { let search_threads_task = @@ -111,6 +113,7 @@ fn search( Task::ready(Vec::new()) } } + Some(ContextPickerMode::Fetch) => { if !query.is_empty() { Task::ready(vec![Match::Fetch(query.into())]) @@ -118,10 +121,11 @@ fn search( Task::ready(Vec::new()) } } + Some(ContextPickerMode::Rules) => { - if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { + if let Some(prompt_store) = prompt_store.as_ref() { let search_rules_task = - search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx); + search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); cx.background_spawn(async move { search_rules_task .await @@ -133,6 +137,7 @@ fn search( Task::ready(Vec::new()) } } + None => { if query.is_empty() { let mut matches = recent_entries @@ -163,7 +168,7 @@ fn search( .collect::>(); matches.extend( - available_context_picker_entries(&thread_store, &workspace, cx) + available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx) .into_iter() .map(|mode| { Match::Entry(EntryMatch { @@ -180,7 +185,8 @@ fn search( let search_files_task = search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); - let entries = available_context_picker_entries(&thread_store, &workspace, cx); + let entries = + available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx); let entry_candidates = entries .iter() .enumerate() @@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider { move |_, _: &mut Window, cx: &mut App| { context_store.update(cx, |context_store, cx| { for (buffer, range) in &selections { - context_store - .add_selection(buffer.clone(), range.clone(), cx) - .detach_and_log_err(cx) + context_store.add_selection( + buffer.clone(), + range.clone(), + cx, + ); } }); @@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider { source_range: Range, editor: Entity, context_store: Entity, - thread_store: Entity, ) -> Completion { let new_text = MentionLink::for_rules(&rules); let new_text_len = new_text.len(); @@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), move |cx| { - let prompt_uuid = rules.prompt_id; - let prompt_id = PromptId::User { uuid: prompt_uuid }; - let context_store = context_store.clone(); - let Some(prompt_store) = thread_store.read(cx).prompt_store() else { - log::error!("Can't add user rules as prompt store is missing."); - return; - }; - let prompt_store = prompt_store.read(cx); - let Some(metadata) = prompt_store.metadata(prompt_id) else { - return; - }; - let Some(title) = metadata.title else { - return; - }; - let text_task = prompt_store.load(prompt_id, cx); - - cx.spawn(async move |cx| { - let text = text_task.await?; - context_store.update(cx, |context_store, cx| { - context_store.add_rules(prompt_uuid, title, text, false, cx) - }) - }) - .detach_and_log_err(cx); + let user_prompt_id = rules.prompt_id; + context_store.update(cx, |context_store, cx| { + context_store.add_rules(user_prompt_id, false, cx); + }); }, )), } @@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider { let url_to_fetch = url_to_fetch.clone(); cx.spawn(async move |cx| { if context_store.update(cx, |context_store, _| { - context_store.includes_url(&url_to_fetch).is_some() + context_store.includes_url(&url_to_fetch) })? { return Ok(()); } @@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider { move |cx| { context_store.update(cx, |context_store, cx| { let task = if is_directory { - context_store.add_directory(project_path.clone(), false, cx) + Task::ready(context_store.add_directory(&project_path, false, cx)) } else { context_store.add_file_from_path(project_path.clone(), false, cx) }; @@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, ); + let prompt_store = thread_store.as_ref().and_then(|thread_store| { + thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten() + }); + let search_task = search( mode, query, Arc::::default(), recent_entries, + prompt_store, thread_store.clone(), workspace.clone(), cx, @@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, )) } + Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( symbol, excerpt_id, @@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { workspace.clone(), cx, ), + Match::Thread(ThreadMatch { thread, is_recent, .. }) => { @@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider { thread_store, )) } - Match::Rules(user_rules) => { - let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?; - Some(Self::completion_for_rules( - user_rules, - excerpt_id, - source_range.clone(), - editor.clone(), - context_store.clone(), - thread_store, - )) - } + + Match::Rules(user_rules) => Some(Self::completion_for_rules( + user_rules, + excerpt_id, + source_range.clone(), + editor.clone(), + context_store.clone(), + )), + Match::Fetch(url) => Some(Self::completion_for_fetch( source_range.clone(), url, @@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { context_store.clone(), http_client.clone(), )), + Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( entry, excerpt_id, diff --git a/crates/agent/src/context_picker/fetch_context_picker.rs b/crates/agent/src/context_picker/fetch_context_picker.rs index 5c7795237b..5df47e5a28 100644 --- a/crates/agent/src/context_picker/fetch_context_picker.rs +++ b/crates/agent/src/context_picker/fetch_context_picker.rs @@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate { cx: &mut Context>, ) -> Option { let added = self.context_store.upgrade().map_or(false, |context_store| { - context_store.read(cx).includes_url(&self.url).is_some() + context_store.read(cx).includes_url(&self.url) }); Some( diff --git a/crates/agent/src/context_picker/file_context_picker.rs b/crates/agent/src/context_picker/file_context_picker.rs index 5981b471c2..1dbd209850 100644 --- a/crates/agent/src/context_picker/file_context_picker.rs +++ b/crates/agent/src/context_picker/file_context_picker.rs @@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate { .context_store .update(cx, |context_store, cx| { if is_directory { - context_store.add_directory(project_path, true, cx) + Task::ready(context_store.add_directory(&project_path, true, cx)) } else { - context_store.add_file_from_path(project_path, true, cx) + context_store.add_file_from_path(project_path.clone(), true, cx) } }) .ok() @@ -325,11 +325,11 @@ pub fn render_file_context_entry( path: path.clone(), }; if is_directory { - context_store.read(cx).includes_directory(&project_path) - } else { context_store .read(cx) - .will_include_file_path(&project_path, cx) + .path_included_in_directory(&project_path, cx) + } else { + context_store.read(cx).file_path_included(&project_path, cx) } }); @@ -357,7 +357,7 @@ pub fn render_file_context_entry( })), ) .when_some(added, |el, added| match added { - FileInclusion::Direct(_) => el.child( + FileInclusion::Direct => el.child( h_flex() .w_full() .justify_end() @@ -369,9 +369,8 @@ pub fn render_file_context_entry( ) .child(Label::new("Added").size(LabelSize::Small)), ), - FileInclusion::InDirectory(directory_project_path) => { - // TODO: Consider using worktree full_path to include worktree name. - let directory_path = directory_project_path.path.to_string_lossy().into_owned(); + FileInclusion::InDirectory { full_path } => { + let directory_full_path = full_path.to_string_lossy().into_owned(); el.child( h_flex() @@ -385,7 +384,7 @@ pub fn render_file_context_entry( ) .child(Label::new("Included").size(LabelSize::Small)), ) - .tooltip(Tooltip::text(format!("in {directory_path}"))) + .tooltip(Tooltip::text(format!("in {directory_full_path}"))) } }) } diff --git a/crates/agent/src/context_picker/rules_context_picker.rs b/crates/agent/src/context_picker/rules_context_picker.rs index 4c1fc65303..ef4676e4c3 100644 --- a/crates/agent/src/context_picker/rules_context_picker.rs +++ b/crates/agent/src/context_picker/rules_context_picker.rs @@ -1,16 +1,15 @@ use std::sync::Arc; use std::sync::atomic::AtomicBool; -use anyhow::anyhow; use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; use picker::{Picker, PickerDelegate}; -use prompt_store::{PromptId, UserPromptId}; +use prompt_store::{PromptId, PromptStore, UserPromptId}; use ui::{ListItem, prelude::*}; +use util::ResultExt as _; use crate::context::RULES_ICON; use crate::context_picker::ContextPicker; use crate::context_store::{self, ContextStore}; -use crate::thread_store::ThreadStore; pub struct RulesContextPicker { picker: Entity>, @@ -18,13 +17,13 @@ pub struct RulesContextPicker { impl RulesContextPicker { pub fn new( - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store); + let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); RulesContextPicker { picker } @@ -50,7 +49,7 @@ pub struct RulesContextEntry { } pub struct RulesContextPickerDelegate { - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, matches: Vec, @@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate { impl RulesContextPickerDelegate { pub fn new( - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, ) -> Self { RulesContextPickerDelegate { - thread_store, + prompt_store, context_picker, context_store, matches: Vec::new(), @@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let Some(thread_store) = self.thread_store.upgrade() else { - return Task::ready(()); - }; - - let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx); + let search_task = search_rules( + query, + Arc::new(AtomicBool::default()), + &self.prompt_store, + cx, + ); cx.spawn_in(window, async move |this, cx| { let matches = search_task.await; this.update(cx, |this, cx| { @@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate { return; }; - let Some(thread_store) = self.thread_store.upgrade() else { - return; - }; - - let prompt_id = entry.prompt_id; - - let load_rules_task = thread_store.update(cx, |thread_store, cx| { - thread_store.load_rules(prompt_id, cx) - }); - - cx.spawn(async move |this, cx| { - let (metadata, text) = load_rules_task.await?; - let Some(title) = metadata.title else { - return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context.")); - }; - this.update(cx, |this, cx| { - this.delegate - .context_store - .update(cx, |context_store, cx| { - context_store.add_rules(prompt_id, title, text, true, cx) - }) - .ok(); + self.context_store + .update(cx, |context_store, cx| { + context_store.add_rules(entry.prompt_id, true, cx) }) - }) - .detach_and_log_err(cx); + .log_err(); } fn dismissed(&mut self, _window: &mut Window, cx: &mut Context>) { @@ -179,11 +159,10 @@ pub fn render_thread_context_entry( context_store: WeakEntity, cx: &mut App, ) -> Div { - let added = context_store.upgrade().map_or(false, |ctx_store| { - ctx_store + let added = context_store.upgrade().map_or(false, |context_store| { + context_store .read(cx) - .includes_user_rules(&user_rules.prompt_id) - .is_some() + .includes_user_rules(user_rules.prompt_id) }); h_flex() @@ -218,12 +197,9 @@ pub fn render_thread_context_entry( pub(crate) fn search_rules( query: String, cancellation_flag: Arc, - thread_store: Entity, + prompt_store: &Entity, cx: &mut App, ) -> Task> { - let Some(prompt_store) = thread_store.read(cx).prompt_store() else { - return Task::ready(vec![]); - }; let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx); cx.background_spawn(async move { search_task diff --git a/crates/agent/src/context_picker/symbol_context_picker.rs b/crates/agent/src/context_picker/symbol_context_picker.rs index b76d4a8093..bc70c237a4 100644 --- a/crates/agent/src/context_picker/symbol_context_picker.rs +++ b/crates/agent/src/context_picker/symbol_context_picker.rs @@ -10,7 +10,6 @@ use gpui::{ use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; use project::{DocumentSymbol, Symbol}; -use text::OffsetRangeExt; use ui::{ListItem, prelude::*}; use util::ResultExt as _; use workspace::Workspace; @@ -228,18 +227,16 @@ pub(crate) fn add_symbol( ) })?; - context_store - .update(cx, move |context_store, cx| { - context_store.add_symbol( - buffer, - name.into(), - range, - enclosing_range, - remove_if_exists, - cx, - ) - })? - .await + context_store.update(cx, move |context_store, cx| { + context_store.add_symbol( + buffer, + name.into(), + range, + enclosing_range, + remove_if_exists, + cx, + ) + }) }) } @@ -353,38 +350,13 @@ fn compute_symbol_entries( context_store: &ContextStore, cx: &App, ) -> Vec { - let mut symbol_entries = Vec::with_capacity(symbols.len()); - for SymbolMatch { symbol, .. } in symbols { - let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path); - let is_included = if let Some(symbols_for_path) = symbols_for_path { - let mut is_included = false; - for included_symbol_id in symbols_for_path { - if included_symbol_id.name.as_ref() == symbol.name.as_str() { - if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) { - let snapshot = buffer.read(cx).snapshot(); - let included_symbol_range = - included_symbol_id.range.to_point_utf16(&snapshot); - - if included_symbol_range.start == symbol.range.start.0 - && included_symbol_range.end == symbol.range.end.0 - { - is_included = true; - break; - } - } - } - } - is_included - } else { - false - }; - - symbol_entries.push(SymbolEntry { + symbols + .into_iter() + .map(|SymbolMatch { symbol, .. }| SymbolEntry { + is_included: context_store.includes_symbol(&symbol, cx), symbol, - is_included, }) - } - symbol_entries + .collect::>() } pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful
{ diff --git a/crates/agent/src/context_picker/thread_context_picker.rs b/crates/agent/src/context_picker/thread_context_picker.rs index 030eaf06af..90c21b1c93 100644 --- a/crates/agent/src/context_picker/thread_context_picker.rs +++ b/crates/agent/src/context_picker/thread_context_picker.rs @@ -173,7 +173,7 @@ pub fn render_thread_context_entry( cx: &mut App, ) -> Div { let added = context_store.upgrade().map_or(false, |ctx_store| { - ctx_store.read(cx).includes_thread(&thread.id).is_some() + ctx_store.read(cx).includes_thread(&thread.id) }); h_flex() @@ -219,7 +219,7 @@ pub(crate) fn search_threads( ) -> Task> { let threads = thread_store .read(cx) - .threads() + .reverse_chronological_threads() .into_iter() .map(|thread| ThreadContextEntry { id: thread.id, diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index c66cad3ef2..f8f60dc911 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -1,43 +1,35 @@ use std::ops::Range; -use std::path::Path; +use std::path::PathBuf; use std::sync::Arc; -use anyhow::{Context as _, Result, anyhow}; -use collections::{BTreeMap, HashMap, HashSet}; +use anyhow::{Result, anyhow}; +use collections::{HashSet, IndexSet}; use futures::future::join_all; -use futures::{self, Future, FutureExt, future}; -use gpui::{App, AppContext as _, Context, Entity, Image, SharedString, Task, WeakEntity}; +use futures::{self, FutureExt}; +use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity}; use language::Buffer; use language_model::LanguageModelImage; -use project::{Project, ProjectEntryId, ProjectItem, ProjectPath, Worktree}; +use project::{Project, ProjectItem, ProjectPath, Symbol}; use prompt_store::UserPromptId; -use rope::{Point, Rope}; -use text::{Anchor, BufferId, OffsetRangeExt}; -use util::{ResultExt as _, maybe}; +use ref_cast::RefCast as _; +use text::{Anchor, OffsetRangeExt}; +use util::ResultExt as _; use crate::ThreadStore; use crate::context::{ - AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, - FetchedUrlContext, FileContext, ImageContext, RulesContext, SelectionContext, SymbolContext, - ThreadContext, + AgentContext, AgentContextKey, ContextId, DirectoryContext, FetchedUrlContext, FileContext, + ImageContext, RulesContext, SelectionContext, SymbolContext, ThreadContext, }; use crate::context_strip::SuggestedContext; use crate::thread::{Thread, ThreadId}; pub struct ContextStore { project: WeakEntity, - context: Vec, thread_store: Option>, - next_context_id: ContextId, - files: BTreeMap, - directories: HashMap, - symbols: HashMap, - symbol_buffers: HashMap>, - symbols_by_path: HashMap>, - threads: HashMap, thread_summary_tasks: Vec>, - fetched_urls: HashMap, - user_rules: HashMap, + next_context_id: ContextId, + context_set: IndexSet, + context_thread_ids: HashSet, } impl ContextStore { @@ -48,35 +40,33 @@ impl ContextStore { Self { project, thread_store, - context: Vec::new(), - next_context_id: ContextId(0), - files: BTreeMap::default(), - directories: HashMap::default(), - symbols: HashMap::default(), - symbol_buffers: HashMap::default(), - symbols_by_path: HashMap::default(), - threads: HashMap::default(), thread_summary_tasks: Vec::new(), - fetched_urls: HashMap::default(), - user_rules: HashMap::default(), + next_context_id: ContextId::zero(), + context_set: IndexSet::default(), + context_thread_ids: HashSet::default(), } } - pub fn context(&self) -> &Vec { - &self.context - } - - pub fn context_for_id(&self, id: ContextId) -> Option<&AssistantContext> { - self.context().iter().find(|context| context.id() == id) + pub fn context(&self) -> impl Iterator { + self.context_set.iter().map(|entry| entry.as_ref()) } pub fn clear(&mut self) { - self.context.clear(); - self.files.clear(); - self.directories.clear(); - self.threads.clear(); - self.fetched_urls.clear(); - self.user_rules.clear(); + self.context_set.clear(); + self.context_thread_ids.clear(); + } + + pub fn new_context_for_thread(&self, thread: &Thread) -> Vec { + let existing_context = thread + .messages() + .flat_map(|message| &message.loaded_context.contexts) + .map(AgentContextKey::ref_cast) + .collect::>(); + self.context_set + .iter() + .filter(|context| !existing_context.contains(context)) + .map(|entry| entry.0.clone()) + .collect::>() } pub fn add_file_from_path( @@ -93,241 +83,98 @@ impl ContextStore { let open_buffer_task = project.update(cx, |project, cx| { project.open_buffer(project_path.clone(), cx) })?; - let buffer = open_buffer_task.await?; - let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?; - - let already_included = this.update(cx, |this, cx| { - match this.will_include_buffer(buffer_id, &project_path) { - Some(FileInclusion::Direct(context_id)) => { - if remove_if_exists { - this.remove_context(context_id, cx); - } - true - } - Some(FileInclusion::InDirectory(_)) => true, - None => false, - } - })?; - - if already_included { - return anyhow::Ok(()); - } - - let context_buffer = this - .update(cx, |_, cx| load_context_buffer(buffer, cx))?? - .await; - this.update(cx, |this, cx| { - this.insert_file(context_buffer, cx); - })?; - - anyhow::Ok(()) + this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx) + }) }) } pub fn add_file_from_buffer( &mut self, + project_path: &ProjectPath, buffer: Entity, + remove_if_exists: bool, cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| { - let context_buffer = this - .update(cx, |_, cx| load_context_buffer(buffer, cx))?? - .await; + ) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::File(FileContext { buffer, context_id }); - this.update(cx, |this, cx| this.insert_file(context_buffer, cx))?; + let already_included = if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); + } + true + } else { + self.path_included_in_directory(project_path, cx).is_some() + }; - anyhow::Ok(()) - }) - } - - fn insert_file(&mut self, context_buffer: ContextBuffer, cx: &mut Context) { - let id = self.next_context_id.post_inc(); - self.files.insert(context_buffer.id, id); - self.context - .push(AssistantContext::File(FileContext { id, context_buffer })); - cx.notify(); + if !already_included { + self.insert_context(context, cx); + } } pub fn add_directory( &mut self, - project_path: ProjectPath, + project_path: &ProjectPath, remove_if_exists: bool, cx: &mut Context, - ) -> Task> { + ) -> Result<()> { let Some(project) = self.project.upgrade() else { - return Task::ready(Err(anyhow!("failed to read project"))); + return Err(anyhow!("failed to read project")); }; let Some(entry_id) = project .read(cx) - .entry_for_path(&project_path, cx) + .entry_for_path(project_path, cx) .map(|entry| entry.id) else { - return Task::ready(Err(anyhow!("no entry found for directory context"))); + return Err(anyhow!("no entry found for directory context")); }; - let already_included = match self.includes_directory(&project_path) { - Some(FileInclusion::Direct(context_id)) => { - if remove_if_exists { - self.remove_context(context_id, cx); - } - true + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Directory(DirectoryContext { + entry_id, + context_id, + }); + + if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); } - Some(FileInclusion::InDirectory(_)) => true, - None => false, - }; - if already_included { - return Task::ready(Ok(())); + } else if self.path_included_in_directory(project_path, cx).is_none() { + self.insert_context(context, cx); } - let worktree_id = project_path.worktree_id; - cx.spawn(async move |this, cx| { - let worktree = project.update(cx, |project, cx| { - project - .worktree_for_id(worktree_id, cx) - .ok_or_else(|| anyhow!("no worktree found for {worktree_id:?}")) - })??; - - let files = worktree.update(cx, |worktree, _cx| { - collect_files_in_path(worktree, &project_path.path) - })?; - - let open_buffers_task = project.update(cx, |project, cx| { - let tasks = files.iter().map(|file_path| { - project.open_buffer( - ProjectPath { - worktree_id, - path: file_path.clone(), - }, - cx, - ) - }); - future::join_all(tasks) - })?; - - let buffers = open_buffers_task.await; - - let context_buffer_tasks = this.update(cx, |_, cx| { - buffers - .into_iter() - .flatten() - .flat_map(move |buffer| load_context_buffer(buffer, cx).log_err()) - .collect::>() - })?; - - let context_buffers = future::join_all(context_buffer_tasks).await; - - if context_buffers.is_empty() { - let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?; - return Err(anyhow!("No text files found in {}", &full_path.display())); - } - - this.update(cx, |this, cx| { - this.insert_directory(worktree, entry_id, project_path, context_buffers, cx); - })?; - - anyhow::Ok(()) - }) - } - - fn insert_directory( - &mut self, - worktree: Entity, - entry_id: ProjectEntryId, - project_path: ProjectPath, - context_buffers: Vec, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - let last_path = project_path.path.clone(); - self.directories.insert(project_path, id); - - self.context - .push(AssistantContext::Directory(DirectoryContext { - id, - worktree, - entry_id, - last_path, - context_buffers, - })); - cx.notify(); + anyhow::Ok(()) } pub fn add_symbol( &mut self, buffer: Entity, - symbol_name: SharedString, - symbol_range: Range, - symbol_enclosing_range: Range, + symbol: SharedString, + range: Range, + enclosing_range: Range, remove_if_exists: bool, cx: &mut Context, - ) -> Task> { - let buffer_ref = buffer.read(cx); - let Some(project_path) = buffer_ref.project_path(cx) else { - return Task::ready(Err(anyhow!("buffer has no path"))); - }; + ) -> bool { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Symbol(SymbolContext { + buffer, + symbol, + range, + enclosing_range, + context_id, + }); - if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) { - let mut matching_symbol_id = None; - for symbol in symbols_for_path { - if &symbol.name == &symbol_name { - let snapshot = buffer_ref.snapshot(); - if symbol.range.to_offset(&snapshot) == symbol_range.to_offset(&snapshot) { - matching_symbol_id = self.symbols.get(symbol).cloned(); - break; - } - } - } - - if let Some(id) = matching_symbol_id { - if remove_if_exists { - self.remove_context(id, cx); - } - return Task::ready(Ok(false)); + if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); } + return false; } - let context_buffer_task = - match load_context_buffer_range(buffer, symbol_enclosing_range.clone(), cx) { - Ok((_line_range, context_buffer_task)) => context_buffer_task, - Err(err) => return Task::ready(Err(err)), - }; - - cx.spawn(async move |this, cx| { - let context_buffer = context_buffer_task.await; - - this.update(cx, |this, cx| { - this.insert_symbol( - make_context_symbol( - context_buffer, - project_path, - symbol_name, - symbol_range, - symbol_enclosing_range, - ), - cx, - ) - })?; - anyhow::Ok(true) - }) - } - - fn insert_symbol(&mut self, context_symbol: ContextSymbol, cx: &mut Context) { - let id = self.next_context_id.post_inc(); - self.symbols.insert(context_symbol.id.clone(), id); - self.symbols_by_path - .entry(context_symbol.id.path.clone()) - .or_insert_with(Vec::new) - .push(context_symbol.id.clone()); - self.symbol_buffers - .insert(context_symbol.id.clone(), context_symbol.buffer.clone()); - self.context.push(AssistantContext::Symbol(SymbolContext { - id, - context_symbol, - })); - cx.notify(); + self.insert_context(context, cx) } pub fn add_thread( @@ -336,24 +183,23 @@ impl ContextStore { remove_if_exists: bool, cx: &mut Context, ) { - if let Some(context_id) = self.includes_thread(&thread.read(cx).id()) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Thread(ThreadContext { thread, context_id }); + + if self.has_context(&context) { if remove_if_exists { - self.remove_context(context_id, cx); + self.remove_context(&context, cx); } } else { - self.insert_thread(thread, cx); + self.insert_context(context, cx); } } - pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { - let tasks = std::mem::take(&mut self.thread_summary_tasks); - - cx.spawn(async move |_cx| { - join_all(tasks).await; - }) - } - - fn insert_thread(&mut self, thread: Entity, cx: &mut Context) { + fn start_summarizing_thread_if_needed( + &mut self, + thread: &Entity, + cx: &mut Context, + ) { if let Some(summary_task) = thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx)) { @@ -374,106 +220,60 @@ impl ContextStore { } })); } + } - let id = self.next_context_id.post_inc(); + pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { + let tasks = std::mem::take(&mut self.thread_summary_tasks); - let text = thread.read(cx).latest_detailed_summary_or_text(); - - self.threads.insert(thread.read(cx).id().clone(), id); - self.context - .push(AssistantContext::Thread(ThreadContext { id, thread, text })); - cx.notify(); + cx.spawn(async move |_cx| { + join_all(tasks).await; + }) } pub fn add_rules( &mut self, prompt_id: UserPromptId, - title: impl Into, - text: impl Into, remove_if_exists: bool, cx: &mut Context, ) { - if let Some(context_id) = self.includes_user_rules(&prompt_id) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Rules(RulesContext { + prompt_id, + context_id, + }); + + if self.has_context(&context) { if remove_if_exists { - self.remove_context(context_id, cx); + self.remove_context(&context, cx); } } else { - self.insert_user_rules(prompt_id, title, text, cx); + self.insert_context(context, cx); } } - pub fn insert_user_rules( - &mut self, - prompt_id: UserPromptId, - title: impl Into, - text: impl Into, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - - self.user_rules.insert(prompt_id, id); - self.context.push(AssistantContext::Rules(RulesContext { - id, - prompt_id, - title: title.into(), - text: text.into(), - })); - cx.notify(); - } - pub fn add_fetched_url( &mut self, url: String, text: impl Into, cx: &mut Context, ) { - if self.includes_url(&url).is_none() { - self.insert_fetched_url(url, text, cx); - } - } + let context = AgentContext::FetchedUrl(FetchedUrlContext { + url: url.into(), + text: text.into(), + context_id: self.next_context_id.post_inc(), + }); - fn insert_fetched_url( - &mut self, - url: String, - text: impl Into, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - - self.fetched_urls.insert(url.clone(), id); - self.context - .push(AssistantContext::FetchedUrl(FetchedUrlContext { - id, - url: url.into(), - text: text.into(), - })); - cx.notify(); + self.insert_context(context, cx); } pub fn add_image(&mut self, image: Arc, cx: &mut Context) { let image_task = LanguageModelImage::from_image(image.clone(), cx).shared(); - let id = self.next_context_id.post_inc(); - self.context.push(AssistantContext::Image(ImageContext { - id, + let context = AgentContext::Image(ImageContext { original_image: image, image_task, - })); - cx.notify(); - } - - pub fn wait_for_images(&self, cx: &App) -> Task<()> { - let tasks = self - .context - .iter() - .filter_map(|ctx| match ctx { - AssistantContext::Image(ctx) => Some(ctx.image_task.clone()), - _ => None, - }) - .collect::>(); - - cx.spawn(async move |_cx| { - join_all(tasks).await; - }) + context_id: self.next_context_id.post_inc(), + }); + self.insert_context(context, cx); } pub fn add_selection( @@ -481,45 +281,21 @@ impl ContextStore { buffer: Entity, range: Range, cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| { - let (line_range, context_buffer_task) = this.update(cx, |_, cx| { - load_context_buffer_range(buffer, range.clone(), cx) - })??; - - let context_buffer = context_buffer_task.await; - - this.update(cx, |this, cx| { - this.insert_selection(context_buffer, range, line_range, cx) - })?; - - anyhow::Ok(()) - }) - } - - fn insert_selection( - &mut self, - context_buffer: ContextBuffer, - range: Range, - line_range: Range, - cx: &mut Context, ) { - let id = self.next_context_id.post_inc(); - self.context - .push(AssistantContext::Selection(SelectionContext { - id, - range, - line_range, - context_buffer, - })); - cx.notify(); + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Selection(SelectionContext { + buffer, + range, + context_id, + }); + self.insert_context(context, cx); } - pub fn accept_suggested_context( + pub fn add_suggested_context( &mut self, suggested: &SuggestedContext, cx: &mut Context, - ) -> Task> { + ) { match suggested { SuggestedContext::File { buffer, @@ -527,655 +303,183 @@ impl ContextStore { name: _, } => { if let Some(buffer) = buffer.upgrade() { - return self.add_file_from_buffer(buffer, cx); + let context_id = self.next_context_id.post_inc(); + self.insert_context(AgentContext::File(FileContext { buffer, context_id }), cx); }; } SuggestedContext::Thread { thread, name: _ } => { if let Some(thread) = thread.upgrade() { - self.insert_thread(thread, cx); - }; - } - } - Task::ready(Ok(())) - } - - pub fn remove_context(&mut self, id: ContextId, cx: &mut Context) { - let Some(ix) = self.context.iter().position(|context| context.id() == id) else { - return; - }; - - match self.context.remove(ix) { - AssistantContext::File(_) => { - self.files.retain(|_, context_id| *context_id != id); - } - AssistantContext::Directory(_) => { - self.directories.retain(|_, context_id| *context_id != id); - } - AssistantContext::Symbol(symbol) => { - if let Some(symbols_in_path) = - self.symbols_by_path.get_mut(&symbol.context_symbol.id.path) - { - symbols_in_path.retain(|s| { - self.symbols - .get(s) - .map_or(false, |context_id| *context_id != id) - }); + let context_id = self.next_context_id.post_inc(); + self.insert_context( + AgentContext::Thread(ThreadContext { thread, context_id }), + cx, + ); } - self.symbol_buffers.remove(&symbol.context_symbol.id); - self.symbols.retain(|_, context_id| *context_id != id); } - AssistantContext::Selection(_) => {} - AssistantContext::FetchedUrl(_) => { - self.fetched_urls.retain(|_, context_id| *context_id != id); - } - AssistantContext::Thread(_) => { - self.threads.retain(|_, context_id| *context_id != id); - } - AssistantContext::Rules(RulesContext { prompt_id, .. }) => { - self.user_rules.remove(&prompt_id); - } - AssistantContext::Image(_) => {} } - - cx.notify(); } - /// Returns whether the buffer is already included directly in the context, or if it will be - /// included in the context via a directory. Directory inclusion is based on paths rather than - /// buffer IDs as the directory will be re-scanned. - pub fn will_include_buffer( - &self, - buffer_id: BufferId, - project_path: &ProjectPath, - ) -> Option { - if let Some(context_id) = self.files.get(&buffer_id) { - return Some(FileInclusion::Direct(*context_id)); + fn insert_context(&mut self, context: AgentContext, cx: &mut Context) -> bool { + match &context { + AgentContext::Thread(thread_context) => { + self.context_thread_ids + .insert(thread_context.thread.read(cx).id().clone()); + self.start_summarizing_thread_if_needed(&thread_context.thread, cx); + } + _ => {} } + let inserted = self.context_set.insert(AgentContextKey(context)); + if inserted { + cx.notify(); + } + inserted + } - self.will_include_file_path_via_directory(project_path) + pub fn remove_context(&mut self, context: &AgentContext, cx: &mut Context) { + if self + .context_set + .shift_remove(AgentContextKey::ref_cast(context)) + { + match context { + AgentContext::Thread(thread_context) => { + self.context_thread_ids + .remove(thread_context.thread.read(cx).id()); + } + _ => {} + } + cx.notify(); + } + } + + pub fn has_context(&mut self, context: &AgentContext) -> bool { + self.context_set + .contains(AgentContextKey::ref_cast(context)) } /// Returns whether this file path is already included directly in the context, or if it will be /// included in the context via a directory. - pub fn will_include_file_path( + pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option { + let project = self.project.upgrade()?.read(cx); + self.context().find_map(|context| match context { + AgentContext::File(file_context) => FileInclusion::check_file(file_context, path, cx), + AgentContext::Directory(directory_context) => { + FileInclusion::check_directory(directory_context, path, project, cx) + } + _ => None, + }) + } + + pub fn path_included_in_directory( &self, - project_path: &ProjectPath, + path: &ProjectPath, cx: &App, ) -> Option { - if !self.files.is_empty() { - let found_file_context = self.context.iter().find(|context| match &context { - AssistantContext::File(file_context) => { - let buffer = file_context.context_buffer.buffer.read(cx); - if let Some(context_path) = buffer.project_path(cx) { - &context_path == project_path - } else { - false - } + let project = self.project.upgrade()?.read(cx); + self.context().find_map(|context| match context { + AgentContext::Directory(directory_context) => { + FileInclusion::check_directory(directory_context, path, project, cx) + } + _ => None, + }) + } + + pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool { + self.context().any(|context| match context { + AgentContext::Symbol(context) => { + if context.symbol != symbol.name { + return false; } - _ => false, - }); - if let Some(context) = found_file_context { - return Some(FileInclusion::Direct(context.id())); + let buffer = context.buffer.read(cx); + let Some(context_path) = buffer.project_path(cx) else { + return false; + }; + if context_path != symbol.path { + return false; + } + let context_range = context.range.to_point_utf16(&buffer.snapshot()); + context_range.start == symbol.range.start.0 + && context_range.end == symbol.range.end.0 } - } - - self.will_include_file_path_via_directory(project_path) + _ => false, + }) } - fn will_include_file_path_via_directory( - &self, - project_path: &ProjectPath, - ) -> Option { - if self.directories.is_empty() { - return None; - } - - let mut path_buf = project_path.path.to_path_buf(); - - while path_buf.pop() { - // TODO: This isn't very efficient. Consider using a better representation of the - // directories map. - let directory_project_path = ProjectPath { - worktree_id: project_path.worktree_id, - path: path_buf.clone().into(), - }; - if let Some(_) = self.directories.get(&directory_project_path) { - return Some(FileInclusion::InDirectory(directory_project_path)); - } - } - - None + pub fn includes_thread(&self, thread_id: &ThreadId) -> bool { + self.context_thread_ids.contains(thread_id) } - pub fn includes_directory(&self, project_path: &ProjectPath) -> Option { - if let Some(context_id) = self.directories.get(project_path) { - return Some(FileInclusion::Direct(*context_id)); - } - - self.will_include_file_path_via_directory(project_path) + pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool { + self.context_set + .contains(&RulesContext::lookup_key(prompt_id)) } - pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option { - self.symbols.get(symbol_id).copied() - } - - pub fn included_symbols_by_path(&self) -> &HashMap> { - &self.symbols_by_path - } - - pub fn buffer_for_symbol(&self, symbol_id: &ContextSymbolId) -> Option> { - self.symbol_buffers.get(symbol_id).cloned() - } - - pub fn includes_thread(&self, thread_id: &ThreadId) -> Option { - self.threads.get(thread_id).copied() - } - - pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option { - self.user_rules.get(prompt_id).copied() - } - - pub fn includes_url(&self, url: &str) -> Option { - self.fetched_urls.get(url).copied() - } - - /// Replaces the context that matches the ID of the new context, if any match. - fn replace_context(&mut self, new_context: AssistantContext) { - let id = new_context.id(); - for context in self.context.iter_mut() { - if context.id() == id { - *context = new_context; - break; - } - } + pub fn includes_url(&self, url: impl Into) -> bool { + self.context_set + .contains(&FetchedUrlContext::lookup_key(url.into())) } pub fn file_paths(&self, cx: &App) -> HashSet { - self.context - .iter() + self.context() .filter_map(|context| match context { - AssistantContext::File(file) => { - let buffer = file.context_buffer.buffer.read(cx); + AgentContext::File(file) => { + let buffer = file.buffer.read(cx); buffer.project_path(cx) } - AssistantContext::Directory(_) - | AssistantContext::Symbol(_) - | AssistantContext::Selection(_) - | AssistantContext::FetchedUrl(_) - | AssistantContext::Thread(_) - | AssistantContext::Rules(_) - | AssistantContext::Image(_) => None, + AgentContext::Directory(_) + | AgentContext::Symbol(_) + | AgentContext::Selection(_) + | AgentContext::FetchedUrl(_) + | AgentContext::Thread(_) + | AgentContext::Rules(_) + | AgentContext::Image(_) => None, }) .collect() } - pub fn thread_ids(&self) -> HashSet { - self.threads.keys().cloned().collect() + pub fn thread_ids(&self) -> &HashSet { + &self.context_thread_ids } } pub enum FileInclusion { - Direct(ContextId), - InDirectory(ProjectPath), + Direct, + InDirectory { full_path: PathBuf }, } -fn make_context_symbol( - context_buffer: ContextBuffer, - path: ProjectPath, - name: SharedString, - range: Range, - enclosing_range: Range, -) -> ContextSymbol { - ContextSymbol { - id: ContextSymbolId { name, range, path }, - buffer_version: context_buffer.version, - enclosing_range, - buffer: context_buffer.buffer, - text: context_buffer.text, - } -} - -fn load_context_buffer_range( - buffer: Entity, - range: Range, - cx: &App, -) -> Result<(Range, Task)> { - let buffer_ref = buffer.read(cx); - let id = buffer_ref.remote_id(); - - let file = buffer_ref.file().context("context buffer missing path")?; - let full_path = file.full_path(cx); - - // Important to collect version at the same time as content so that staleness logic is correct. - let version = buffer_ref.version(); - let content = buffer_ref.text_for_range(range.clone()).collect::(); - let line_range = range.to_point(&buffer_ref.snapshot()); - - // Build the text on a background thread. - let task = cx.background_spawn({ - let line_range = line_range.clone(); - async move { - let text = to_fenced_codeblock(&full_path, content, Some(line_range)); - ContextBuffer { - id, - buffer, - last_full_path: full_path.into(), - version, - text, - } - } - }); - - Ok((line_range, task)) -} - -fn load_context_buffer(buffer: Entity, cx: &App) -> Result> { - let buffer_ref = buffer.read(cx); - let id = buffer_ref.remote_id(); - - let file = buffer_ref.file().context("context buffer missing path")?; - let full_path = file.full_path(cx); - - // Important to collect version at the same time as content so that staleness logic is correct. - let version = buffer_ref.version(); - let content = buffer_ref.as_rope().clone(); - - // Build the text on a background thread. - Ok(cx.background_spawn(async move { - let text = to_fenced_codeblock(&full_path, content, None); - ContextBuffer { - id, - buffer, - last_full_path: full_path.into(), - version, - text, - } - })) -} - -fn to_fenced_codeblock( - path: &Path, - content: Rope, - line_range: Option>, -) -> SharedString { - let line_range_text = line_range.map(|range| { - if range.start.row == range.end.row { - format!(":{}", range.start.row + 1) +impl FileInclusion { + fn check_file(file_context: &FileContext, path: &ProjectPath, cx: &App) -> Option { + let file_path = file_context.buffer.read(cx).project_path(cx)?; + if path == &file_path { + Some(FileInclusion::Direct) } else { - format!(":{}-{}", range.start.row + 1, range.end.row + 1) - } - }); - - let path_extension = path.extension().and_then(|ext| ext.to_str()); - let path_string = path.to_string_lossy(); - let capacity = 3 - + path_extension.map_or(0, |extension| extension.len() + 1) - + path_string.len() - + line_range_text.as_ref().map_or(0, |text| text.len()) - + 1 - + content.len() - + 5; - let mut buffer = String::with_capacity(capacity); - - buffer.push_str("```"); - - if let Some(extension) = path_extension { - buffer.push_str(extension); - buffer.push(' '); - } - buffer.push_str(&path_string); - - if let Some(line_range_text) = line_range_text { - buffer.push_str(&line_range_text); - } - - buffer.push('\n'); - for chunk in content.chunks() { - buffer.push_str(&chunk); - } - - if !buffer.ends_with('\n') { - buffer.push('\n'); - } - - buffer.push_str("```\n"); - - debug_assert!( - buffer.len() == capacity - 1 || buffer.len() == capacity, - "to_fenced_codeblock calculated capacity of {}, but length was {}", - capacity, - buffer.len(), - ); - - buffer.into() -} - -fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec> { - let mut files = Vec::new(); - - for entry in worktree.child_entries(path) { - if entry.is_dir() { - files.extend(collect_files_in_path(worktree, &entry.path)); - } else if entry.is_file() { - files.push(entry.path.clone()); - } - } - - files -} - -pub fn refresh_context_store_text( - context_store: Entity, - changed_buffers: &HashSet>, - cx: &App, -) -> impl Future> + use<> { - let mut tasks = Vec::new(); - - for context in &context_store.read(cx).context { - let id = context.id(); - - let task = maybe!({ - match context { - AssistantContext::File(file_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&file_context.context_buffer.buffer) - { - let context_store = context_store.clone(); - return refresh_file_text(context_store, file_context, cx); - } - } - AssistantContext::Directory(directory_context) => { - let directory_path = directory_context.project_path(cx)?; - let should_refresh = directory_path.path != directory_context.last_path - || changed_buffers.is_empty() - || changed_buffers.iter().any(|buffer| { - let Some(buffer_path) = buffer.read(cx).project_path(cx) else { - return false; - }; - buffer_path.starts_with(&directory_path) - }); - - if should_refresh { - let context_store = context_store.clone(); - return refresh_directory_text( - context_store, - directory_context, - directory_path, - cx, - ); - } - } - AssistantContext::Symbol(symbol_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&symbol_context.context_symbol.buffer) - { - let context_store = context_store.clone(); - return refresh_symbol_text(context_store, symbol_context, cx); - } - } - AssistantContext::Selection(selection_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&selection_context.context_buffer.buffer) - { - let context_store = context_store.clone(); - return refresh_selection_text(context_store, selection_context, cx); - } - } - AssistantContext::Thread(thread_context) => { - if changed_buffers.is_empty() { - let context_store = context_store.clone(); - return Some(refresh_thread_text(context_store, thread_context, cx)); - } - } - // Intentionally omit refreshing fetched URLs as it doesn't seem all that useful, - // and doing the caching properly could be tricky (unless it's already handled by - // the HttpClient?). - AssistantContext::FetchedUrl(_) => {} - AssistantContext::Rules(user_rules_context) => { - let context_store = context_store.clone(); - return Some(refresh_user_rules(context_store, user_rules_context, cx)); - } - AssistantContext::Image(_) => {} - } - None - }); - - if let Some(task) = task { - tasks.push(task.map(move |_| id)); } } - future::join_all(tasks) -} - -fn refresh_file_text( - context_store: Entity, - file_context: &FileContext, - cx: &App, -) -> Option> { - let id = file_context.id; - let task = refresh_context_buffer(&file_context.context_buffer, cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let context_buffer = task.await; - context_store - .update(cx, |context_store, _| { - let new_file_context = FileContext { id, context_buffer }; - context_store.replace_context(AssistantContext::File(new_file_context)); - }) - .ok(); - })) - } else { - None - } -} - -fn refresh_directory_text( - context_store: Entity, - directory_context: &DirectoryContext, - directory_path: ProjectPath, - cx: &App, -) -> Option> { - let mut stale = false; - let futures = directory_context - .context_buffers - .iter() - .map(|context_buffer| { - if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) { - stale = true; - future::Either::Left(refresh_task) + fn check_directory( + directory_context: &DirectoryContext, + path: &ProjectPath, + project: &Project, + cx: &App, + ) -> Option { + let worktree = project + .worktree_for_entry(directory_context.entry_id, cx)? + .read(cx); + let entry = worktree.entry_for_id(directory_context.entry_id)?; + let directory_path = ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }; + if path.starts_with(&directory_path) { + if path == &directory_path { + Some(FileInclusion::Direct) } else { - future::Either::Right(future::ready((*context_buffer).clone())) - } - }) - .collect::>(); - - if !stale { - return None; - } - - let context_buffers = future::join_all(futures); - - let id = directory_context.id; - let worktree = directory_context.worktree.clone(); - let entry_id = directory_context.entry_id; - let last_path = directory_path.path; - Some(cx.spawn(async move |cx| { - let context_buffers = context_buffers.await; - context_store - .update(cx, |context_store, _| { - let new_directory_context = DirectoryContext { - id, - worktree, - entry_id, - last_path, - context_buffers, - }; - context_store.replace_context(AssistantContext::Directory(new_directory_context)); - }) - .ok(); - })) -} - -fn refresh_symbol_text( - context_store: Entity, - symbol_context: &SymbolContext, - cx: &App, -) -> Option> { - let id = symbol_context.id; - let task = refresh_context_symbol(&symbol_context.context_symbol, cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let context_symbol = task.await; - context_store - .update(cx, |context_store, _| { - let new_symbol_context = SymbolContext { id, context_symbol }; - context_store.replace_context(AssistantContext::Symbol(new_symbol_context)); + Some(FileInclusion::InDirectory { + full_path: worktree.full_path(&entry.path), }) - .ok(); - })) - } else { - None - } -} - -fn refresh_selection_text( - context_store: Entity, - selection_context: &SelectionContext, - cx: &App, -) -> Option> { - let id = selection_context.id; - let range = selection_context.range.clone(); - let task = refresh_context_excerpt(&selection_context.context_buffer, range.clone(), cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let (line_range, context_buffer) = task.await; - context_store - .update(cx, |context_store, _| { - let new_selection_context = SelectionContext { - id, - range, - line_range, - context_buffer, - }; - context_store - .replace_context(AssistantContext::Selection(new_selection_context)); - }) - .ok(); - })) - } else { - None - } -} - -fn refresh_thread_text( - context_store: Entity, - thread_context: &ThreadContext, - cx: &App, -) -> Task<()> { - let id = thread_context.id; - let thread = thread_context.thread.clone(); - cx.spawn(async move |cx| { - context_store - .update(cx, |context_store, cx| { - let text = thread.read(cx).latest_detailed_summary_or_text(); - context_store.replace_context(AssistantContext::Thread(ThreadContext { - id, - thread, - text, - })); - }) - .ok(); - }) -} - -fn refresh_user_rules( - context_store: Entity, - user_rules_context: &RulesContext, - cx: &App, -) -> Task<()> { - let id = user_rules_context.id; - let prompt_id = user_rules_context.prompt_id; - let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else { - return Task::ready(()); - }; - let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| { - thread_store.load_rules(prompt_id, cx) - }) else { - return Task::ready(()); - }; - cx.spawn(async move |cx| { - if let Ok((metadata, text)) = load_task.await { - if let Some(title) = metadata.title.clone() { - context_store - .update(cx, |context_store, _cx| { - context_store.replace_context(AssistantContext::Rules(RulesContext { - id, - prompt_id, - title, - text: text.into(), - })); - }) - .ok(); - return; } + } else { + None } - context_store - .update(cx, |context_store, cx| { - context_store.remove_context(id, cx); - }) - .ok(); - }) -} - -fn refresh_context_buffer(context_buffer: &ContextBuffer, cx: &App) -> Option> { - let buffer = context_buffer.buffer.read(cx); - if buffer.version.changed_since(&context_buffer.version) { - load_context_buffer(context_buffer.buffer.clone(), cx).log_err() - } else { - None - } -} - -fn refresh_context_excerpt( - context_buffer: &ContextBuffer, - range: Range, - cx: &App, -) -> Option, ContextBuffer)> + use<>> { - let buffer = context_buffer.buffer.read(cx); - if buffer.version.changed_since(&context_buffer.version) { - let (line_range, context_buffer_task) = - load_context_buffer_range(context_buffer.buffer.clone(), range, cx).log_err()?; - Some(context_buffer_task.map(move |context_buffer| (line_range, context_buffer))) - } else { - None - } -} - -fn refresh_context_symbol( - context_symbol: &ContextSymbol, - cx: &App, -) -> Option + use<>> { - let buffer = context_symbol.buffer.read(cx); - let project_path = buffer.project_path(cx)?; - if buffer.version.changed_since(&context_symbol.buffer_version) { - let (_line_range, context_buffer_task) = load_context_buffer_range( - context_symbol.buffer.clone(), - context_symbol.enclosing_range.clone(), - cx, - ) - .log_err()?; - let name = context_symbol.id.name.clone(); - let range = context_symbol.id.range.clone(); - let enclosing_range = context_symbol.enclosing_range.clone(); - Some(context_buffer_task.map(move |context_buffer| { - make_context_symbol(context_buffer, project_path, name, range, enclosing_range) - })) - } else { - None } } diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index 6245f88998..8e0ca981a4 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -12,9 +12,9 @@ use itertools::Itertools; use language::Buffer; use project::ProjectItem; use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; -use workspace::{Workspace, notifications::NotifyResultExt}; +use workspace::Workspace; -use crate::context::{ContextId, ContextKind}; +use crate::context::{AgentContext, ContextKind}; use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; use crate::thread::Thread; @@ -32,6 +32,7 @@ pub struct ContextStrip { focus_handle: FocusHandle, suggest_context_kind: SuggestContextKind, workspace: WeakEntity, + thread_store: Option>, _subscriptions: Vec, focused_index: Option, children_bounds: Option>>, @@ -73,12 +74,31 @@ impl ContextStrip { focus_handle, suggest_context_kind, workspace, + thread_store, _subscriptions: subscriptions, focused_index: None, children_bounds: None, } } + fn added_contexts(&self, cx: &App) -> Vec { + if let Some(workspace) = self.workspace.upgrade() { + let project = workspace.read(cx).project().read(cx); + let prompt_store = self + .thread_store + .as_ref() + .and_then(|thread_store| thread_store.upgrade()) + .and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref()); + self.context_store + .read(cx) + .context() + .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx)) + .collect::>() + } else { + Vec::new() + } + } + fn suggested_context(&self, cx: &Context) -> Option { match self.suggest_context_kind { SuggestContextKind::File => self.suggested_file(cx), @@ -93,22 +113,19 @@ impl ContextStrip { let editor = active_item.to_any().downcast::().ok()?.read(cx); let active_buffer_entity = editor.buffer().read(cx).as_singleton()?; let active_buffer = active_buffer_entity.read(cx); - let project_path = active_buffer.project_path(cx)?; if self .context_store .read(cx) - .will_include_buffer(active_buffer.remote_id(), &project_path) + .file_path_included(&project_path, cx) .is_some() { return None; } let file_name = active_buffer.file()?.file_name(cx); - let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx); - Some(SuggestedContext::File { name: file_name.to_string_lossy().into_owned().into(), buffer: active_buffer_entity.downgrade(), @@ -135,7 +152,6 @@ impl ContextStrip { .context_store .read(cx) .includes_thread(active_thread.id()) - .is_some() { return None; } @@ -272,12 +288,12 @@ impl ContextStrip { best.map(|(index, _, _)| index) } - fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) { + fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) { let Some(workspace) = self.workspace.upgrade() else { return; }; - crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx); + crate::active_thread::open_context(context, workspace, window, cx); } fn remove_focused_context( @@ -287,17 +303,17 @@ impl ContextStrip { cx: &mut Context, ) { if let Some(index) = self.focused_index { - let mut is_empty = false; + let added_contexts = self.added_contexts(cx); + let Some(context) = added_contexts.get(index) else { + return; + }; self.context_store.update(cx, |this, cx| { - if let Some(item) = this.context().get(index) { - this.remove_context(item.id(), cx); - } - - is_empty = this.context().is_empty(); + this.remove_context(&context.context, cx); }); - if is_empty { + let is_now_empty = added_contexts.len() == 1; + if is_now_empty { cx.emit(ContextStripEvent::BlurredEmpty); } else { self.focused_index = Some(index.saturating_sub(1)); @@ -306,49 +322,28 @@ impl ContextStrip { } } - fn is_suggested_focused(&self, context: &Vec) -> bool { + fn is_suggested_focused(&self, added_contexts: &Vec) -> bool { // We only suggest one item after the actual context - self.focused_index == Some(context.len()) + self.focused_index == Some(added_contexts.len()) } fn accept_suggested_context( &mut self, _: &AcceptSuggestedContext, - window: &mut Window, + _window: &mut Window, cx: &mut Context, ) { if let Some(suggested) = self.suggested_context(cx) { - let context_store = self.context_store.read(cx); - - if self.is_suggested_focused(context_store.context()) { - self.add_suggested_context(&suggested, window, cx); + if self.is_suggested_focused(&self.added_contexts(cx)) { + self.add_suggested_context(&suggested, cx); } } } - fn add_suggested_context( - &mut self, - suggested: &SuggestedContext, - window: &mut Window, - cx: &mut Context, - ) { - let task = self.context_store.update(cx, |context_store, cx| { - context_store.accept_suggested_context(&suggested, cx) + fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context) { + self.context_store.update(cx, |context_store, cx| { + context_store.add_suggested_context(&suggested, cx) }); - - cx.spawn_in(window, async move |this, cx| { - match task.await.notify_async_err(cx) { - None => {} - Some(()) => { - if let Some(this) = this.upgrade() { - this.update(cx, |_, cx| cx.notify())?; - } - } - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - cx.notify(); } } @@ -361,17 +356,10 @@ impl Focusable for ContextStrip { impl Render for ContextStrip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let context_store = self.context_store.read(cx); - let context = context_store.context(); let context_picker = self.context_picker.clone(); let focus_handle = self.focus_handle.clone(); - let suggested_context = self.suggested_context(cx); - - let added_contexts = context - .iter() - .map(|c| AddedContext::new(c, cx)) - .collect::>(); + let added_contexts = self.added_contexts(cx); let dupe_names = added_contexts .iter() .map(|c| c.name.clone()) @@ -380,6 +368,14 @@ impl Render for ContextStrip { .filter(|(a, b)| a == b) .map(|(a, _)| a) .collect::>(); + let no_added_context = added_contexts.is_empty(); + + let suggested_context = self.suggested_context(cx).map(|suggested_context| { + ( + suggested_context, + self.is_suggested_focused(&added_contexts), + ) + }); h_flex() .flex_wrap() @@ -436,7 +432,7 @@ impl Render for ContextStrip { }) .with_handle(self.context_picker_menu_handle.clone()), ) - .when(context.is_empty() && suggested_context.is_none(), { + .when(no_added_context && suggested_context.is_none(), { |parent| { parent.child( h_flex() @@ -466,16 +462,17 @@ impl Render for ContextStrip { .enumerate() .map(|(i, added_context)| { let name = added_context.name.clone(); - let id = added_context.id; + let context = added_context.context.clone(); ContextPill::added( added_context, dupe_names.contains(&name), self.focused_index == Some(i), Some({ + let context = context.clone(); let context_store = self.context_store.clone(); Rc::new(cx.listener(move |_this, _event, _window, cx| { context_store.update(cx, |this, cx| { - this.remove_context(id, cx); + this.remove_context(&context, cx); }); cx.notify(); })) @@ -484,7 +481,7 @@ impl Render for ContextStrip { .on_click({ Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { if event.down.click_count > 1 { - this.open_context(id, window, cx); + this.open_context(&context, window, cx); } else { this.focused_index = Some(i); } @@ -493,22 +490,22 @@ impl Render for ContextStrip { }) }), ) - .when_some(suggested_context, |el, suggested| { + .when_some(suggested_context, |el, (suggested, focused)| { el.child( ContextPill::suggested( suggested.name().clone(), suggested.icon_path(), suggested.kind(), - self.is_suggested_focused(&context), + focused, ) .on_click(Rc::new(cx.listener( - move |this, _event, window, cx| { - this.add_suggested_context(&suggested, window, cx); + move |this, _event, _window, cx| { + this.add_suggested_context(&suggested, cx); }, ))), ) }) - .when(!context.is_empty(), { + .when(!no_added_context, { move |parent| { parent.child( IconButton::new("remove-all-context", IconName::Eraser) @@ -534,6 +531,7 @@ impl Render for ContextStrip { ) } }) + .into_any() } } diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 3947a7dc0e..029fe1b381 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -51,7 +51,10 @@ impl HistoryStore { return history_entries; } - for thread in self.thread_store.update(cx, |this, _cx| this.threads()) { + for thread in self + .thread_store + .update(cx, |this, _cx| this.reverse_chronological_threads()) + { history_entries.push(HistoryEntry::Thread(thread)); } diff --git a/crates/agent/src/inline_assistant.rs b/crates/agent/src/inline_assistant.rs index e8d626b82a..6785d08574 100644 --- a/crates/agent/src/inline_assistant.rs +++ b/crates/agent/src/inline_assistant.rs @@ -32,6 +32,7 @@ use project::LspAction; use project::Project; use project::{CodeAction, ProjectTransaction}; use prompt_store::PromptBuilder; +use prompt_store::PromptStore; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use terminal_view::{TerminalView, terminal_panel::TerminalPanel}; @@ -245,9 +246,13 @@ impl InlineAssistant { .map_or(false, |model| model.provider.is_authenticated(cx)) }; - let thread_store = workspace + let assistant_panel = workspace .panel::(cx) - .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade()); + .map(|assistant_panel| assistant_panel.read(cx)); + let prompt_store = assistant_panel + .and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned()); + let thread_store = + assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade()); let handle_assist = |window: &mut Window, cx: &mut Context| match inline_assist_target { @@ -257,6 +262,7 @@ impl InlineAssistant { &active_editor, cx.entity().downgrade(), workspace.project().downgrade(), + prompt_store, thread_store, window, cx, @@ -269,6 +275,7 @@ impl InlineAssistant { &active_terminal, cx.entity().downgrade(), workspace.project().downgrade(), + prompt_store, thread_store, window, cx, @@ -323,6 +330,7 @@ impl InlineAssistant { editor: &Entity, workspace: WeakEntity, project: WeakEntity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -437,6 +445,8 @@ impl InlineAssistant { range.clone(), None, context_store.clone(), + project.clone(), + prompt_store.clone(), self.telemetry.clone(), self.prompt_builder.clone(), cx, @@ -525,6 +535,7 @@ impl InlineAssistant { initial_transaction_id: Option, focus: bool, workspace: Entity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -543,7 +554,7 @@ impl InlineAssistant { } let project = workspace.read(cx).project().downgrade(); - let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone())); + let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone())); let codegen = cx.new(|cx| { BufferCodegen::new( @@ -551,6 +562,8 @@ impl InlineAssistant { range.clone(), initial_transaction_id, context_store.clone(), + project, + prompt_store, self.telemetry.clone(), self.prompt_builder.clone(), cx, @@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { let editor = self.editor.clone(); let workspace = self.workspace.clone(); let thread_store = self.thread_store.clone(); + let prompt_store = PromptStore::global(cx); window.spawn(cx, async move |cx| { let workspace = workspace.upgrade().context("workspace was released")?; let editor = editor.upgrade().context("editor was released")?; @@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { })? .context("invalid range")?; + let prompt_store = prompt_store.await.ok(); cx.update_global(|assistant: &mut InlineAssistant, window, cx| { let assist_id = assistant.suggest_assist( &editor, @@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { None, true, workspace, + prompt_store, thread_store, window, cx, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 05599b536b..ef60b0e6c4 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use crate::assistant_model_selector::ModelType; -use crate::context::{AssistantContext, format_context_as_string}; +use crate::context::{ContextLoadResult, load_context}; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use buffer_diff::BufferDiff; use collections::HashSet; @@ -13,6 +13,8 @@ use editor::{ }; use file_icons::FileIcons; use fs::Fs; +use futures::future::Shared; +use futures::{FutureExt as _, future}; use gpui::{ Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, @@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; +use prompt_store::PromptStore; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -31,7 +34,7 @@ use workspace::Workspace; use crate::assistant_model_selector::AssistantModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; -use crate::context_store::{ContextStore, refresh_context_store_text}; +use crate::context_store::ContextStore; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; use crate::thread::{Thread, TokenUsageRatio}; @@ -49,9 +52,12 @@ pub struct MessageEditor { workspace: WeakEntity, project: Entity, context_store: Entity, + prompt_store: Option>, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, + last_loaded_context: Option, + context_load_task: Option>>, profile_selector: Entity, edits_expanded: bool, editor_is_expanded: bool, @@ -68,6 +74,7 @@ impl MessageEditor { fs: Arc, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, thread_store: WeakEntity, thread: Entity, window: &mut Window, @@ -135,13 +142,11 @@ impl MessageEditor { let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), cx.subscribe(&editor, |this, _, event, cx| match event { - EditorEvent::BufferEdited => { - this.message_or_context_changed(true, cx); - } + EditorEvent::BufferEdited => this.handle_message_changed(cx), _ => {} }), cx.observe(&context_store, |this, _, cx| { - this.message_or_context_changed(false, cx); + this.handle_context_changed(cx) }), ]; @@ -152,8 +157,11 @@ impl MessageEditor { incompatible_tools_state: incompatible_tools.clone(), workspace, context_store, + prompt_store, context_strip, context_picker_menu_handle, + context_load_task: None, + last_loaded_context: None, model_selector: cx.new(|cx| { AssistantModelSelector::new( fs.clone(), @@ -175,6 +183,10 @@ impl MessageEditor { } } + pub fn context_store(&self) -> &Entity { + &self.context_store + } + fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context) { cx.notify(); } @@ -214,6 +226,7 @@ impl MessageEditor { ) { self.context_picker_menu_handle.toggle(window, cx); } + pub fn remove_all_context( &mut self, _: &RemoveAllContext, @@ -270,68 +283,22 @@ impl MessageEditor { self.last_estimated_token_count.take(); cx.emit(MessageEditorEvent::EstimatedTokenCount); - let refresh_task = - refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); - let wait_for_images = self.context_store.read(cx).wait_for_images(cx); - let thread = self.thread.clone(); - let context_store = self.context_store.clone(); let git_store = self.project.read(cx).git_store().clone(); let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); + let context_task = self.wait_for_context(cx); let window_handle = window.window_handle(); - cx.spawn(async move |this, cx| { - let checkpoint = checkpoint.await.ok(); - refresh_task.await; - wait_for_images.await; + cx.spawn(async move |_this, cx| { + let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await; + let loaded_context = loaded_context.unwrap_or_default(); thread .update(cx, |thread, cx| { - let context = context_store.read(cx).context().clone(); - thread.insert_user_message(user_message, context, checkpoint, cx); + thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx); }) .log_err(); - context_store - .update(cx, |context_store, cx| { - let excerpt_ids = context_store - .context() - .iter() - .filter(|ctx| { - matches!( - ctx, - AssistantContext::Selection(_) | AssistantContext::Image(_) - ) - }) - .map(|ctx| ctx.id()) - .collect::>(); - - for id in excerpt_ids { - context_store.remove_context(id, cx); - } - }) - .log_err(); - - if let Some(wait_for_summaries) = context_store - .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) - .log_err() - { - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = true; - cx.notify(); - }) - .log_err(); - - wait_for_summaries.await; - - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = false; - cx.notify(); - }) - .log_err(); - } - - // Send to model after summaries are done thread .update(cx, |thread, cx| { thread.advance_prompt_id(); @@ -342,6 +309,30 @@ impl MessageEditor { .detach(); } + fn wait_for_summaries(&mut self, cx: &mut Context) -> Task<()> { + let context_store = self.context_store.clone(); + cx.spawn(async move |this, cx| { + if let Some(wait_for_summaries) = context_store + .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) + .ok() + { + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = true; + cx.notify(); + }) + .ok(); + + wait_for_summaries.await; + + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = false; + cx.notify(); + }) + .ok(); + } + }) + } + fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context) { let cancelled = self.thread.update(cx, |thread, cx| { thread.cancel_last_completion(Some(window.window_handle()), cx) @@ -1015,6 +1006,49 @@ impl MessageEditor { self.update_token_count_task.is_some() } + fn handle_message_changed(&mut self, cx: &mut Context) { + self.message_or_context_changed(true, cx); + } + + fn handle_context_changed(&mut self, cx: &mut Context) { + let summaries_task = self.wait_for_summaries(cx); + let load_task = cx.spawn(async move |this, cx| { + // Waits for detailed summaries before `load_context`, as it directly reads these from + // the thread. TODO: Would be cleaner to have context loading await on summarization. + summaries_task.await; + let Ok(load_task) = this.update(cx, |this, cx| { + let new_context = this.context_store.read_with(cx, |context_store, cx| { + context_store.new_context_for_thread(this.thread.read(cx)) + }); + load_context(new_context, &this.project, &this.prompt_store, cx) + }) else { + return; + }; + let result = load_task.await; + this.update(cx, |this, cx| { + this.last_loaded_context = Some(result); + this.context_load_task = None; + this.message_or_context_changed(false, cx); + }) + .ok(); + }); + // Replace existing load task, if any, causing it to be cancelled. + self.context_load_task = Some(load_task.shared()); + } + + fn wait_for_context(&self, cx: &mut Context) -> Task> { + if let Some(context_load_task) = self.context_load_task.clone() { + cx.spawn(async move |this, cx| { + context_load_task.await; + this.read_with(cx, |this, _cx| this.last_loaded_context.clone()) + .ok() + .flatten() + }) + } else { + Task::ready(self.last_loaded_context.clone()) + } + } + fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context) { cx.emit(MessageEditorEvent::Changed); self.update_token_count_task.take(); @@ -1024,9 +1058,7 @@ impl MessageEditor { return; }; - let context_store = self.context_store.clone(); let editor = self.editor.clone(); - let thread = self.thread.clone(); self.update_token_count_task = Some(cx.spawn(async move |this, cx| { if debounce { @@ -1035,27 +1067,33 @@ impl MessageEditor { .await; } - let token_count = if let Some(task) = cx.update(|cx| { - let context = context_store.read(cx).context().iter(); - let new_context = thread.read(cx).filter_new_context(context); - let context_text = - format_context_as_string(new_context, cx).unwrap_or(String::new()); + let token_count = if let Some(task) = this.update(cx, |this, cx| { + let loaded_context = this + .last_loaded_context + .as_ref() + .map(|context_load_result| &context_load_result.loaded_context); let message_text = editor.read(cx).text(cx); - let content = context_text + &message_text; - - if content.is_empty() { + if message_text.is_empty() + && loaded_context.map_or(true, |loaded_context| loaded_context.is_empty()) + { return None; } + let mut request_message = LanguageModelRequestMessage { + role: language_model::Role::User, + content: Vec::new(), + cache: false, + }; + + if let Some(loaded_context) = loaded_context { + loaded_context.add_to_request_message(&mut request_message); + } + let request = language_model::LanguageModelRequest { thread_id: None, prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: language_model::Role::User, - content: vec![content.into()], - cache: false, - }], + messages: vec![request_message], tools: vec![], stop: vec![], temperature: None, diff --git a/crates/agent/src/terminal_codegen.rs b/crates/agent/src/terminal_codegen.rs index 8c0e9e1675..925187c7cf 100644 --- a/crates/agent/src/terminal_codegen.rs +++ b/crates/agent/src/terminal_codegen.rs @@ -32,7 +32,7 @@ impl TerminalCodegen { } } - pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context) { + pub fn start(&mut self, prompt_task: Task, cx: &mut Context) { let Some(ConfiguredModel { model, .. }) = LanguageModelRegistry::read_global(cx).inline_assistant_model() else { @@ -45,6 +45,7 @@ impl TerminalCodegen { self.status = CodegenStatus::Pending; self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); self.generation = cx.spawn(async move |this, cx| { + let prompt = prompt_task.await; let model_telemetry_id = model.telemetry_id(); let model_provider_id = model.provider_id(); let response = model.stream_completion_text(prompt, &cx).await; diff --git a/crates/agent/src/terminal_inline_assistant.rs b/crates/agent/src/terminal_inline_assistant.rs index 95099e542c..b7690fa580 100644 --- a/crates/agent/src/terminal_inline_assistant.rs +++ b/crates/agent/src/terminal_inline_assistant.rs @@ -1,4 +1,4 @@ -use crate::context::attach_context_to_message; +use crate::context::load_context; use crate::context_store::ContextStore; use crate::inline_prompt_editor::{ CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, @@ -10,14 +10,14 @@ use client::telemetry::Telemetry; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; -use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity}; +use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity}; use language::Buffer; use language_model::{ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, report_assistant_event, }; use project::Project; -use prompt_store::PromptBuilder; +use prompt_store::{PromptBuilder, PromptStore}; use std::sync::Arc; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use terminal_view::TerminalView; @@ -69,6 +69,7 @@ impl TerminalInlineAssistant { terminal_view: &Entity, workspace: WeakEntity, project: WeakEntity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -109,6 +110,7 @@ impl TerminalInlineAssistant { prompt_editor, workspace.clone(), context_store, + prompt_store, window, cx, ); @@ -196,11 +198,11 @@ impl TerminalInlineAssistant { .log_err(); let codegen = assist.codegen.clone(); - let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else { + let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else { return; }; - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + codegen.update(cx, |codegen, cx| codegen.start(request_task, cx)); } fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) { @@ -217,7 +219,7 @@ impl TerminalInlineAssistant { &self, assist_id: TerminalInlineAssistId, cx: &mut App, - ) -> Result { + ) -> Result> { let assist = self.assists.get(&assist_id).context("invalid assist")?; let shell = std::env::var("SHELL").ok(); @@ -246,28 +248,40 @@ impl TerminalInlineAssistant { &latest_output, )?; - let mut request_message = LanguageModelRequestMessage { - role: Role::User, - content: vec![], - cache: false, - }; + let contexts = assist + .context_store + .read(cx) + .context() + .cloned() + .collect::>(); + let context_load_task = assist.workspace.update(cx, |workspace, cx| { + let project = workspace.project(); + load_context(contexts, project, &assist.prompt_store, cx) + })?; - attach_context_to_message( - &mut request_message, - assist.context_store.read(cx).context().iter(), - cx, - ); + Ok(cx.background_spawn(async move { + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![], + cache: false, + }; - request_message.content.push(prompt.into()); + context_load_task + .await + .loaded_context + .add_to_request_message(&mut request_message); - Ok(LanguageModelRequest { - thread_id: None, - prompt_id: None, - messages: vec![request_message], - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - }) + request_message.content.push(prompt.into()); + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![request_message], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + } + })) } fn finish_assist( @@ -380,6 +394,7 @@ struct TerminalInlineAssist { codegen: Entity, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, _subscriptions: Vec, } @@ -390,6 +405,7 @@ impl TerminalInlineAssist { prompt_editor: Entity>, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, window: &mut Window, cx: &mut App, ) -> Self { @@ -400,6 +416,7 @@ impl TerminalInlineAssist { codegen: codegen.clone(), workspace: workspace.clone(), context_store, + prompt_store, _subscriptions: vec![ window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| { TerminalInlineAssistant::update_global(cx, |this, cx| { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index cce7d109c0..29709490c9 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,7 +8,7 @@ use anyhow::{Result, anyhow}; use assistant_settings::AssistantSettings; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use collections::{BTreeMap, HashMap}; +use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; @@ -18,9 +18,9 @@ use gpui::{ }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, + LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, TokenUsage, }; @@ -35,7 +35,7 @@ use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; -use crate::context::{AssistantContext, ContextId, format_context_as_string}; +use crate::context::{AgentContext, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext, @@ -98,8 +98,7 @@ pub struct Message { pub id: MessageId, pub role: Role, pub segments: Vec, - pub context: String, - pub images: Vec, + pub loaded_context: LoadedContext, } impl Message { @@ -138,8 +137,8 @@ impl Message { pub fn to_string(&self) -> String { let mut result = String::new(); - if !self.context.is_empty() { - result.push_str(&self.context); + if !self.loaded_context.text.is_empty() { + result.push_str(&self.loaded_context.text); } for segment in &self.segments { @@ -294,8 +293,6 @@ pub struct Thread { messages: Vec, next_message_id: MessageId, last_prompt_id: PromptId, - context: BTreeMap, - context_by_message: HashMap>, project_context: SharedProjectContext, checkpoints_by_message: HashMap, completion_count: usize, @@ -345,8 +342,6 @@ impl Thread { messages: Vec::new(), next_message_id: MessageId(0), last_prompt_id: PromptId::new(), - context: BTreeMap::default(), - context_by_message: HashMap::default(), project_context: system_prompt, checkpoints_by_message: HashMap::default(), completion_count: 0, @@ -418,14 +413,15 @@ impl Thread { } }) .collect(), - context: message.context, - images: Vec::new(), + loaded_context: LoadedContext { + contexts: Vec::new(), + text: message.context, + images: Vec::new(), + }, }) .collect(), next_message_id, last_prompt_id: PromptId::new(), - context: BTreeMap::default(), - context_by_message: HashMap::default(), project_context, checkpoints_by_message: HashMap::default(), completion_count: 0, @@ -660,21 +656,17 @@ impl Thread { return; }; for deleted_message in self.messages.drain(message_ix..) { - self.context_by_message.remove(&deleted_message.id); self.checkpoints_by_message.remove(&deleted_message.id); } cx.notify(); } - pub fn context_for_message(&self, id: MessageId) -> impl Iterator { - self.context_by_message - .get(&id) + pub fn context_for_message(&self, id: MessageId) -> impl Iterator { + self.messages + .iter() + .find(|message| message.id == id) .into_iter() - .flat_map(|context| { - context - .iter() - .filter_map(|context_id| self.context.get(&context_id)) - }) + .flat_map(|message| message.loaded_context.contexts.iter()) } pub fn is_turn_end(&self, ix: usize) -> bool { @@ -736,91 +728,27 @@ impl Thread { self.tool_use.tool_result_card(id).cloned() } - /// Filter out contexts that have already been included in previous messages - pub fn filter_new_context<'a>( - &self, - context: impl Iterator, - ) -> impl Iterator { - context.filter(|ctx| self.is_context_new(ctx)) - } - - fn is_context_new(&self, context: &AssistantContext) -> bool { - !self.context.contains_key(&context.id()) - } - pub fn insert_user_message( &mut self, text: impl Into, - context: Vec, + loaded_context: ContextLoadResult, git_checkpoint: Option, cx: &mut Context, ) -> MessageId { - let text = text.into(); - - let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx); - - let new_context: Vec<_> = context - .into_iter() - .filter(|ctx| self.is_context_new(ctx)) - .collect(); - - if !new_context.is_empty() { - if let Some(context_string) = format_context_as_string(new_context.iter(), cx) { - if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) { - message.context = context_string; - } - } - - if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) { - message.images = new_context - .iter() - .filter_map(|context| { - if let AssistantContext::Image(image_context) = context { - image_context.image_task.clone().now_or_never().flatten() - } else { - None - } - }) - .collect::>(); - } - + if !loaded_context.referenced_buffers.is_empty() { self.action_log.update(cx, |log, cx| { - // Track all buffers added as context - for ctx in &new_context { - match ctx { - AssistantContext::File(file_ctx) => { - log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx); - } - AssistantContext::Directory(dir_ctx) => { - for context_buffer in &dir_ctx.context_buffers { - log.track_buffer(context_buffer.buffer.clone(), cx); - } - } - AssistantContext::Symbol(symbol_ctx) => { - log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx); - } - AssistantContext::Selection(selection_context) => { - log.track_buffer(selection_context.context_buffer.buffer.clone(), cx); - } - AssistantContext::FetchedUrl(_) - | AssistantContext::Thread(_) - | AssistantContext::Rules(_) - | AssistantContext::Image(_) => {} - } + for buffer in loaded_context.referenced_buffers { + log.track_buffer(buffer, cx); } }); } - let context_ids = new_context - .iter() - .map(|context| context.id()) - .collect::>(); - self.context.extend( - new_context - .into_iter() - .map(|context| (context.id(), context)), + let message_id = self.insert_message( + Role::User, + vec![MessageSegment::Text(text.into())], + loaded_context.loaded_context, + cx, ); - self.context_by_message.insert(message_id, context_ids); if let Some(git_checkpoint) = git_checkpoint { self.pending_checkpoint = Some(ThreadCheckpoint { @@ -834,10 +762,19 @@ impl Thread { message_id } + pub fn insert_assistant_message( + &mut self, + segments: Vec, + cx: &mut Context, + ) -> MessageId { + self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx) + } + pub fn insert_message( &mut self, role: Role, segments: Vec, + loaded_context: LoadedContext, cx: &mut Context, ) -> MessageId { let id = self.next_message_id.post_inc(); @@ -845,8 +782,7 @@ impl Thread { id, role, segments, - context: String::new(), - images: Vec::new(), + loaded_context, }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); @@ -875,7 +811,6 @@ impl Thread { return false; }; self.messages.remove(index); - self.context_by_message.remove(&id); self.touch_updated_at(); cx.emit(ThreadEvent::MessageDeleted(id)); true @@ -962,7 +897,7 @@ impl Thread { content: tool_result.content.clone(), }) .collect(), - context: message.context.clone(), + context: message.loaded_context.text.clone(), }) .collect(), initial_project_snapshot, @@ -1080,26 +1015,9 @@ impl Thread { cache: false, }; - if !message.context.is_empty() { - request_message - .content - .push(MessageContent::Text(message.context.to_string())); - } - - if !message.images.is_empty() { - // Some providers only support image parts after an initial text part - if request_message.content.is_empty() { - request_message - .content - .push(MessageContent::Text("Images attached by user:".to_string())); - } - - for image in &message.images { - request_message - .content - .push(MessageContent::Image(image.clone())) - } - } + message + .loaded_context + .add_to_request_message(&mut request_message); for segment in &message.segments { match segment { @@ -1301,11 +1219,11 @@ impl Thread { match event { LanguageModelCompletionEvent::StartMessage { .. } => { - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Text(String::new())], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Text(String::new())], + cx, + )); } LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -1334,11 +1252,11 @@ impl Thread { // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Text(chunk.to_string())], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Text(chunk.to_string())], + cx, + )); }; } } @@ -1361,14 +1279,14 @@ impl Thread { // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Thinking { - text: chunk.to_string(), - signature, - }], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Thinking { + text: chunk.to_string(), + signature, + }], + cx, + )); }; } } @@ -1376,7 +1294,7 @@ impl Thread { let last_assistant_message_id = request_assistant_message_id .unwrap_or_else(|| { let new_assistant_message_id = - thread.insert_message(Role::Assistant, vec![], cx); + thread.insert_assistant_message(vec![], cx); request_assistant_message_id = Some(new_assistant_message_id); new_assistant_message_id @@ -2097,8 +2015,16 @@ impl Thread { } )?; - if !message.context.is_empty() { - writeln!(markdown, "{}", message.context)?; + if !message.loaded_context.text.is_empty() { + writeln!(markdown, "{}", message.loaded_context.text)?; + } + + if !message.loaded_context.images.is_empty() { + writeln!( + markdown, + "\n{} images attached as context.\n", + message.loaded_context.images.len() + )?; } for segment in &message.segments { @@ -2373,7 +2299,7 @@ struct PendingCompletion { #[cfg(test)] mod tests { use super::*; - use crate::{ThreadStore, context_store::ContextStore, thread_store}; + use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use assistant_settings::AssistantSettings; use context_server::ContextServerSettings; use editor::EditorSettings; @@ -2404,12 +2330,14 @@ mod tests { .await .unwrap(); - let context = - context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap()); + let loaded_context = cx + .update(|cx| load_context(vec![context], &project, &None, cx)) + .await; // Insert user message with context let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Please explain this code", vec![context], None, cx) + thread.insert_user_message("Please explain this code", loaded_context, None, cx) }); // Check content and context in message object @@ -2443,7 +2371,7 @@ fn main() {{ message.segments[0], MessageSegment::Text("Please explain this code".to_string()) ); - assert_eq!(message.context, expected_context); + assert_eq!(message.loaded_context.text, expected_context); // Check message in request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2470,48 +2398,50 @@ fn main() {{ let (_, _thread_store, thread, context_store) = setup_test_environment(cx, project.clone()).await; - // Open files individually + // First message with context 1 add_file_to_context(&project, &context_store, "test/file1.rs", cx) .await .unwrap(); - add_file_to_context(&project, &context_store, "test/file2.rs", cx) - .await - .unwrap(); - add_file_to_context(&project, &context_store, "test/file3.rs", cx) - .await - .unwrap(); - - // Get the context objects - let contexts = context_store.update(cx, |store, _| store.context().clone()); - assert_eq!(contexts.len(), 3); - - // First message with context 1 + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message1_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx) + thread.insert_user_message("Message 1", loaded_context, None, cx) }); // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) + add_file_to_context(&project, &context_store, "test/file2.rs", cx) + .await + .unwrap(); + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Message 2", - vec![contexts[0].clone(), contexts[1].clone()], - None, - cx, - ) + thread.insert_user_message("Message 2", loaded_context, None, cx) }); // Third message with all three contexts (contexts 1 and 2 should be skipped) + // + add_file_to_context(&project, &context_store, "test/file3.rs", cx) + .await + .unwrap(); + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message3_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Message 3", - vec![ - contexts[0].clone(), - contexts[1].clone(), - contexts[2].clone(), - ], - None, - cx, - ) + thread.insert_user_message("Message 3", loaded_context, None, cx) }); // Check what contexts are included in each message @@ -2524,16 +2454,16 @@ fn main() {{ }); // First message should include context 1 - assert!(message1.context.contains("file1.rs")); + assert!(message1.loaded_context.text.contains("file1.rs")); // Second message should include only context 2 (not 1) - assert!(!message2.context.contains("file1.rs")); - assert!(message2.context.contains("file2.rs")); + assert!(!message2.loaded_context.text.contains("file1.rs")); + assert!(message2.loaded_context.text.contains("file2.rs")); // Third message should include only context 3 (not 1 or 2) - assert!(!message3.context.contains("file1.rs")); - assert!(!message3.context.contains("file2.rs")); - assert!(message3.context.contains("file3.rs")); + assert!(!message3.loaded_context.text.contains("file1.rs")); + assert!(!message3.loaded_context.text.contains("file2.rs")); + assert!(message3.loaded_context.text.contains("file3.rs")); // Check entire request to make sure all contexts are properly included let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2570,7 +2500,12 @@ fn main() {{ // Insert user message without any context (empty context vector) let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx) + thread.insert_user_message( + "What is the best way to learn Rust?", + ContextLoadResult::default(), + None, + cx, + ) }); // Check content and context in message object @@ -2583,7 +2518,7 @@ fn main() {{ message.segments[0], MessageSegment::Text("What is the best way to learn Rust?".to_string()) ); - assert_eq!(message.context, ""); + assert_eq!(message.loaded_context.text, ""); // Check message in request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2596,12 +2531,17 @@ fn main() {{ // Add second message, also without context let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Are there any good books?", vec![], None, cx) + thread.insert_user_message( + "Are there any good books?", + ContextLoadResult::default(), + None, + cx, + ) }); let message2 = thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); - assert_eq!(message2.context, ""); + assert_eq!(message2.loaded_context.text, ""); // Check that both messages appear in the request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2635,12 +2575,14 @@ fn main() {{ .await .unwrap(); - let context = - context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap()); + let loaded_context = cx + .update(|cx| load_context(vec![context], &project, &None, cx)) + .await; // Insert user message with the buffer as context thread.update(cx, |thread, cx| { - thread.insert_user_message("Explain this code", vec![context], None, cx) + thread.insert_user_message("Explain this code", loaded_context, None, cx) }); // Create a request and check that it doesn't have a stale buffer warning yet @@ -2668,7 +2610,12 @@ fn main() {{ // Insert another user message without context thread.update(cx, |thread, cx| { - thread.insert_user_message("What does the code do now?", vec![], None, cx) + thread.insert_user_message( + "What does the code do now?", + ContextLoadResult::default(), + None, + cx, + ) }); // Create a new request and check for the stale buffer warning @@ -2735,6 +2682,7 @@ fn main() {{ ThreadStore::load( project.clone(), cx.new(|_| ToolWorkingSet::default()), + None, Arc::new(PromptBuilder::new(None).unwrap()), cx, ) @@ -2759,15 +2707,15 @@ fn main() {{ .unwrap(); let buffer = project - .update(cx, |project, cx| project.open_buffer(buffer_path, cx)) + .update(cx, |project, cx| { + project.open_buffer(buffer_path.clone(), cx) + }) .await .unwrap(); - context_store - .update(cx, |store, cx| { - store.add_file_from_buffer(buffer.clone(), cx) - }) - .await?; + context_store.update(cx, |context_store, cx| { + context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); + }); Ok(buffer) } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 117be3675c..a907e1d440 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -24,8 +24,8 @@ use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::{Project, Worktree}; use prompt_store::{ - ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent, - RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext, + ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext, + UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; @@ -82,12 +82,11 @@ impl ThreadStore { pub fn load( project: Entity, tools: Entity, + prompt_store: Option>, prompt_builder: Arc, cx: &mut App, ) -> Task>> { - let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { - let prompt_store = prompt_store.await.ok(); let (thread_store, ready_rx) = cx.update(|cx| { let mut option_ready_rx = None; let thread_store = cx.new(|cx| { @@ -349,25 +348,8 @@ impl ThreadStore { self.context_server_manager.clone() } - pub fn prompt_store(&self) -> Option> { - self.prompt_store.clone() - } - - pub fn load_rules( - &self, - prompt_id: UserPromptId, - cx: &App, - ) -> Task> { - let prompt_id = PromptId::User { uuid: prompt_id }; - let Some(prompt_store) = self.prompt_store.as_ref() else { - return Task::ready(Err(anyhow!("Prompt store unexpectedly missing."))); - }; - let prompt_store = prompt_store.read(cx); - let Some(metadata) = prompt_store.metadata(prompt_id) else { - return Task::ready(Err(anyhow!("User rules not found in library."))); - }; - let text_task = prompt_store.load(prompt_id, cx); - cx.background_spawn(async move { Ok((metadata, text_task.await?)) }) + pub fn prompt_store(&self) -> &Option> { + &self.prompt_store } pub fn tools(&self) -> Entity { @@ -379,16 +361,12 @@ impl ThreadStore { self.threads.len() } - pub fn threads(&self) -> Vec { + pub fn reverse_chronological_threads(&self) -> Vec { let mut threads = self.threads.iter().cloned().collect::>(); threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at)); threads } - pub fn recent_threads(&self, limit: usize) -> Vec { - self.threads().into_iter().take(limit).collect() - } - pub fn create_thread(&mut self, cx: &mut Context) -> Entity { cx.new(|cx| { Thread::new( diff --git a/crates/agent/src/ui/context_pill.rs b/crates/agent/src/ui/context_pill.rs index a3c6608179..c31af4b642 100644 --- a/crates/agent/src/ui/context_pill.rs +++ b/crates/agent/src/ui/context_pill.rs @@ -1,14 +1,13 @@ -use std::sync::Arc; use std::{rc::Rc, time::Duration}; use file_icons::FileIcons; -use futures::FutureExt; -use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between}; -use gpui::{ClickEvent, Task}; -use language_model::LanguageModelImage; +use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between}; +use project::Project; +use prompt_store::PromptStore; +use text::OffsetRangeExt; use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container}; -use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext}; +use crate::context::{AgentContext, ContextKind, ImageStatus}; #[derive(IntoElement)] pub enum ContextPill { @@ -73,9 +72,7 @@ impl ContextPill { pub fn id(&self) -> ElementId { match self { - Self::Added { context, .. } => { - ElementId::NamedInteger("context-pill".into(), context.id.0) - } + Self::Added { context, .. } => context.context.element_id("context-pill".into()), Self::Suggested { .. } => "suggested-context-pill".into(), } } @@ -199,14 +196,17 @@ impl RenderOnce for ContextPill { ) .when_some(on_remove.as_ref(), |element, on_remove| { element.child( - IconButton::new(("remove", context.id.0), IconName::Close) - .shape(IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .tooltip(Tooltip::text("Remove Context")) - .on_click({ - let on_remove = on_remove.clone(); - move |event, window, cx| on_remove(event, window, cx) - }), + IconButton::new( + context.context.element_id("remove".into()), + IconName::Close, + ) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .tooltip(Tooltip::text("Remove Context")) + .on_click({ + let on_remove = on_remove.clone(); + move |event, window, cx| on_remove(event, window, cx) + }), ) }) .when_some(on_click.as_ref(), |element, on_click| { @@ -262,9 +262,11 @@ pub enum ContextStatus { Error { message: SharedString }, } -#[derive(RegisterComponent)] +// TODO: Component commented out due to new dependency on `Project`. +// +// #[derive(RegisterComponent)] pub struct AddedContext { - pub id: ContextId, + pub context: AgentContext, pub kind: ContextKind, pub name: SharedString, pub parent: Option, @@ -275,10 +277,19 @@ pub struct AddedContext { } impl AddedContext { - pub fn new(context: &AssistantContext, cx: &App) -> AddedContext { + /// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a + /// `None` if `DirectoryContext` or `RulesContext` no longer exist. + /// + /// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak. + pub fn new( + context: AgentContext, + prompt_store: Option<&Entity>, + project: &Project, + cx: &App, + ) -> Option { match context { - AssistantContext::File(file_context) => { - let full_path = file_context.context_buffer.full_path(cx); + AgentContext::File(ref file_context) => { + let full_path = file_context.buffer.read(cx).file()?.full_path(cx); let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let name = full_path @@ -289,8 +300,7 @@ impl AddedContext { .parent() .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: file_context.id, + Some(AddedContext { kind: ContextKind::File, name, parent, @@ -298,18 +308,16 @@ impl AddedContext { icon_path: FileIcons::get_icon(&full_path, cx), status: ContextStatus::Ready, render_preview: None, - } + context, + }) } - AssistantContext::Directory(directory_context) => { - let worktree = directory_context.worktree.read(cx); - // If the directory no longer exists, use its last known path. - let full_path = worktree - .entry_for_id(directory_context.entry_id) - .map_or_else( - || directory_context.last_path.clone(), - |entry| worktree.full_path(&entry.path).into(), - ); + AgentContext::Directory(ref directory_context) => { + let worktree = project + .worktree_for_entry(directory_context.entry_id, cx)? + .read(cx); + let entry = worktree.entry_for_id(directory_context.entry_id)?; + let full_path = worktree.full_path(&entry.path); let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let name = full_path @@ -320,8 +328,7 @@ impl AddedContext { .parent() .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: directory_context.id, + Some(AddedContext { kind: ContextKind::Directory, name, parent, @@ -329,33 +336,34 @@ impl AddedContext { icon_path: None, status: ContextStatus::Ready, render_preview: None, - } + context, + }) } - AssistantContext::Symbol(symbol_context) => AddedContext { - id: symbol_context.id, + AgentContext::Symbol(ref symbol_context) => Some(AddedContext { kind: ContextKind::Symbol, - name: symbol_context.context_symbol.id.name.clone(), + name: symbol_context.symbol.clone(), parent: None, tooltip: None, icon_path: None, status: ContextStatus::Ready, render_preview: None, - }, + context, + }), - AssistantContext::Selection(selection_context) => { - let full_path = selection_context.context_buffer.full_path(cx); + AgentContext::Selection(ref selection_context) => { + let buffer = selection_context.buffer.read(cx); + let full_path = buffer.file()?.full_path(cx); let mut full_path_string = full_path.to_string_lossy().into_owned(); let mut name = full_path .file_name() .map(|n| n.to_string_lossy().into_owned()) .unwrap_or_else(|| full_path_string.clone()); - let line_range_text = format!( - " ({}-{})", - selection_context.line_range.start.row + 1, - selection_context.line_range.end.row + 1 - ); + let line_range = selection_context.range.to_point(&buffer.snapshot()); + + let line_range_text = + format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1); full_path_string.push_str(&line_range_text); name.push_str(&line_range_text); @@ -365,16 +373,17 @@ impl AddedContext { .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: selection_context.id, + Some(AddedContext { kind: ContextKind::Selection, name: name.into(), parent, tooltip: None, icon_path: FileIcons::get_icon(&full_path, cx), status: ContextStatus::Ready, + render_preview: None, + /* render_preview: Some(Rc::new({ - let content = selection_context.context_buffer.text.clone(); + let content = selection_context.text.clone(); move |_, cx| { div() .id("context-pill-selection-preview") @@ -385,11 +394,12 @@ impl AddedContext { .into_any_element() } })), - } + */ + context, + }) } - AssistantContext::FetchedUrl(fetched_url_context) => AddedContext { - id: fetched_url_context.id, + AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext { kind: ContextKind::FetchedUrl, name: fetched_url_context.url.clone(), parent: None, @@ -397,12 +407,12 @@ impl AddedContext { icon_path: None, status: ContextStatus::Ready, render_preview: None, - }, + context, + }), - AssistantContext::Thread(thread_context) => AddedContext { - id: thread_context.id, + AgentContext::Thread(ref thread_context) => Some(AddedContext { kind: ContextKind::Thread, - name: thread_context.summary(cx), + name: thread_context.name(cx), parent: None, tooltip: None, icon_path: None, @@ -418,36 +428,41 @@ impl AddedContext { ContextStatus::Ready }, render_preview: None, - }, + context, + }), - AssistantContext::Rules(user_rules_context) => AddedContext { - id: user_rules_context.id, - kind: ContextKind::Rules, - name: user_rules_context.title.clone(), - parent: None, - tooltip: None, - icon_path: None, - status: ContextStatus::Ready, - render_preview: None, - }, + AgentContext::Rules(ref user_rules_context) => { + let name = prompt_store + .as_ref()? + .read(cx) + .metadata(user_rules_context.prompt_id.into())? + .title?; + Some(AddedContext { + kind: ContextKind::Rules, + name: name.clone(), + parent: None, + tooltip: None, + icon_path: None, + status: ContextStatus::Ready, + render_preview: None, + context, + }) + } - AssistantContext::Image(image_context) => AddedContext { - id: image_context.id, + AgentContext::Image(ref image_context) => Some(AddedContext { kind: ContextKind::Image, name: "Image".into(), parent: None, tooltip: None, icon_path: None, - status: if image_context.is_loading() { - ContextStatus::Loading { + status: match image_context.status() { + ImageStatus::Loading => ContextStatus::Loading { message: "Loading…".into(), - } - } else if image_context.is_error() { - ContextStatus::Error { + }, + ImageStatus::Error => ContextStatus::Error { message: "Failed to load image".into(), - } - } else { - ContextStatus::Ready + }, + ImageStatus::Ready => ContextStatus::Ready, }, render_preview: Some(Rc::new({ let image = image_context.original_image.clone(); @@ -458,7 +473,8 @@ impl AddedContext { .into_any_element() } })), - }, + context, + }), } } } @@ -478,6 +494,8 @@ impl Render for ContextPillPreview { } } +// TODO: Component commented out due to new dependency on `Project`. +/* impl Component for AddedContext { fn scope() -> ComponentScope { ComponentScope::Agent @@ -487,12 +505,13 @@ impl Component for AddedContext { "AddedContext" } - fn preview(_window: &mut Window, cx: &mut App) -> Option { + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + let next_context_id = ContextId::zero(); let image_ready = ( "Ready", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(0), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: Task::ready(Some(LanguageModelImage::empty())).shared(), }), @@ -503,8 +522,8 @@ impl Component for AddedContext { let image_loading = ( "Loading", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(1), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: cx .background_spawn(async move { @@ -520,8 +539,8 @@ impl Component for AddedContext { let image_error = ( "Error", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(2), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: Task::ready(None).shared(), }), @@ -544,5 +563,8 @@ impl Component for AddedContext { ) .into_any(), ) + + None } } +*/ diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 2e02f3f29a..d58ef2d58e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -25,7 +25,7 @@ use language_model::{ AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, }; use project::Project; -use prompt_store::{PromptBuilder, PromptId, UserPromptId}; +use prompt_store::{PromptBuilder, UserPromptId}; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search::DivRegistrar}; @@ -1059,9 +1059,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_select.map(|uuid| PromptId::User { - uuid: UserPromptId(uuid), - }), + action + .prompt_to_select + .map(|uuid| UserPromptId(uuid).into()), cx, ) .detach_and_log_err(cx); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index fccb9de7c8..c0c3c4cd99 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -10,7 +10,7 @@ use crate::{ ToolMetrics, assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, }; -use agent::ThreadEvent; +use agent::{ContextLoadResult, ThreadEvent}; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; @@ -115,7 +115,12 @@ impl ExampleContext { pub fn push_user_message(&mut self, text: impl ToString) { self.app .update_entity(&self.agent_thread, |thread, cx| { - thread.insert_user_message(text.to_string(), vec![], None, cx); + thread.insert_user_message( + text.to_string(), + ContextLoadResult::default(), + None, + cx, + ); }) .unwrap(); } diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 9210d0b818..e165506abf 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -218,8 +218,14 @@ impl ExampleInstance { }); let tools = cx.new(|_| ToolWorkingSet::default()); - let thread_store = - ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + let prompt_store = None; + let thread_store = ThreadStore::load( + project.clone(), + tools, + prompt_store, + app_state.prompt_builder.clone(), + cx, + ); let meta = self.thread.meta(); let this = self.clone(); diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index 84aaa688cd..b0d68fd416 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -60,9 +60,7 @@ pub enum PromptId { impl PromptId { pub fn new() -> PromptId { - PromptId::User { - uuid: UserPromptId::new(), - } + UserPromptId::new().into() } pub fn is_built_in(&self) -> bool { @@ -70,6 +68,12 @@ impl PromptId { } } +impl From for PromptId { + fn from(uuid: UserPromptId) -> Self { + PromptId::User { uuid } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct UserPromptId(pub Uuid); @@ -227,9 +231,7 @@ impl PromptStore { .collect::>>()?; for (prompt_id_v1, metadata_v1) in metadata_v1 { - let prompt_id_v2 = PromptId::User { - uuid: UserPromptId(prompt_id_v1.0), - }; + let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into(); let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { continue; };