From 17ecf94f6f248716c791053f0d6b05eda50ac705 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Thu, 24 Apr 2025 15:29:33 -0600 Subject: [PATCH] Restructure agent context (#29233) Simplifies the data structures involved in agent context by removing caching and limiting the use of ContextId: * `AssistantContext` enum is now like an ID / handle to context that does not need to be updated. `ContextId` still exists but is only used for generating unique `ElementId`. * `ContextStore` has a `IndexMap`. Only need to keep a `HashSet` consistent with it. `ContextSetEntry` is a newtype wrapper around `AssistantContext` which implements eq / hash on a subset of fields. * Thread `Message` directly stores its context. Fixes the following bugs: * If a context entry is removed from the strip and added again, it was reincluded in the next message. * Clicking file context in the thread that has been removed from the context strip didn't jump to the file. * Refresh of directory context didn't reflect added / removed files. * Deleted directories would remain in the message editor context strip. * Token counting requests didn't include image context. * File, directory, and symbol context deduplication relied on `ProjectPath` for identity, and so didn't handle renames. * Symbol context line numbers didn't update when shifted Known bugs (not fixed): * Deleting a directory causes it to disappear from messages in threads. Fixing this in a nice way is tricky. One easy fix is to store the original path and show that on deletion. It's weird that deletion would cause the name to "revert", though. Another possibility would be to snapshot context metadata on add (ala `AddedContext`), and keep that around despite deletion. Release Notes: - N/A --- Cargo.lock | 22 +- Cargo.toml | 1 + clippy.toml | 5 + crates/agent/Cargo.toml | 2 +- crates/agent/src/active_thread.rs | 278 ++-- crates/agent/src/agent_diff.rs | 2 + crates/agent/src/assistant.rs | 1 + crates/agent/src/assistant_panel.rs | 60 +- crates/agent/src/buffer_codegen.rs | 113 +- crates/agent/src/context.rs | 939 +++++++++---- crates/agent/src/context_picker.rs | 57 +- .../src/context_picker/completion_provider.rs | 95 +- .../context_picker/fetch_context_picker.rs | 2 +- .../src/context_picker/file_context_picker.rs | 19 +- .../context_picker/rules_context_picker.rs | 66 +- .../context_picker/symbol_context_picker.rs | 58 +- .../context_picker/thread_context_picker.rs | 4 +- crates/agent/src/context_store.rs | 1212 ++++------------- crates/agent/src/context_strip.rs | 122 +- crates/agent/src/history_store.rs | 5 +- crates/agent/src/inline_assistant.rs | 22 +- crates/agent/src/message_editor.rs | 182 ++- crates/agent/src/terminal_codegen.rs | 3 +- crates/agent/src/terminal_inline_assistant.rs | 67 +- crates/agent/src/thread.rs | 346 ++--- crates/agent/src/thread_store.rs | 34 +- crates/agent/src/ui/context_pill.rs | 196 +-- crates/assistant/src/assistant_panel.rs | 8 +- crates/eval/src/example.rs | 9 +- crates/eval/src/instance.rs | 10 +- crates/prompt_store/src/prompt_store.rs | 14 +- 31 files changed, 1893 insertions(+), 2061 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63a57d7df1..98a035ff74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,7 +61,6 @@ dependencies = [ "buffer_diff", "chrono", "client", - "clock", "collections", "command_palette_hooks", "component", @@ -99,6 +98,7 @@ dependencies = [ "prompt_store", "proto", "rand 0.8.5", + "ref-cast", "release_channel", "rope", "rules_library", @@ -11716,6 +11716,26 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "refineable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3f4caf37b7..d90a256a9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -500,6 +500,7 @@ prost-types = "0.9" pulldown-cmark = { version = "0.12.0", default-features = false } quote = "1.0.9" rand = "0.8.5" +ref-cast = "1.0.24" rayon = "1.8" regex = "1.5" repair_json = "0.1.0" diff --git a/clippy.toml b/clippy.toml index 8c8da03a26..e606ad4c79 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,2 +1,7 @@ allow-private-module-inception = true avoid-breaking-exported-api = false +ignore-interior-mutability = [ + # Suppresses clippy::mutable_key_type, which is a false positive as the Eq + # and Hash impls do not use fields with interior mutability. + "agent::context::AgentContextKey" +] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ec366da1ad..4152843d22 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -28,7 +28,6 @@ async-watch.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true -clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -65,6 +64,7 @@ project.workspace = true rules_library.workspace = true prompt_store.workspace = true proto.workspace = true +ref-cast.workspace = true release_channel.workspace = true rope.workspace = true schemars.workspace = true diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 310a283287..0d0e2dd08d 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,4 +1,4 @@ -use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string}; +use crate::context::{AgentContext, RULES_ICON}; use crate::context_picker::MentionLink; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, @@ -25,8 +25,8 @@ use gpui::{ }; use language::{Buffer, LanguageRegistry}; use language_model::{ - LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role, - StopReason, + LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, + RequestUsage, Role, StopReason, }; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; @@ -47,13 +47,10 @@ use util::markdown::MarkdownString; use workspace::{OpenOptions, Workspace}; use zed_actions::assistant::OpenRulesLibrary; -use crate::context_store::ContextStore; - pub struct ActiveThread { language_registry: Arc, thread_store: Entity, thread: Entity, - context_store: Entity, workspace: WeakEntity, save_thread_task: Option>, messages: Vec, @@ -717,7 +714,6 @@ impl ActiveThread { thread: Entity, thread_store: Entity, language_registry: Arc, - context_store: Entity, workspace: WeakEntity, window: &mut Window, cx: &mut Context, @@ -740,7 +736,6 @@ impl ActiveThread { language_registry, thread_store, thread: thread.clone(), - context_store, workspace, save_thread_task: None, messages: Vec::new(), @@ -780,10 +775,6 @@ impl ActiveThread { this } - pub fn context_store(&self) -> &Entity { - &self.context_store - } - pub fn thread(&self) -> &Entity { &self.thread } @@ -1273,26 +1264,36 @@ impl ActiveThread { } let token_count = if let Some(task) = cx.update(|cx| { - let context = thread.read(cx).context_for_message(message_id); - let new_context = thread.read(cx).filter_new_context(context); - let context_text = - format_context_as_string(new_context, cx).unwrap_or(String::new()); + let Some(message) = thread.read(cx).message(message_id) else { + log::error!("Message that was being edited no longer exists"); + return None; + }; let message_text = editor.read(cx).text(cx); - let content = context_text + &message_text; - - if content.is_empty() { + if message_text.is_empty() && message.loaded_context.is_empty() { return None; } + let mut request_message = LanguageModelRequestMessage { + role: language_model::Role::User, + content: Vec::new(), + cache: false, + }; + + message + .loaded_context + .add_to_request_message(&mut request_message); + + if !message_text.is_empty() { + request_message + .content + .push(MessageContent::Text(message_text)); + } + let request = language_model::LanguageModelRequest { thread_id: None, prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: language_model::Role::User, - content: vec![content.into()], - cache: false, - }], + messages: vec![request_message], tools: vec![], stop: vec![], temperature: None, @@ -1487,13 +1488,21 @@ impl ActiveThread { return Empty.into_any(); }; - let context_store = self.context_store.clone(); let workspace = self.workspace.clone(); let thread = self.thread.read(cx); + let prompt_store = self.thread_store.read(cx).prompt_store().as_ref(); // Get all the data we need from thread before we start using it in closures let checkpoint = thread.checkpoint_for_message(message_id); - let context = thread.context_for_message(message_id).collect::>(); + let added_context = if let Some(workspace) = workspace.upgrade() { + let project = workspace.read(cx).project().read(cx); + thread + .context_for_message(message_id) + .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx)) + .collect::>() + } else { + return Empty.into_any(); + }; let tool_uses = thread.tool_uses_for_message(message_id, cx); let has_tool_uses = !tool_uses.is_empty(); @@ -1641,90 +1650,78 @@ impl ActiveThread { }; let message_is_empty = message.should_display_content(); - let has_content = !message_is_empty || !context.is_empty(); + let has_content = !message_is_empty || !added_context.is_empty(); - let message_content = - has_content.then(|| { - v_flex() - .gap_1() - .when(!message_is_empty, |parent| { - parent.child( - if let Some(edit_message_editor) = edit_message_editor.clone() { - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small.rems(cx); - let line_height = font_size.to_pixels(window.rem_size()) * 1.5; + let message_content = has_content.then(|| { + v_flex() + .gap_1() + .when(!message_is_empty, |parent| { + parent.child( + if let Some(edit_message_editor) = edit_message_editor.clone() { + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small.rems(cx); + let line_height = font_size.to_pixels(window.rem_size()) * 1.5; - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; - div() - .key_context("EditMessageEditor") - .on_action(cx.listener(Self::cancel_editing_message)) - .on_action(cx.listener(Self::confirm_editing_message)) - .min_h_6() - .child(EditorElement::new( - &edit_message_editor, - EditorStyle { - background: colors.editor_background, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - )) - .into_any() - } else { - div() - .min_h_6() - .child(self.render_message_content( - message_id, - rendered_message, - has_tool_uses, - workspace.clone(), - window, - cx, - )) - .into_any() - }, - ) - }) - .when(!context.is_empty(), |parent| { - parent.child(h_flex().flex_wrap().gap_1().children( - context.into_iter().map(|context| { - let context_id = context.id(); - ContextPill::added( - AddedContext::new(context, cx), - false, - false, - None, - ) - .on_click(Rc::new(cx.listener({ + div() + .key_context("EditMessageEditor") + .on_action(cx.listener(Self::cancel_editing_message)) + .on_action(cx.listener(Self::confirm_editing_message)) + .min_h_6() + .child(EditorElement::new( + &edit_message_editor, + EditorStyle { + background: colors.editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + )) + .into_any() + } else { + div() + .min_h_6() + .child(self.render_message_content( + message_id, + rendered_message, + has_tool_uses, + workspace.clone(), + window, + cx, + )) + .into_any() + }, + ) + }) + .when(!added_context.is_empty(), |parent| { + parent.child(h_flex().flex_wrap().gap_1().children( + added_context.into_iter().map(|added_context| { + let context = added_context.context.clone(); + ContextPill::added(added_context, false, false, None).on_click(Rc::new( + cx.listener({ let workspace = workspace.clone(); - let context_store = context_store.clone(); move |_, _, window, cx| { if let Some(workspace) = workspace.upgrade() { - open_context( - context_id, - context_store.clone(), - workspace, - window, - cx, - ); + open_context(&context, workspace, window, cx); cx.notify(); } } - }))) - }), - )) - }) - }); + }), + )) + }), + )) + }) + }); let styled_message = match message.role { Role::User => v_flex() @@ -3173,20 +3170,14 @@ impl Render for ActiveThread { } pub(crate) fn open_context( - id: ContextId, - context_store: Entity, + context: &AgentContext, workspace: Entity, window: &mut Window, cx: &mut App, ) { - let Some(context) = context_store.read(cx).context_for_id(id) else { - return; - }; - match context { - AssistantContext::File(file_context) => { - if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx) - { + AgentContext::File(file_context) => { + if let Some(project_path) = file_context.project_path(cx) { workspace.update(cx, |workspace, cx| { workspace .open_path(project_path, None, true, window, cx) @@ -3194,7 +3185,8 @@ pub(crate) fn open_context( }); } } - AssistantContext::Directory(directory_context) => { + + AgentContext::Directory(directory_context) => { let entry_id = directory_context.entry_id; workspace.update(cx, |workspace, cx| { workspace.project().update(cx, |_project, cx| { @@ -3202,61 +3194,51 @@ pub(crate) fn open_context( }) }) } - AssistantContext::Symbol(symbol_context) => { - if let Some(project_path) = symbol_context - .context_symbol - .buffer - .read(cx) - .project_path(cx) - { - let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot(); - let target_position = symbol_context - .context_symbol - .id - .range - .start - .to_point(&snapshot); + AgentContext::Symbol(symbol_context) => { + let buffer = symbol_context.buffer.read(cx); + if let Some(project_path) = buffer.project_path(cx) { + let snapshot = buffer.snapshot(); + let target_position = symbol_context.range.start.to_point(&snapshot); open_editor_at_position(project_path, target_position, &workspace, window, cx) .detach(); } } - AssistantContext::Selection(selection_context) => { - if let Some(project_path) = selection_context - .context_buffer - .buffer - .read(cx) - .project_path(cx) - { - let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot(); + + AgentContext::Selection(selection_context) => { + let buffer = selection_context.buffer.read(cx); + if let Some(project_path) = buffer.project_path(cx) { + let snapshot = buffer.snapshot(); let target_position = selection_context.range.start.to_point(&snapshot); open_editor_at_position(project_path, target_position, &workspace, window, cx) .detach(); } } - AssistantContext::FetchedUrl(fetched_url_context) => { + + AgentContext::FetchedUrl(fetched_url_context) => { cx.open_url(&fetched_url_context.url); } - AssistantContext::Thread(thread_context) => { - let thread_id = thread_context.thread.read(cx).id().clone(); - workspace.update(cx, |workspace, cx| { - if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel - .open_thread(&thread_id, window, cx) - .detach_and_log_err(cx) - }); - } - }) - } - AssistantContext::Rules(rules_context) => window.dispatch_action( + + AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + let thread_id = thread_context.thread.read(cx).id().clone(); + panel + .open_thread(&thread_id, window, cx) + .detach_and_log_err(cx) + }); + } + }), + + AgentContext::Rules(rules_context) => window.dispatch_action( Box::new(OpenRulesLibrary { prompt_to_select: Some(rules_context.prompt_id.0), }), cx, ), - AssistantContext::Image(_) => {} + + AgentContext::Image(_) => {} } } diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index ad2aee93ab..0e66308f37 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -962,11 +962,13 @@ mod tests { }) .unwrap(); + let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), cx.new(|_| ToolWorkingSet::default()), + prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), cx, ) diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 03e13d6f68..80c69665fb 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -39,6 +39,7 @@ use thread::ThreadId; pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; +pub use crate::context::{ContextLoadResult, LoadedContext}; pub use crate::inline_assistant::InlineAssistant; pub use crate::thread::{Message, Thread, ThreadEvent}; pub use crate::thread_store::ThreadStore; diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 3e83c1c224..404702c0b2 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -24,7 +24,7 @@ use language::LanguageRegistry; use language_model::{LanguageModelProviderTosView, LanguageModelRegistry}; use language_model_selector::ToggleModelSelector; use project::Project; -use prompt_store::{PromptBuilder, PromptId, UserPromptId}; +use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use settings::{Settings, update_settings_file}; @@ -189,6 +189,7 @@ pub struct AssistantPanel { message_editor: Entity, _active_thread_subscriptions: Vec, context_store: Entity, + prompt_store: Option>, configuration: Option>, configuration_subscription: Option, local_timezone: UtcOffset, @@ -205,14 +206,25 @@ impl AssistantPanel { pub fn load( workspace: WeakEntity, prompt_builder: Arc, - cx: AsyncWindowContext, + mut cx: AsyncWindowContext, ) -> Task>> { + let prompt_store = cx.update(|_window, cx| PromptStore::global(cx)); cx.spawn(async move |cx| { + let prompt_store = match prompt_store { + Ok(prompt_store) => prompt_store.await.ok(), + Err(_) => None, + }; let tools = cx.new(|_| ToolWorkingSet::default())?; let thread_store = workspace .update(cx, |workspace, cx| { let project = workspace.project().clone(); - ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx) + ThreadStore::load( + project, + tools.clone(), + prompt_store.clone(), + prompt_builder.clone(), + cx, + ) })? .await?; @@ -230,7 +242,16 @@ impl AssistantPanel { .await?; workspace.update_in(cx, |workspace, window, cx| { - cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx)) + cx.new(|cx| { + Self::new( + workspace, + thread_store, + context_store, + prompt_store, + window, + cx, + ) + }) }) }) } @@ -239,6 +260,7 @@ impl AssistantPanel { workspace: &Workspace, thread_store: Entity, context_store: Entity, + prompt_store: Option>, window: &mut Window, cx: &mut Context, ) -> Self { @@ -262,6 +284,7 @@ impl AssistantPanel { fs.clone(), workspace.clone(), message_editor_context_store.clone(), + prompt_store.clone(), thread_store.downgrade(), thread.clone(), window, @@ -293,7 +316,6 @@ impl AssistantPanel { thread.clone(), thread_store.clone(), language_registry.clone(), - message_editor_context_store.clone(), workspace.clone(), window, cx, @@ -322,6 +344,7 @@ impl AssistantPanel { message_editor_subscription, ], context_store, + prompt_store, configuration: None, configuration_subscription: None, local_timezone: UtcOffset::from_whole_seconds( @@ -355,6 +378,10 @@ impl AssistantPanel { self.local_timezone } + pub(crate) fn prompt_store(&self) -> &Option> { + &self.prompt_store + } + pub(crate) fn thread_store(&self) -> &Entity { &self.thread_store } @@ -411,7 +438,6 @@ impl AssistantPanel { thread.clone(), self.thread_store.clone(), self.language_registry.clone(), - message_editor_context_store.clone(), self.workspace.clone(), window, cx, @@ -430,6 +456,7 @@ impl AssistantPanel { self.fs.clone(), self.workspace.clone(), message_editor_context_store, + self.prompt_store.clone(), self.thread_store.downgrade(), thread, window, @@ -500,9 +527,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_select.map(|uuid| PromptId::User { - uuid: UserPromptId(uuid), - }), + action + .prompt_to_select + .map(|uuid| UserPromptId(uuid).into()), cx, ) .detach_and_log_err(cx); @@ -598,7 +625,6 @@ impl AssistantPanel { thread.clone(), this.thread_store.clone(), this.language_registry.clone(), - message_editor_context_store.clone(), this.workspace.clone(), window, cx, @@ -617,6 +643,7 @@ impl AssistantPanel { this.fs.clone(), this.workspace.clone(), message_editor_context_store, + this.prompt_store.clone(), this.thread_store.downgrade(), thread, window, @@ -1876,11 +1903,14 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { else { return; }; + let prompt_store = None; + let thread_store = None; assistant.assist( &prompt_editor, self.workspace.clone(), project, - None, + prompt_store, + thread_store, window, cx, ) @@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate { // being updated. cx.defer_in(window, move |panel, window, cx| { if panel.has_active_thread() { - panel.thread.update(cx, |thread, cx| { - thread.context_store().update(cx, |store, cx| { + panel.message_editor.update(cx, |message_editor, cx| { + message_editor.context_store().update(cx, |store, cx| { let buffer = buffer.read(cx); let selection_ranges = selection_ranges .into_iter() @@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate { .collect::>(); for (buffer, range) in selection_ranges { - store - .add_selection(buffer, range, cx) - .detach_and_log_err(cx); + store.add_selection(buffer, range, cx); } }) }) diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index ebdf9e3d9f..4dbadcbaa5 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -1,6 +1,6 @@ -use crate::context::attach_context_to_message; -use crate::context_store::ContextStore; +use crate::context::ContextLoadResult; use crate::inline_prompt_editor::CodegenStatus; +use crate::{context::load_context, context_store::ContextStore}; use anyhow::Result; use client::telemetry::Telemetry; use collections::HashSet; @@ -8,7 +8,7 @@ use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset use futures::{ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join, }; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, @@ -16,7 +16,9 @@ use language_model::{ }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; +use project::Project; use prompt_store::PromptBuilder; +use prompt_store::PromptStore; use rope::Rope; use smol::future::FutureExt; use std::{ @@ -41,6 +43,8 @@ pub struct BufferCodegen { range: Range, initial_transaction_id: Option, context_store: Entity, + project: WeakEntity, + prompt_store: Option>, telemetry: Arc, builder: Arc, pub is_insertion: bool, @@ -52,6 +56,8 @@ impl BufferCodegen { range: Range, initial_transaction_id: Option, context_store: Entity, + project: WeakEntity, + prompt_store: Option>, telemetry: Arc, builder: Arc, cx: &mut Context, @@ -62,6 +68,8 @@ impl BufferCodegen { range.clone(), false, Some(context_store.clone()), + project.clone(), + prompt_store.clone(), Some(telemetry.clone()), builder.clone(), cx, @@ -77,6 +85,8 @@ impl BufferCodegen { range, initial_transaction_id, context_store, + project, + prompt_store, telemetry, builder, }; @@ -155,6 +165,8 @@ impl BufferCodegen { self.range.clone(), false, Some(self.context_store.clone()), + self.project.clone(), + self.prompt_store.clone(), Some(self.telemetry.clone()), self.builder.clone(), cx, @@ -231,13 +243,14 @@ pub struct CodegenAlternative { generation: Task<()>, diff: Diff, context_store: Option>, + project: WeakEntity, + prompt_store: Option>, telemetry: Option>, _subscription: gpui::Subscription, builder: Arc, active: bool, edits: Vec<(Range, String)>, line_operations: Vec, - request: Option, elapsed_time: Option, completion: Option, pub message_id: Option, @@ -251,6 +264,8 @@ impl CodegenAlternative { range: Range, active: bool, context_store: Option>, + project: WeakEntity, + prompt_store: Option>, telemetry: Option>, builder: Arc, cx: &mut Context, @@ -292,6 +307,8 @@ impl CodegenAlternative { generation: Task::ready(()), diff: Diff::default(), context_store, + project, + prompt_store, telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), builder, @@ -299,7 +316,6 @@ impl CodegenAlternative { edits: Vec::new(), line_operations: Vec::new(), range, - request: None, elapsed_time: None, completion: None, } @@ -368,16 +384,18 @@ impl CodegenAlternative { async { Ok(LanguageModelTextStream::default()) }.boxed_local() } else { let request = self.build_request(user_prompt, cx)?; - self.request = Some(request.clone()); - - cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await) + cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await) .boxed_local() }; self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); Ok(()) } - fn build_request(&self, user_prompt: String, cx: &mut App) -> Result { + fn build_request( + &self, + user_prompt: String, + cx: &mut App, + ) -> Result> { let buffer = self.buffer.read(cx).snapshot(cx); let language = buffer.language_at(self.range.start); let language_name = if let Some(language) = language.as_ref() { @@ -410,30 +428,44 @@ impl CodegenAlternative { .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range) .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; - let mut request_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; + let context_task = self.context_store.as_ref().map(|context_store| { + if let Some(project) = self.project.upgrade() { + let context = context_store + .read(cx) + .context() + .cloned() + .collect::>(); + load_context(context, &project, &self.prompt_store, cx) + } else { + Task::ready(ContextLoadResult::default()) + } + }); - if let Some(context_store) = &self.context_store { - attach_context_to_message( - &mut request_message, - context_store.read(cx).context().iter(), - cx, - ); - } + Ok(cx.spawn(async move |_cx| { + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; - request_message.content.push(prompt.into()); + if let Some(context_task) = context_task { + context_task + .await + .loaded_context + .add_to_request_message(&mut request_message); + } - Ok(LanguageModelRequest { - thread_id: None, - prompt_id: None, - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - messages: vec![request_message], - }) + request_message.content.push(prompt.into()); + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + messages: vec![request_message], + } + })) } pub fn handle_stream( @@ -1038,6 +1070,7 @@ impl Diff { #[cfg(test)] mod tests { use super::*; + use fs::FakeFs; use futures::{ Stream, stream::{self}, @@ -1080,12 +1113,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1144,12 +1181,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1211,12 +1252,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1278,12 +1323,16 @@ mod tests { snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), true, None, + project.downgrade(), + None, None, prompt_builder, cx, @@ -1333,12 +1382,16 @@ mod tests { snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14)) }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, vec![], cx).await; let codegen = cx.new(|cx| { CodegenAlternative::new( buffer.clone(), range.clone(), false, None, + project.downgrade(), + None, None, prompt_builder, cx, diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index cf813a1426..c9101c5615 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -1,34 +1,25 @@ -use std::{ - ops::Range, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::hash::{Hash, Hasher}; +use std::usize; +use std::{ops::Range, path::Path, sync::Arc}; +use collections::HashSet; +use futures::future; use futures::{FutureExt, future::Shared}; -use gpui::{App, Entity, SharedString, Task}; +use gpui::{App, AppContext as _, Entity, SharedString, Task}; use language::Buffer; -use language_model::{LanguageModelImage, LanguageModelRequestMessage}; -use project::{ProjectEntryId, ProjectPath, Worktree}; -use prompt_store::UserPromptId; -use rope::Point; -use serde::{Deserialize, Serialize}; -use text::{Anchor, BufferId}; -use ui::IconName; -use util::post_inc; +use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; +use project::{Project, ProjectEntryId, ProjectPath, Worktree}; +use prompt_store::{PromptStore, UserPromptId}; +use ref_cast::RefCast; +use rope::{Point, Rope}; +use text::{Anchor, OffsetRangeExt as _}; +use ui::{ElementId, IconName}; +use util::{ResultExt as _, post_inc}; use crate::thread::Thread; pub const RULES_ICON: IconName = IconName::Context; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct ContextId(pub(crate) usize); - -impl ContextId { - pub fn post_inc(&mut self) -> Self { - Self(post_inc(&mut self.0)) - } -} - pub enum ContextKind { File, Directory, @@ -55,307 +46,761 @@ impl ContextKind { } } +/// Handle for context that can be added to a user message. +/// +/// This uses IDs that are stable enough for tracking renames and identifying when context has +/// already been added to the thread. To use this in a set, wrap it in `AgentContextKey` to opt in +/// to `PartialEq` and `Hash` impls that use the subset of the fields used for this stable identity. #[derive(Debug, Clone)] -pub enum AssistantContext { +pub enum AgentContext { File(FileContext), Directory(DirectoryContext), Symbol(SymbolContext), + Selection(SelectionContext), FetchedUrl(FetchedUrlContext), Thread(ThreadContext), - Selection(SelectionContext), Rules(RulesContext), Image(ImageContext), } -impl AssistantContext { - pub fn id(&self) -> ContextId { +impl AgentContext { + fn id(&self) -> ContextId { match self { - Self::File(file) => file.id, - Self::Directory(directory) => directory.id, - Self::Symbol(symbol) => symbol.id, - Self::FetchedUrl(url) => url.id, - Self::Thread(thread) => thread.id, - Self::Selection(selection) => selection.id, - Self::Rules(rules) => rules.id, - Self::Image(image) => image.id, + Self::File(context) => context.context_id, + Self::Directory(context) => context.context_id, + Self::Symbol(context) => context.context_id, + Self::Selection(context) => context.context_id, + Self::FetchedUrl(context) => context.context_id, + Self::Thread(context) => context.context_id, + Self::Rules(context) => context.context_id, + Self::Image(context) => context.context_id, } } + + pub fn element_id(&self, name: SharedString) -> ElementId { + ElementId::NamedInteger(name, self.id().0) + } } +/// ID created at time of context add, for use in ElementId. This is not the stable identity of a +/// context, instead that's handled by the `PartialEq` and `Hash` impls of `AgentContextKey`. +#[derive(Debug, Copy, Clone)] +pub struct ContextId(usize); + +impl ContextId { + pub fn zero() -> Self { + ContextId(0) + } + + fn for_lookup() -> Self { + ContextId(usize::MAX) + } + + pub fn post_inc(&mut self) -> Self { + Self(post_inc(&mut self.0)) + } +} + +/// File context provides the entire contents of a file. +/// +/// This holds an `Entity` so that file path renames affect its display and so that it can +/// be opened even if the file has been deleted. An alternative might be to use `ProjectEntryId`, +/// but then when deleted there is no path info or ability to open. #[derive(Debug, Clone)] pub struct FileContext { - pub id: ContextId, - pub context_buffer: ContextBuffer, + pub buffer: Entity, + pub context_id: ContextId, } -#[derive(Debug, Clone)] -pub struct DirectoryContext { - pub id: ContextId, - pub worktree: Entity, - pub entry_id: ProjectEntryId, - pub last_path: Arc, - /// Buffers of the files within the directory. - pub context_buffers: Vec, -} +impl FileContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer + } -impl DirectoryContext { - pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> { - self.worktree.read(cx).entry_for_id(self.entry_id) + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state) } pub fn project_path(&self, cx: &App) -> Option { - let worktree = self.worktree.read(cx); - worktree - .entry_for_id(self.entry_id) - .map(|entry| ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }) + let file = self.buffer.read(cx).file()?; + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("file context missing path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref.as_rope().clone(); + let buffer = self.buffer.clone(); + Some( + cx.background_spawn( + async move { (to_fenced_codeblock(&full_path, rope, None), buffer) }, + ), + ) + } +} + +/// Directory contents provides the entire contents of text files in a directory. +/// +/// This has a `ProjectEntryId` so that it follows renames. +#[derive(Debug, Clone)] +pub struct DirectoryContext { + pub entry_id: ProjectEntryId, + pub context_id: ContextId, +} + +impl DirectoryContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.entry_id == other.entry_id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.entry_id.hash(state) + } + + fn load( + &self, + project: Entity, + cx: &mut App, + ) -> Option)>>> { + let worktree = project.read(cx).worktree_for_entry(self.entry_id, cx)?; + let worktree_ref = worktree.read(cx); + let entry = worktree_ref.entry_for_id(self.entry_id)?; + if entry.is_file() { + log::error!("DirectoryContext unexpectedly refers to a file."); + return None; + } + + let file_paths = collect_files_in_path(worktree_ref, entry.path.as_ref()); + let texts_future = future::join_all(file_paths.into_iter().map(|path| { + load_file_path_text_as_fenced_codeblock(project.clone(), worktree.clone(), path, cx) + })); + + Some(cx.background_spawn(async move { + texts_future.await.into_iter().flatten().collect::>() + })) } } #[derive(Debug, Clone)] pub struct SymbolContext { - pub id: ContextId, - pub context_symbol: ContextSymbol, + pub buffer: Entity, + pub symbol: SharedString, + pub range: Range, + /// The range that fully contain the symbol. e.g. for function symbol, this will include not + /// only the signature, but also the body. Not used by `PartialEq` or `Hash` for `AgentContextKey`. + pub enclosing_range: Range, + pub context_id: ContextId, +} + +impl SymbolContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer && self.symbol == other.symbol && self.range == other.range + } + + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state); + self.symbol.hash(state); + self.range.hash(state); + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("symbol context's file has no path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref + .text_for_range(self.enclosing_range.clone()) + .collect::(); + let line_range = self.enclosing_range.to_point(&buffer_ref.snapshot()); + let buffer = self.buffer.clone(); + Some(cx.background_spawn(async move { + ( + to_fenced_codeblock(&full_path, rope, Some(line_range)), + buffer, + ) + })) + } +} + +#[derive(Debug, Clone)] +pub struct SelectionContext { + pub buffer: Entity, + pub range: Range, + pub context_id: ContextId, +} + +impl SelectionContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.buffer == other.buffer && self.range == other.range + } + + pub fn hash_for_key(&self, state: &mut H) { + self.buffer.hash(state); + self.range.hash(state); + } + + fn load(&self, cx: &App) -> Option)>> { + let buffer_ref = self.buffer.read(cx); + let Some(file) = buffer_ref.file() else { + log::error!("selection context's file has no path"); + return None; + }; + let full_path = file.full_path(cx); + let rope = buffer_ref + .text_for_range(self.range.clone()) + .collect::(); + let line_range = self.range.to_point(&buffer_ref.snapshot()); + let buffer = self.buffer.clone(); + Some(cx.background_spawn(async move { + ( + to_fenced_codeblock(&full_path, rope, Some(line_range)), + buffer, + ) + })) + } } #[derive(Debug, Clone)] pub struct FetchedUrlContext { - pub id: ContextId, pub url: SharedString, + /// Text contents of the fetched url. Unlike other context types, the contents of this gets + /// populated when added rather than when sending the message. Not used by `PartialEq` or `Hash` + /// for `AgentContextKey`. pub text: SharedString, + pub context_id: ContextId, +} + +impl FetchedUrlContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.url == other.url + } + + pub fn hash_for_key(&self, state: &mut H) { + self.url.hash(state); + } + + pub fn lookup_key(url: SharedString) -> AgentContextKey { + AgentContextKey(AgentContext::FetchedUrl(FetchedUrlContext { + url, + text: "".into(), + context_id: ContextId::for_lookup(), + })) + } } #[derive(Debug, Clone)] pub struct ThreadContext { - pub id: ContextId, - // TODO: Entity holds onto the thread even if the thread is deleted. Should probably be - // a WeakEntity and handle removal from the UI when it has dropped. pub thread: Entity, - pub text: SharedString, + pub context_id: ContextId, } impl ThreadContext { - pub fn summary(&self, cx: &App) -> SharedString { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.thread == other.thread + } + + pub fn hash_for_key(&self, state: &mut H) { + self.thread.hash(state) + } + + pub fn name(&self, cx: &App) -> SharedString { self.thread .read(cx) .summary() - .unwrap_or("New thread".into()) + .unwrap_or_else(|| "New thread".into()) + } + + pub fn load(&self, cx: &App) -> String { + let name = self.name(cx); + let contents = self.thread.read(cx).latest_detailed_summary_or_text(); + let mut text = String::new(); + text.push_str(&name); + text.push('\n'); + text.push_str(&contents.trim()); + text.push('\n'); + text + } +} + +#[derive(Debug, Clone)] +pub struct RulesContext { + pub prompt_id: UserPromptId, + pub context_id: ContextId, +} + +impl RulesContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.prompt_id == other.prompt_id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.prompt_id.hash(state) + } + + pub fn lookup_key(prompt_id: UserPromptId) -> AgentContextKey { + AgentContextKey(AgentContext::Rules(RulesContext { + prompt_id, + context_id: ContextId::for_lookup(), + })) + } + + pub fn load( + &self, + prompt_store: &Option>, + cx: &App, + ) -> Task> { + let Some(prompt_store) = prompt_store.as_ref() else { + return Task::ready(None); + }; + let prompt_store = prompt_store.read(cx); + let prompt_id = self.prompt_id.into(); + let Some(metadata) = prompt_store.metadata(prompt_id) else { + return Task::ready(None); + }; + let contents_task = prompt_store.load(prompt_id, cx); + cx.background_spawn(async move { + let contents = contents_task.await.ok()?; + let mut text = String::new(); + if let Some(title) = metadata.title { + text.push_str("Rules title: "); + text.push_str(&title); + text.push('\n'); + } + text.push_str("``````\n"); + text.push_str(contents.trim()); + text.push_str("\n``````\n"); + Some(text) + }) } } #[derive(Debug, Clone)] pub struct ImageContext { - pub id: ContextId, pub original_image: Arc, + // TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml + // needed due to a false positive of `clippy::mutable_key_type`. pub image_task: Shared>>, + pub context_id: ContextId, +} + +pub enum ImageStatus { + Loading, + Error, + Ready, } impl ImageContext { + pub fn eq_for_key(&self, other: &Self) -> bool { + self.original_image.id == other.original_image.id + } + + pub fn hash_for_key(&self, state: &mut H) { + self.original_image.id.hash(state); + } + pub fn image(&self) -> Option { self.image_task.clone().now_or_never().flatten() } - pub fn is_loading(&self) -> bool { - self.image_task.clone().now_or_never().is_none() - } - - pub fn is_error(&self) -> bool { - self.image_task - .clone() - .now_or_never() - .map(|result| result.is_none()) - .unwrap_or(false) - } -} - -#[derive(Clone)] -pub struct ContextBuffer { - pub id: BufferId, - // TODO: Entity holds onto the buffer even if the buffer is deleted. Should probably be - // a WeakEntity and handle removal from the UI when it has dropped. - pub buffer: Entity, - pub last_full_path: Arc, - pub version: clock::Global, - pub text: SharedString, -} - -impl ContextBuffer { - pub fn full_path(&self, cx: &App) -> PathBuf { - let file = self.buffer.read(cx).file(); - // Note that in practice file can't be `None` because it is present when this is created and - // there's no way for buffers to go from having a file to not. - file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx)) - } -} - -impl std::fmt::Debug for ContextBuffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ContextBuffer") - .field("id", &self.id) - .field("buffer", &self.buffer) - .field("version", &self.version) - .field("text", &self.text) - .finish() - } -} - -#[derive(Debug, Clone)] -pub struct ContextSymbol { - pub id: ContextSymbolId, - pub buffer: Entity, - pub buffer_version: clock::Global, - /// The range that the symbol encloses, e.g. for function symbol, this will - /// include not only the signature, but also the body - pub enclosing_range: Range, - pub text: SharedString, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ContextSymbolId { - pub path: ProjectPath, - pub name: SharedString, - pub range: Range, -} - -#[derive(Debug, Clone)] -pub struct SelectionContext { - pub id: ContextId, - pub range: Range, - pub line_range: Range, - pub context_buffer: ContextBuffer, -} - -#[derive(Debug, Clone)] -pub struct RulesContext { - pub id: ContextId, - pub prompt_id: UserPromptId, - pub title: SharedString, - pub text: SharedString, -} - -/// Formats a collection of contexts into a string representation -pub fn format_context_as_string<'a>( - contexts: impl Iterator, - cx: &App, -) -> Option { - let mut file_context = Vec::new(); - let mut directory_context = Vec::new(); - let mut symbol_context = Vec::new(); - let mut selection_context = Vec::new(); - let mut fetch_context = Vec::new(); - let mut thread_context = Vec::new(); - let mut rules_context = Vec::new(); - - for context in contexts { - match context { - AssistantContext::File(context) => file_context.push(context), - AssistantContext::Directory(context) => directory_context.push(context), - AssistantContext::Symbol(context) => symbol_context.push(context), - AssistantContext::Selection(context) => selection_context.push(context), - AssistantContext::FetchedUrl(context) => fetch_context.push(context), - AssistantContext::Thread(context) => thread_context.push(context), - AssistantContext::Rules(context) => rules_context.push(context), - AssistantContext::Image(_) => {} + pub fn status(&self) -> ImageStatus { + match self.image_task.clone().now_or_never() { + None => ImageStatus::Loading, + Some(None) => ImageStatus::Error, + Some(Some(_)) => ImageStatus::Ready, } } +} - if file_context.is_empty() - && directory_context.is_empty() - && symbol_context.is_empty() - && selection_context.is_empty() - && fetch_context.is_empty() - && thread_context.is_empty() - && rules_context.is_empty() - { - return None; +#[derive(Debug, Clone, Default)] +pub struct ContextLoadResult { + pub loaded_context: LoadedContext, + pub referenced_buffers: HashSet>, +} + +#[derive(Debug, Clone, Default)] +pub struct LoadedContext { + pub contexts: Vec, + pub text: String, + pub images: Vec, +} + +impl LoadedContext { + pub fn is_empty(&self) -> bool { + self.text.is_empty() && self.images.is_empty() } - let mut result = String::new(); - result.push_str("\n\n\ - The following items were attached by the user. You don't need to use other tools to read them.\n\n"); - - if !file_context.is_empty() { - result.push_str("\n"); - for context in file_context { - result.push_str(&context.context_buffer.text); + pub fn add_to_request_message(&self, request_message: &mut LanguageModelRequestMessage) { + if !self.text.is_empty() { + request_message + .content + .push(MessageContent::Text(self.text.to_string())); } - result.push_str("\n"); - } - if !directory_context.is_empty() { - result.push_str("\n"); - for context in directory_context { - for context_buffer in &context.context_buffers { - result.push_str(&context_buffer.text); + if !self.images.is_empty() { + // Some providers only support image parts after an initial text part + if request_message.content.is_empty() { + request_message + .content + .push(MessageContent::Text("Images attached by user:".to_string())); + } + + for image in &self.images { + request_message + .content + .push(MessageContent::Image(image.clone())) } } - result.push_str("\n"); } +} - if !symbol_context.is_empty() { - result.push_str("\n"); - for context in symbol_context { - result.push_str(&context.context_symbol.text); - result.push('\n'); +/// Loads and formats a collection of contexts. +pub fn load_context( + contexts: Vec, + project: &Entity, + prompt_store: &Option>, + cx: &mut App, +) -> Task { + let mut file_tasks = Vec::new(); + let mut directory_tasks = Vec::new(); + let mut symbol_tasks = Vec::new(); + let mut selection_tasks = Vec::new(); + let mut fetch_context = Vec::new(); + let mut thread_context = Vec::new(); + let mut rules_tasks = Vec::new(); + let mut image_tasks = Vec::new(); + + for context in contexts.iter().cloned() { + match context { + AgentContext::File(context) => file_tasks.extend(context.load(cx)), + AgentContext::Directory(context) => { + directory_tasks.extend(context.load(project.clone(), cx)) + } + AgentContext::Symbol(context) => symbol_tasks.extend(context.load(cx)), + AgentContext::Selection(context) => selection_tasks.extend(context.load(cx)), + AgentContext::FetchedUrl(context) => fetch_context.push(context), + AgentContext::Thread(context) => thread_context.push(context.load(cx)), + AgentContext::Rules(context) => rules_tasks.push(context.load(prompt_store, cx)), + AgentContext::Image(context) => image_tasks.push(context.image_task.clone()), } - result.push_str("\n"); } - if !selection_context.is_empty() { - result.push_str("\n"); - for context in selection_context { - result.push_str(&context.context_buffer.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !fetch_context.is_empty() { - result.push_str("\n"); - for context in &fetch_context { - result.push_str(&context.url); - result.push('\n'); - result.push_str(&context.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !thread_context.is_empty() { - result.push_str("\n"); - for context in &thread_context { - result.push_str(&context.summary(cx)); - result.push('\n'); - result.push_str(&context.text); - result.push('\n'); - } - result.push_str("\n"); - } - - if !rules_context.is_empty() { - result.push_str( - "\n\ - The user has specified the following rules that should be applied:\n\n", + cx.background_spawn(async move { + let ( + file_context, + directory_context, + symbol_context, + selection_context, + rules_context, + images, + ) = futures::join!( + future::join_all(file_tasks), + future::join_all(directory_tasks), + future::join_all(symbol_tasks), + future::join_all(selection_tasks), + future::join_all(rules_tasks), + future::join_all(image_tasks) ); - for context in &rules_context { - result.push_str(&context.text); - result.push('\n'); + + let directory_context = directory_context.into_iter().flatten().collect::>(); + let rules_context = rules_context.into_iter().flatten().collect::>(); + let images = images.into_iter().flatten().collect::>(); + + let mut referenced_buffers = HashSet::default(); + let mut text = String::new(); + + if file_context.is_empty() + && directory_context.is_empty() + && symbol_context.is_empty() + && selection_context.is_empty() + && fetch_context.is_empty() + && thread_context.is_empty() + && rules_context.is_empty() + { + return ContextLoadResult { + loaded_context: LoadedContext { + contexts, + text, + images, + }, + referenced_buffers, + }; } - result.push_str("\n"); - } - result.push_str("\n"); - Some(result) + text.push_str( + "\n\n\ + The following items were attached by the user. \ + You don't need to use other tools to read them.\n\n", + ); + + if !file_context.is_empty() { + text.push_str(""); + for (file_text, buffer) in file_context { + text.push('\n'); + text.push_str(&file_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !directory_context.is_empty() { + text.push_str(""); + for (file_text, buffer) in directory_context { + text.push('\n'); + text.push_str(&file_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !symbol_context.is_empty() { + text.push_str(""); + for (symbol_text, buffer) in symbol_context { + text.push('\n'); + text.push_str(&symbol_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !selection_context.is_empty() { + text.push_str(""); + for (selection_text, buffer) in selection_context { + text.push('\n'); + text.push_str(&selection_text); + referenced_buffers.insert(buffer); + } + text.push_str("\n"); + } + + if !fetch_context.is_empty() { + text.push_str(""); + for context in fetch_context { + text.push('\n'); + text.push_str(&context.url); + text.push('\n'); + text.push_str(&context.text); + } + text.push_str("\n"); + } + + if !thread_context.is_empty() { + text.push_str(""); + for thread_text in thread_context { + text.push('\n'); + text.push_str(&thread_text); + } + text.push_str("\n"); + } + + if !rules_context.is_empty() { + text.push_str( + "\n\ + The user has specified the following rules that should be applied:\n", + ); + for rules_text in rules_context { + text.push('\n'); + text.push_str(&rules_text); + } + text.push_str("\n"); + } + + text.push_str("\n"); + + ContextLoadResult { + loaded_context: LoadedContext { + contexts, + text, + images, + }, + referenced_buffers, + } + }) } -pub fn attach_context_to_message<'a>( - message: &mut LanguageModelRequestMessage, - contexts: impl Iterator, - cx: &App, -) { - if let Some(context_string) = format_context_as_string(contexts, cx) { - message.content.push(context_string.into()); +fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec> { + let mut files = Vec::new(); + + for entry in worktree.child_entries(path) { + if entry.is_dir() { + files.extend(collect_files_in_path(worktree, &entry.path)); + } else if entry.is_file() { + files.push(entry.path.clone()); + } + } + + files +} + +fn load_file_path_text_as_fenced_codeblock( + project: Entity, + worktree: Entity, + path: Arc, + cx: &mut App, +) -> Task)>> { + let worktree_ref = worktree.read(cx); + let worktree_id = worktree_ref.id(); + let full_path = worktree_ref.full_path(&path); + + let open_task = project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + let project_path = ProjectPath { worktree_id, path }; + buffer_store.open_buffer(project_path, cx) + }) + }); + + let rope_task = cx.spawn(async move |cx| { + let buffer = open_task.await.log_err()?; + let rope = buffer + .read_with(cx, |buffer, _cx| buffer.as_rope().clone()) + .log_err()?; + Some((rope, buffer)) + }); + + cx.background_spawn(async move { + let (rope, buffer) = rope_task.await?; + Some((to_fenced_codeblock(&full_path, rope, None), buffer)) + }) +} + +fn to_fenced_codeblock( + full_path: &Path, + content: Rope, + line_range: Option>, +) -> String { + let line_range_text = line_range.map(|range| { + if range.start.row == range.end.row { + format!(":{}", range.start.row + 1) + } else { + format!(":{}-{}", range.start.row + 1, range.end.row + 1) + } + }); + + let path_extension = full_path.extension().and_then(|ext| ext.to_str()); + let path_string = full_path.to_string_lossy(); + let capacity = 3 + + path_extension.map_or(0, |extension| extension.len() + 1) + + path_string.len() + + line_range_text.as_ref().map_or(0, |text| text.len()) + + 1 + + content.len() + + 5; + let mut buffer = String::with_capacity(capacity); + + buffer.push_str("```"); + + if let Some(extension) = path_extension { + buffer.push_str(extension); + buffer.push(' '); + } + buffer.push_str(&path_string); + + if let Some(line_range_text) = line_range_text { + buffer.push_str(&line_range_text); + } + + buffer.push('\n'); + for chunk in content.chunks() { + buffer.push_str(chunk); + } + + if !buffer.ends_with('\n') { + buffer.push('\n'); + } + + buffer.push_str("```\n"); + + debug_assert!( + buffer.len() == capacity - 1 || buffer.len() == capacity, + "to_fenced_codeblock calculated capacity of {}, but length was {}", + capacity, + buffer.len(), + ); + + buffer +} + +/// Wraps `AgentContext` to opt-in to `PartialEq` and `Hash` impls which use a subset of fields +/// needed for stable context identity. +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] +pub struct AgentContextKey(pub AgentContext); + +impl AsRef for AgentContextKey { + fn as_ref(&self) -> &AgentContext { + &self.0 + } +} + +impl Eq for AgentContextKey {} + +impl PartialEq for AgentContextKey { + fn eq(&self, other: &Self) -> bool { + match &self.0 { + AgentContext::File(context) => { + if let AgentContext::File(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Directory(context) => { + if let AgentContext::Directory(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Symbol(context) => { + if let AgentContext::Symbol(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Selection(context) => { + if let AgentContext::Selection(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::FetchedUrl(context) => { + if let AgentContext::FetchedUrl(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Thread(context) => { + if let AgentContext::Thread(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Rules(context) => { + if let AgentContext::Rules(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + AgentContext::Image(context) => { + if let AgentContext::Image(other_context) = &other.0 { + return context.eq_for_key(other_context); + } + } + } + false + } +} + +impl Hash for AgentContextKey { + fn hash(&self, state: &mut H) { + match &self.0 { + AgentContext::File(context) => context.hash_for_key(state), + AgentContext::Directory(context) => context.hash_for_key(state), + AgentContext::Symbol(context) => context.hash_for_key(state), + AgentContext::Selection(context) => context.hash_for_key(state), + AgentContext::FetchedUrl(context) => context.hash_for_key(state), + AgentContext::Thread(context) => context.hash_for_key(state), + AgentContext::Rules(context) => context.hash_for_key(state), + AgentContext::Image(context) => context.hash_for_key(state), + } } } diff --git a/crates/agent/src/context_picker.rs b/crates/agent/src/context_picker.rs index 8e5cca941c..752f8a0af2 100644 --- a/crates/agent/src/context_picker.rs +++ b/crates/agent/src/context_picker.rs @@ -10,8 +10,11 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{Result, anyhow}; +pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, FoldId}; use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use fetch_context_picker::FetchContextPicker; +use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; use gpui::{ App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, @@ -20,10 +23,10 @@ use gpui::{ use language::Buffer; use multi_buffer::MultiBufferRow; use project::{Entry, ProjectPath}; -use prompt_store::UserPromptId; -use rules_context_picker::RulesContextEntry; +use prompt_store::{PromptStore, UserPromptId}; +use rules_context_picker::{RulesContextEntry, RulesContextPicker}; use symbol_context_picker::SymbolContextPicker; -use thread_context_picker::{ThreadContextEntry, render_thread_context_entry}; +use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry}; use ui::{ ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*, }; @@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt}; use crate::AssistantPanel; use crate::context::RULES_ICON; -pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider; -use crate::context_picker::fetch_context_picker::FetchContextPicker; -use crate::context_picker::file_context_picker::FileContextPicker; -use crate::context_picker::rules_context_picker::RulesContextPicker; -use crate::context_picker::thread_context_picker::ThreadContextPicker; use crate::context_store::ContextStore; use crate::thread::ThreadId; use crate::thread_store::ThreadStore; @@ -166,6 +164,7 @@ pub(super) struct ContextPicker { workspace: WeakEntity, context_store: WeakEntity, thread_store: Option>, + prompt_store: Option>, _subscriptions: Vec, } @@ -193,6 +192,13 @@ impl ContextPicker { ) .collect::>(); + let prompt_store = thread_store.as_ref().and_then(|thread_store| { + thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten() + }); + ContextPicker { mode: ContextPickerState::Default(ContextMenu::build( window, @@ -202,6 +208,7 @@ impl ContextPicker { workspace, context_store, thread_store, + prompt_store, _subscriptions: subscriptions, } } @@ -226,7 +233,12 @@ impl ContextPicker { .workspace .upgrade() .map(|workspace| { - available_context_picker_entries(&self.thread_store, &workspace, cx) + available_context_picker_entries( + &self.prompt_store, + &self.thread_store, + &workspace, + cx, + ) }) .unwrap_or_default(); @@ -304,10 +316,10 @@ impl ContextPicker { })); } ContextPickerMode::Rules => { - if let Some(thread_store) = self.thread_store.as_ref() { + if let Some(prompt_store) = self.prompt_store.as_ref() { self.mode = ContextPickerState::Rules(cx.new(|cx| { RulesContextPicker::new( - thread_store.clone(), + prompt_store.clone(), context_picker.clone(), self.context_store.clone(), window, @@ -526,6 +538,7 @@ enum RecentEntry { } fn available_context_picker_entries( + prompt_store: &Option>, thread_store: &Option>, workspace: &Entity, cx: &mut App, @@ -550,6 +563,9 @@ fn available_context_picker_entries( if thread_store.is_some() { entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread)); + } + + if prompt_store.is_some() { entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules)); } @@ -585,22 +601,21 @@ fn recent_context_picker_entries( }), ); - let mut current_threads = context_store.read(cx).thread_ids(); + let current_threads = context_store.read(cx).thread_ids(); - if let Some(active_thread) = workspace + let active_thread_id = workspace .panel::(cx) - .map(|panel| panel.read(cx).active_thread(cx)) - { - current_threads.insert(active_thread.read(cx).id().clone()); - } + .map(|panel| panel.read(cx).active_thread(cx).read(cx).id()); if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) { recent.extend( thread_store .read(cx) - .threads() + .reverse_chronological_threads() .into_iter() - .filter(|thread| !current_threads.contains(&thread.id)) + .filter(|thread| { + Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id) + }) .take(2) .map(|thread| { RecentEntry::Thread(ThreadContextEntry { @@ -622,9 +637,7 @@ fn add_selections_as_context( let selection_ranges = selection_ranges(workspace, cx); context_store.update(cx, |context_store, cx| { for (buffer, range) in selection_ranges { - context_store - .add_selection(buffer, range, cx) - .detach_and_log_err(cx); + context_store.add_selection(buffer, range, cx); } }) } diff --git a/crates/agent/src/context_picker/completion_provider.rs b/crates/agent/src/context_picker/completion_provider.rs index f82dab0a1b..05cecc68e7 100644 --- a/crates/agent/src/context_picker/completion_provider.rs +++ b/crates/agent/src/context_picker/completion_provider.rs @@ -15,22 +15,21 @@ use itertools::Itertools; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId}; -use prompt_store::PromptId; +use prompt_store::PromptStore; use rope::Point; use text::{Anchor, OffsetRangeExt, ToPoint}; use ui::prelude::*; use workspace::Workspace; use crate::context::RULES_ICON; -use crate::context_picker::file_context_picker::search_files; -use crate::context_picker::symbol_context_picker::search_symbols; use crate::context_store::ContextStore; use crate::thread_store::ThreadStore; use super::fetch_context_picker::fetch_url_content; -use super::file_context_picker::FileMatch; +use super::file_context_picker::{FileMatch, search_files}; use super::rules_context_picker::{RulesContextEntry, search_rules}; use super::symbol_context_picker::SymbolMatch; +use super::symbol_context_picker::search_symbols; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry, @@ -38,8 +37,8 @@ use super::{ }; pub(crate) enum Match { - Symbol(SymbolMatch), File(FileMatch), + Symbol(SymbolMatch), Thread(ThreadMatch), Fetch(SharedString), Rules(RulesContextEntry), @@ -69,6 +68,7 @@ fn search( query: String, cancellation_flag: Arc, recent_entries: Vec, + prompt_store: Option>, thread_store: Option>, workspace: Entity, cx: &mut App, @@ -85,6 +85,7 @@ fn search( .collect() }) } + Some(ContextPickerMode::Symbol) => { let search_symbols_task = search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); @@ -96,6 +97,7 @@ fn search( .collect() }) } + Some(ContextPickerMode::Thread) => { if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { let search_threads_task = @@ -111,6 +113,7 @@ fn search( Task::ready(Vec::new()) } } + Some(ContextPickerMode::Fetch) => { if !query.is_empty() { Task::ready(vec![Match::Fetch(query.into())]) @@ -118,10 +121,11 @@ fn search( Task::ready(Vec::new()) } } + Some(ContextPickerMode::Rules) => { - if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { + if let Some(prompt_store) = prompt_store.as_ref() { let search_rules_task = - search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx); + search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); cx.background_spawn(async move { search_rules_task .await @@ -133,6 +137,7 @@ fn search( Task::ready(Vec::new()) } } + None => { if query.is_empty() { let mut matches = recent_entries @@ -163,7 +168,7 @@ fn search( .collect::>(); matches.extend( - available_context_picker_entries(&thread_store, &workspace, cx) + available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx) .into_iter() .map(|mode| { Match::Entry(EntryMatch { @@ -180,7 +185,8 @@ fn search( let search_files_task = search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); - let entries = available_context_picker_entries(&thread_store, &workspace, cx); + let entries = + available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx); let entry_candidates = entries .iter() .enumerate() @@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider { move |_, _: &mut Window, cx: &mut App| { context_store.update(cx, |context_store, cx| { for (buffer, range) in &selections { - context_store - .add_selection(buffer.clone(), range.clone(), cx) - .detach_and_log_err(cx) + context_store.add_selection( + buffer.clone(), + range.clone(), + cx, + ); } }); @@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider { source_range: Range, editor: Entity, context_store: Entity, - thread_store: Entity, ) -> Completion { let new_text = MentionLink::for_rules(&rules); let new_text_len = new_text.len(); @@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), move |cx| { - let prompt_uuid = rules.prompt_id; - let prompt_id = PromptId::User { uuid: prompt_uuid }; - let context_store = context_store.clone(); - let Some(prompt_store) = thread_store.read(cx).prompt_store() else { - log::error!("Can't add user rules as prompt store is missing."); - return; - }; - let prompt_store = prompt_store.read(cx); - let Some(metadata) = prompt_store.metadata(prompt_id) else { - return; - }; - let Some(title) = metadata.title else { - return; - }; - let text_task = prompt_store.load(prompt_id, cx); - - cx.spawn(async move |cx| { - let text = text_task.await?; - context_store.update(cx, |context_store, cx| { - context_store.add_rules(prompt_uuid, title, text, false, cx) - }) - }) - .detach_and_log_err(cx); + let user_prompt_id = rules.prompt_id; + context_store.update(cx, |context_store, cx| { + context_store.add_rules(user_prompt_id, false, cx); + }); }, )), } @@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider { let url_to_fetch = url_to_fetch.clone(); cx.spawn(async move |cx| { if context_store.update(cx, |context_store, _| { - context_store.includes_url(&url_to_fetch).is_some() + context_store.includes_url(&url_to_fetch) })? { return Ok(()); } @@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider { move |cx| { context_store.update(cx, |context_store, cx| { let task = if is_directory { - context_store.add_directory(project_path.clone(), false, cx) + Task::ready(context_store.add_directory(&project_path, false, cx)) } else { context_store.add_file_from_path(project_path.clone(), false, cx) }; @@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, ); + let prompt_store = thread_store.as_ref().and_then(|thread_store| { + thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten() + }); + let search_task = search( mode, query, Arc::::default(), recent_entries, + prompt_store, thread_store.clone(), workspace.clone(), cx, @@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, )) } + Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( symbol, excerpt_id, @@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { workspace.clone(), cx, ), + Match::Thread(ThreadMatch { thread, is_recent, .. }) => { @@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider { thread_store, )) } - Match::Rules(user_rules) => { - let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?; - Some(Self::completion_for_rules( - user_rules, - excerpt_id, - source_range.clone(), - editor.clone(), - context_store.clone(), - thread_store, - )) - } + + Match::Rules(user_rules) => Some(Self::completion_for_rules( + user_rules, + excerpt_id, + source_range.clone(), + editor.clone(), + context_store.clone(), + )), + Match::Fetch(url) => Some(Self::completion_for_fetch( source_range.clone(), url, @@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { context_store.clone(), http_client.clone(), )), + Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( entry, excerpt_id, diff --git a/crates/agent/src/context_picker/fetch_context_picker.rs b/crates/agent/src/context_picker/fetch_context_picker.rs index 5c7795237b..5df47e5a28 100644 --- a/crates/agent/src/context_picker/fetch_context_picker.rs +++ b/crates/agent/src/context_picker/fetch_context_picker.rs @@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate { cx: &mut Context>, ) -> Option { let added = self.context_store.upgrade().map_or(false, |context_store| { - context_store.read(cx).includes_url(&self.url).is_some() + context_store.read(cx).includes_url(&self.url) }); Some( diff --git a/crates/agent/src/context_picker/file_context_picker.rs b/crates/agent/src/context_picker/file_context_picker.rs index 5981b471c2..1dbd209850 100644 --- a/crates/agent/src/context_picker/file_context_picker.rs +++ b/crates/agent/src/context_picker/file_context_picker.rs @@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate { .context_store .update(cx, |context_store, cx| { if is_directory { - context_store.add_directory(project_path, true, cx) + Task::ready(context_store.add_directory(&project_path, true, cx)) } else { - context_store.add_file_from_path(project_path, true, cx) + context_store.add_file_from_path(project_path.clone(), true, cx) } }) .ok() @@ -325,11 +325,11 @@ pub fn render_file_context_entry( path: path.clone(), }; if is_directory { - context_store.read(cx).includes_directory(&project_path) - } else { context_store .read(cx) - .will_include_file_path(&project_path, cx) + .path_included_in_directory(&project_path, cx) + } else { + context_store.read(cx).file_path_included(&project_path, cx) } }); @@ -357,7 +357,7 @@ pub fn render_file_context_entry( })), ) .when_some(added, |el, added| match added { - FileInclusion::Direct(_) => el.child( + FileInclusion::Direct => el.child( h_flex() .w_full() .justify_end() @@ -369,9 +369,8 @@ pub fn render_file_context_entry( ) .child(Label::new("Added").size(LabelSize::Small)), ), - FileInclusion::InDirectory(directory_project_path) => { - // TODO: Consider using worktree full_path to include worktree name. - let directory_path = directory_project_path.path.to_string_lossy().into_owned(); + FileInclusion::InDirectory { full_path } => { + let directory_full_path = full_path.to_string_lossy().into_owned(); el.child( h_flex() @@ -385,7 +384,7 @@ pub fn render_file_context_entry( ) .child(Label::new("Included").size(LabelSize::Small)), ) - .tooltip(Tooltip::text(format!("in {directory_path}"))) + .tooltip(Tooltip::text(format!("in {directory_full_path}"))) } }) } diff --git a/crates/agent/src/context_picker/rules_context_picker.rs b/crates/agent/src/context_picker/rules_context_picker.rs index 4c1fc65303..ef4676e4c3 100644 --- a/crates/agent/src/context_picker/rules_context_picker.rs +++ b/crates/agent/src/context_picker/rules_context_picker.rs @@ -1,16 +1,15 @@ use std::sync::Arc; use std::sync::atomic::AtomicBool; -use anyhow::anyhow; use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; use picker::{Picker, PickerDelegate}; -use prompt_store::{PromptId, UserPromptId}; +use prompt_store::{PromptId, PromptStore, UserPromptId}; use ui::{ListItem, prelude::*}; +use util::ResultExt as _; use crate::context::RULES_ICON; use crate::context_picker::ContextPicker; use crate::context_store::{self, ContextStore}; -use crate::thread_store::ThreadStore; pub struct RulesContextPicker { picker: Entity>, @@ -18,13 +17,13 @@ pub struct RulesContextPicker { impl RulesContextPicker { pub fn new( - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store); + let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); RulesContextPicker { picker } @@ -50,7 +49,7 @@ pub struct RulesContextEntry { } pub struct RulesContextPickerDelegate { - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, matches: Vec, @@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate { impl RulesContextPickerDelegate { pub fn new( - thread_store: WeakEntity, + prompt_store: Entity, context_picker: WeakEntity, context_store: WeakEntity, ) -> Self { RulesContextPickerDelegate { - thread_store, + prompt_store, context_picker, context_store, matches: Vec::new(), @@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let Some(thread_store) = self.thread_store.upgrade() else { - return Task::ready(()); - }; - - let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx); + let search_task = search_rules( + query, + Arc::new(AtomicBool::default()), + &self.prompt_store, + cx, + ); cx.spawn_in(window, async move |this, cx| { let matches = search_task.await; this.update(cx, |this, cx| { @@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate { return; }; - let Some(thread_store) = self.thread_store.upgrade() else { - return; - }; - - let prompt_id = entry.prompt_id; - - let load_rules_task = thread_store.update(cx, |thread_store, cx| { - thread_store.load_rules(prompt_id, cx) - }); - - cx.spawn(async move |this, cx| { - let (metadata, text) = load_rules_task.await?; - let Some(title) = metadata.title else { - return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context.")); - }; - this.update(cx, |this, cx| { - this.delegate - .context_store - .update(cx, |context_store, cx| { - context_store.add_rules(prompt_id, title, text, true, cx) - }) - .ok(); + self.context_store + .update(cx, |context_store, cx| { + context_store.add_rules(entry.prompt_id, true, cx) }) - }) - .detach_and_log_err(cx); + .log_err(); } fn dismissed(&mut self, _window: &mut Window, cx: &mut Context>) { @@ -179,11 +159,10 @@ pub fn render_thread_context_entry( context_store: WeakEntity, cx: &mut App, ) -> Div { - let added = context_store.upgrade().map_or(false, |ctx_store| { - ctx_store + let added = context_store.upgrade().map_or(false, |context_store| { + context_store .read(cx) - .includes_user_rules(&user_rules.prompt_id) - .is_some() + .includes_user_rules(user_rules.prompt_id) }); h_flex() @@ -218,12 +197,9 @@ pub fn render_thread_context_entry( pub(crate) fn search_rules( query: String, cancellation_flag: Arc, - thread_store: Entity, + prompt_store: &Entity, cx: &mut App, ) -> Task> { - let Some(prompt_store) = thread_store.read(cx).prompt_store() else { - return Task::ready(vec![]); - }; let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx); cx.background_spawn(async move { search_task diff --git a/crates/agent/src/context_picker/symbol_context_picker.rs b/crates/agent/src/context_picker/symbol_context_picker.rs index b76d4a8093..bc70c237a4 100644 --- a/crates/agent/src/context_picker/symbol_context_picker.rs +++ b/crates/agent/src/context_picker/symbol_context_picker.rs @@ -10,7 +10,6 @@ use gpui::{ use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; use project::{DocumentSymbol, Symbol}; -use text::OffsetRangeExt; use ui::{ListItem, prelude::*}; use util::ResultExt as _; use workspace::Workspace; @@ -228,18 +227,16 @@ pub(crate) fn add_symbol( ) })?; - context_store - .update(cx, move |context_store, cx| { - context_store.add_symbol( - buffer, - name.into(), - range, - enclosing_range, - remove_if_exists, - cx, - ) - })? - .await + context_store.update(cx, move |context_store, cx| { + context_store.add_symbol( + buffer, + name.into(), + range, + enclosing_range, + remove_if_exists, + cx, + ) + }) }) } @@ -353,38 +350,13 @@ fn compute_symbol_entries( context_store: &ContextStore, cx: &App, ) -> Vec { - let mut symbol_entries = Vec::with_capacity(symbols.len()); - for SymbolMatch { symbol, .. } in symbols { - let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path); - let is_included = if let Some(symbols_for_path) = symbols_for_path { - let mut is_included = false; - for included_symbol_id in symbols_for_path { - if included_symbol_id.name.as_ref() == symbol.name.as_str() { - if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) { - let snapshot = buffer.read(cx).snapshot(); - let included_symbol_range = - included_symbol_id.range.to_point_utf16(&snapshot); - - if included_symbol_range.start == symbol.range.start.0 - && included_symbol_range.end == symbol.range.end.0 - { - is_included = true; - break; - } - } - } - } - is_included - } else { - false - }; - - symbol_entries.push(SymbolEntry { + symbols + .into_iter() + .map(|SymbolMatch { symbol, .. }| SymbolEntry { + is_included: context_store.includes_symbol(&symbol, cx), symbol, - is_included, }) - } - symbol_entries + .collect::>() } pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful
{ diff --git a/crates/agent/src/context_picker/thread_context_picker.rs b/crates/agent/src/context_picker/thread_context_picker.rs index 030eaf06af..90c21b1c93 100644 --- a/crates/agent/src/context_picker/thread_context_picker.rs +++ b/crates/agent/src/context_picker/thread_context_picker.rs @@ -173,7 +173,7 @@ pub fn render_thread_context_entry( cx: &mut App, ) -> Div { let added = context_store.upgrade().map_or(false, |ctx_store| { - ctx_store.read(cx).includes_thread(&thread.id).is_some() + ctx_store.read(cx).includes_thread(&thread.id) }); h_flex() @@ -219,7 +219,7 @@ pub(crate) fn search_threads( ) -> Task> { let threads = thread_store .read(cx) - .threads() + .reverse_chronological_threads() .into_iter() .map(|thread| ThreadContextEntry { id: thread.id, diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index c66cad3ef2..f8f60dc911 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -1,43 +1,35 @@ use std::ops::Range; -use std::path::Path; +use std::path::PathBuf; use std::sync::Arc; -use anyhow::{Context as _, Result, anyhow}; -use collections::{BTreeMap, HashMap, HashSet}; +use anyhow::{Result, anyhow}; +use collections::{HashSet, IndexSet}; use futures::future::join_all; -use futures::{self, Future, FutureExt, future}; -use gpui::{App, AppContext as _, Context, Entity, Image, SharedString, Task, WeakEntity}; +use futures::{self, FutureExt}; +use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity}; use language::Buffer; use language_model::LanguageModelImage; -use project::{Project, ProjectEntryId, ProjectItem, ProjectPath, Worktree}; +use project::{Project, ProjectItem, ProjectPath, Symbol}; use prompt_store::UserPromptId; -use rope::{Point, Rope}; -use text::{Anchor, BufferId, OffsetRangeExt}; -use util::{ResultExt as _, maybe}; +use ref_cast::RefCast as _; +use text::{Anchor, OffsetRangeExt}; +use util::ResultExt as _; use crate::ThreadStore; use crate::context::{ - AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, - FetchedUrlContext, FileContext, ImageContext, RulesContext, SelectionContext, SymbolContext, - ThreadContext, + AgentContext, AgentContextKey, ContextId, DirectoryContext, FetchedUrlContext, FileContext, + ImageContext, RulesContext, SelectionContext, SymbolContext, ThreadContext, }; use crate::context_strip::SuggestedContext; use crate::thread::{Thread, ThreadId}; pub struct ContextStore { project: WeakEntity, - context: Vec, thread_store: Option>, - next_context_id: ContextId, - files: BTreeMap, - directories: HashMap, - symbols: HashMap, - symbol_buffers: HashMap>, - symbols_by_path: HashMap>, - threads: HashMap, thread_summary_tasks: Vec>, - fetched_urls: HashMap, - user_rules: HashMap, + next_context_id: ContextId, + context_set: IndexSet, + context_thread_ids: HashSet, } impl ContextStore { @@ -48,35 +40,33 @@ impl ContextStore { Self { project, thread_store, - context: Vec::new(), - next_context_id: ContextId(0), - files: BTreeMap::default(), - directories: HashMap::default(), - symbols: HashMap::default(), - symbol_buffers: HashMap::default(), - symbols_by_path: HashMap::default(), - threads: HashMap::default(), thread_summary_tasks: Vec::new(), - fetched_urls: HashMap::default(), - user_rules: HashMap::default(), + next_context_id: ContextId::zero(), + context_set: IndexSet::default(), + context_thread_ids: HashSet::default(), } } - pub fn context(&self) -> &Vec { - &self.context - } - - pub fn context_for_id(&self, id: ContextId) -> Option<&AssistantContext> { - self.context().iter().find(|context| context.id() == id) + pub fn context(&self) -> impl Iterator { + self.context_set.iter().map(|entry| entry.as_ref()) } pub fn clear(&mut self) { - self.context.clear(); - self.files.clear(); - self.directories.clear(); - self.threads.clear(); - self.fetched_urls.clear(); - self.user_rules.clear(); + self.context_set.clear(); + self.context_thread_ids.clear(); + } + + pub fn new_context_for_thread(&self, thread: &Thread) -> Vec { + let existing_context = thread + .messages() + .flat_map(|message| &message.loaded_context.contexts) + .map(AgentContextKey::ref_cast) + .collect::>(); + self.context_set + .iter() + .filter(|context| !existing_context.contains(context)) + .map(|entry| entry.0.clone()) + .collect::>() } pub fn add_file_from_path( @@ -93,241 +83,98 @@ impl ContextStore { let open_buffer_task = project.update(cx, |project, cx| { project.open_buffer(project_path.clone(), cx) })?; - let buffer = open_buffer_task.await?; - let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?; - - let already_included = this.update(cx, |this, cx| { - match this.will_include_buffer(buffer_id, &project_path) { - Some(FileInclusion::Direct(context_id)) => { - if remove_if_exists { - this.remove_context(context_id, cx); - } - true - } - Some(FileInclusion::InDirectory(_)) => true, - None => false, - } - })?; - - if already_included { - return anyhow::Ok(()); - } - - let context_buffer = this - .update(cx, |_, cx| load_context_buffer(buffer, cx))?? - .await; - this.update(cx, |this, cx| { - this.insert_file(context_buffer, cx); - })?; - - anyhow::Ok(()) + this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx) + }) }) } pub fn add_file_from_buffer( &mut self, + project_path: &ProjectPath, buffer: Entity, + remove_if_exists: bool, cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| { - let context_buffer = this - .update(cx, |_, cx| load_context_buffer(buffer, cx))?? - .await; + ) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::File(FileContext { buffer, context_id }); - this.update(cx, |this, cx| this.insert_file(context_buffer, cx))?; + let already_included = if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); + } + true + } else { + self.path_included_in_directory(project_path, cx).is_some() + }; - anyhow::Ok(()) - }) - } - - fn insert_file(&mut self, context_buffer: ContextBuffer, cx: &mut Context) { - let id = self.next_context_id.post_inc(); - self.files.insert(context_buffer.id, id); - self.context - .push(AssistantContext::File(FileContext { id, context_buffer })); - cx.notify(); + if !already_included { + self.insert_context(context, cx); + } } pub fn add_directory( &mut self, - project_path: ProjectPath, + project_path: &ProjectPath, remove_if_exists: bool, cx: &mut Context, - ) -> Task> { + ) -> Result<()> { let Some(project) = self.project.upgrade() else { - return Task::ready(Err(anyhow!("failed to read project"))); + return Err(anyhow!("failed to read project")); }; let Some(entry_id) = project .read(cx) - .entry_for_path(&project_path, cx) + .entry_for_path(project_path, cx) .map(|entry| entry.id) else { - return Task::ready(Err(anyhow!("no entry found for directory context"))); + return Err(anyhow!("no entry found for directory context")); }; - let already_included = match self.includes_directory(&project_path) { - Some(FileInclusion::Direct(context_id)) => { - if remove_if_exists { - self.remove_context(context_id, cx); - } - true + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Directory(DirectoryContext { + entry_id, + context_id, + }); + + if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); } - Some(FileInclusion::InDirectory(_)) => true, - None => false, - }; - if already_included { - return Task::ready(Ok(())); + } else if self.path_included_in_directory(project_path, cx).is_none() { + self.insert_context(context, cx); } - let worktree_id = project_path.worktree_id; - cx.spawn(async move |this, cx| { - let worktree = project.update(cx, |project, cx| { - project - .worktree_for_id(worktree_id, cx) - .ok_or_else(|| anyhow!("no worktree found for {worktree_id:?}")) - })??; - - let files = worktree.update(cx, |worktree, _cx| { - collect_files_in_path(worktree, &project_path.path) - })?; - - let open_buffers_task = project.update(cx, |project, cx| { - let tasks = files.iter().map(|file_path| { - project.open_buffer( - ProjectPath { - worktree_id, - path: file_path.clone(), - }, - cx, - ) - }); - future::join_all(tasks) - })?; - - let buffers = open_buffers_task.await; - - let context_buffer_tasks = this.update(cx, |_, cx| { - buffers - .into_iter() - .flatten() - .flat_map(move |buffer| load_context_buffer(buffer, cx).log_err()) - .collect::>() - })?; - - let context_buffers = future::join_all(context_buffer_tasks).await; - - if context_buffers.is_empty() { - let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?; - return Err(anyhow!("No text files found in {}", &full_path.display())); - } - - this.update(cx, |this, cx| { - this.insert_directory(worktree, entry_id, project_path, context_buffers, cx); - })?; - - anyhow::Ok(()) - }) - } - - fn insert_directory( - &mut self, - worktree: Entity, - entry_id: ProjectEntryId, - project_path: ProjectPath, - context_buffers: Vec, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - let last_path = project_path.path.clone(); - self.directories.insert(project_path, id); - - self.context - .push(AssistantContext::Directory(DirectoryContext { - id, - worktree, - entry_id, - last_path, - context_buffers, - })); - cx.notify(); + anyhow::Ok(()) } pub fn add_symbol( &mut self, buffer: Entity, - symbol_name: SharedString, - symbol_range: Range, - symbol_enclosing_range: Range, + symbol: SharedString, + range: Range, + enclosing_range: Range, remove_if_exists: bool, cx: &mut Context, - ) -> Task> { - let buffer_ref = buffer.read(cx); - let Some(project_path) = buffer_ref.project_path(cx) else { - return Task::ready(Err(anyhow!("buffer has no path"))); - }; + ) -> bool { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Symbol(SymbolContext { + buffer, + symbol, + range, + enclosing_range, + context_id, + }); - if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) { - let mut matching_symbol_id = None; - for symbol in symbols_for_path { - if &symbol.name == &symbol_name { - let snapshot = buffer_ref.snapshot(); - if symbol.range.to_offset(&snapshot) == symbol_range.to_offset(&snapshot) { - matching_symbol_id = self.symbols.get(symbol).cloned(); - break; - } - } - } - - if let Some(id) = matching_symbol_id { - if remove_if_exists { - self.remove_context(id, cx); - } - return Task::ready(Ok(false)); + if self.has_context(&context) { + if remove_if_exists { + self.remove_context(&context, cx); } + return false; } - let context_buffer_task = - match load_context_buffer_range(buffer, symbol_enclosing_range.clone(), cx) { - Ok((_line_range, context_buffer_task)) => context_buffer_task, - Err(err) => return Task::ready(Err(err)), - }; - - cx.spawn(async move |this, cx| { - let context_buffer = context_buffer_task.await; - - this.update(cx, |this, cx| { - this.insert_symbol( - make_context_symbol( - context_buffer, - project_path, - symbol_name, - symbol_range, - symbol_enclosing_range, - ), - cx, - ) - })?; - anyhow::Ok(true) - }) - } - - fn insert_symbol(&mut self, context_symbol: ContextSymbol, cx: &mut Context) { - let id = self.next_context_id.post_inc(); - self.symbols.insert(context_symbol.id.clone(), id); - self.symbols_by_path - .entry(context_symbol.id.path.clone()) - .or_insert_with(Vec::new) - .push(context_symbol.id.clone()); - self.symbol_buffers - .insert(context_symbol.id.clone(), context_symbol.buffer.clone()); - self.context.push(AssistantContext::Symbol(SymbolContext { - id, - context_symbol, - })); - cx.notify(); + self.insert_context(context, cx) } pub fn add_thread( @@ -336,24 +183,23 @@ impl ContextStore { remove_if_exists: bool, cx: &mut Context, ) { - if let Some(context_id) = self.includes_thread(&thread.read(cx).id()) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Thread(ThreadContext { thread, context_id }); + + if self.has_context(&context) { if remove_if_exists { - self.remove_context(context_id, cx); + self.remove_context(&context, cx); } } else { - self.insert_thread(thread, cx); + self.insert_context(context, cx); } } - pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { - let tasks = std::mem::take(&mut self.thread_summary_tasks); - - cx.spawn(async move |_cx| { - join_all(tasks).await; - }) - } - - fn insert_thread(&mut self, thread: Entity, cx: &mut Context) { + fn start_summarizing_thread_if_needed( + &mut self, + thread: &Entity, + cx: &mut Context, + ) { if let Some(summary_task) = thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx)) { @@ -374,106 +220,60 @@ impl ContextStore { } })); } + } - let id = self.next_context_id.post_inc(); + pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { + let tasks = std::mem::take(&mut self.thread_summary_tasks); - let text = thread.read(cx).latest_detailed_summary_or_text(); - - self.threads.insert(thread.read(cx).id().clone(), id); - self.context - .push(AssistantContext::Thread(ThreadContext { id, thread, text })); - cx.notify(); + cx.spawn(async move |_cx| { + join_all(tasks).await; + }) } pub fn add_rules( &mut self, prompt_id: UserPromptId, - title: impl Into, - text: impl Into, remove_if_exists: bool, cx: &mut Context, ) { - if let Some(context_id) = self.includes_user_rules(&prompt_id) { + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Rules(RulesContext { + prompt_id, + context_id, + }); + + if self.has_context(&context) { if remove_if_exists { - self.remove_context(context_id, cx); + self.remove_context(&context, cx); } } else { - self.insert_user_rules(prompt_id, title, text, cx); + self.insert_context(context, cx); } } - pub fn insert_user_rules( - &mut self, - prompt_id: UserPromptId, - title: impl Into, - text: impl Into, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - - self.user_rules.insert(prompt_id, id); - self.context.push(AssistantContext::Rules(RulesContext { - id, - prompt_id, - title: title.into(), - text: text.into(), - })); - cx.notify(); - } - pub fn add_fetched_url( &mut self, url: String, text: impl Into, cx: &mut Context, ) { - if self.includes_url(&url).is_none() { - self.insert_fetched_url(url, text, cx); - } - } + let context = AgentContext::FetchedUrl(FetchedUrlContext { + url: url.into(), + text: text.into(), + context_id: self.next_context_id.post_inc(), + }); - fn insert_fetched_url( - &mut self, - url: String, - text: impl Into, - cx: &mut Context, - ) { - let id = self.next_context_id.post_inc(); - - self.fetched_urls.insert(url.clone(), id); - self.context - .push(AssistantContext::FetchedUrl(FetchedUrlContext { - id, - url: url.into(), - text: text.into(), - })); - cx.notify(); + self.insert_context(context, cx); } pub fn add_image(&mut self, image: Arc, cx: &mut Context) { let image_task = LanguageModelImage::from_image(image.clone(), cx).shared(); - let id = self.next_context_id.post_inc(); - self.context.push(AssistantContext::Image(ImageContext { - id, + let context = AgentContext::Image(ImageContext { original_image: image, image_task, - })); - cx.notify(); - } - - pub fn wait_for_images(&self, cx: &App) -> Task<()> { - let tasks = self - .context - .iter() - .filter_map(|ctx| match ctx { - AssistantContext::Image(ctx) => Some(ctx.image_task.clone()), - _ => None, - }) - .collect::>(); - - cx.spawn(async move |_cx| { - join_all(tasks).await; - }) + context_id: self.next_context_id.post_inc(), + }); + self.insert_context(context, cx); } pub fn add_selection( @@ -481,45 +281,21 @@ impl ContextStore { buffer: Entity, range: Range, cx: &mut Context, - ) -> Task> { - cx.spawn(async move |this, cx| { - let (line_range, context_buffer_task) = this.update(cx, |_, cx| { - load_context_buffer_range(buffer, range.clone(), cx) - })??; - - let context_buffer = context_buffer_task.await; - - this.update(cx, |this, cx| { - this.insert_selection(context_buffer, range, line_range, cx) - })?; - - anyhow::Ok(()) - }) - } - - fn insert_selection( - &mut self, - context_buffer: ContextBuffer, - range: Range, - line_range: Range, - cx: &mut Context, ) { - let id = self.next_context_id.post_inc(); - self.context - .push(AssistantContext::Selection(SelectionContext { - id, - range, - line_range, - context_buffer, - })); - cx.notify(); + let context_id = self.next_context_id.post_inc(); + let context = AgentContext::Selection(SelectionContext { + buffer, + range, + context_id, + }); + self.insert_context(context, cx); } - pub fn accept_suggested_context( + pub fn add_suggested_context( &mut self, suggested: &SuggestedContext, cx: &mut Context, - ) -> Task> { + ) { match suggested { SuggestedContext::File { buffer, @@ -527,655 +303,183 @@ impl ContextStore { name: _, } => { if let Some(buffer) = buffer.upgrade() { - return self.add_file_from_buffer(buffer, cx); + let context_id = self.next_context_id.post_inc(); + self.insert_context(AgentContext::File(FileContext { buffer, context_id }), cx); }; } SuggestedContext::Thread { thread, name: _ } => { if let Some(thread) = thread.upgrade() { - self.insert_thread(thread, cx); - }; - } - } - Task::ready(Ok(())) - } - - pub fn remove_context(&mut self, id: ContextId, cx: &mut Context) { - let Some(ix) = self.context.iter().position(|context| context.id() == id) else { - return; - }; - - match self.context.remove(ix) { - AssistantContext::File(_) => { - self.files.retain(|_, context_id| *context_id != id); - } - AssistantContext::Directory(_) => { - self.directories.retain(|_, context_id| *context_id != id); - } - AssistantContext::Symbol(symbol) => { - if let Some(symbols_in_path) = - self.symbols_by_path.get_mut(&symbol.context_symbol.id.path) - { - symbols_in_path.retain(|s| { - self.symbols - .get(s) - .map_or(false, |context_id| *context_id != id) - }); + let context_id = self.next_context_id.post_inc(); + self.insert_context( + AgentContext::Thread(ThreadContext { thread, context_id }), + cx, + ); } - self.symbol_buffers.remove(&symbol.context_symbol.id); - self.symbols.retain(|_, context_id| *context_id != id); } - AssistantContext::Selection(_) => {} - AssistantContext::FetchedUrl(_) => { - self.fetched_urls.retain(|_, context_id| *context_id != id); - } - AssistantContext::Thread(_) => { - self.threads.retain(|_, context_id| *context_id != id); - } - AssistantContext::Rules(RulesContext { prompt_id, .. }) => { - self.user_rules.remove(&prompt_id); - } - AssistantContext::Image(_) => {} } - - cx.notify(); } - /// Returns whether the buffer is already included directly in the context, or if it will be - /// included in the context via a directory. Directory inclusion is based on paths rather than - /// buffer IDs as the directory will be re-scanned. - pub fn will_include_buffer( - &self, - buffer_id: BufferId, - project_path: &ProjectPath, - ) -> Option { - if let Some(context_id) = self.files.get(&buffer_id) { - return Some(FileInclusion::Direct(*context_id)); + fn insert_context(&mut self, context: AgentContext, cx: &mut Context) -> bool { + match &context { + AgentContext::Thread(thread_context) => { + self.context_thread_ids + .insert(thread_context.thread.read(cx).id().clone()); + self.start_summarizing_thread_if_needed(&thread_context.thread, cx); + } + _ => {} } + let inserted = self.context_set.insert(AgentContextKey(context)); + if inserted { + cx.notify(); + } + inserted + } - self.will_include_file_path_via_directory(project_path) + pub fn remove_context(&mut self, context: &AgentContext, cx: &mut Context) { + if self + .context_set + .shift_remove(AgentContextKey::ref_cast(context)) + { + match context { + AgentContext::Thread(thread_context) => { + self.context_thread_ids + .remove(thread_context.thread.read(cx).id()); + } + _ => {} + } + cx.notify(); + } + } + + pub fn has_context(&mut self, context: &AgentContext) -> bool { + self.context_set + .contains(AgentContextKey::ref_cast(context)) } /// Returns whether this file path is already included directly in the context, or if it will be /// included in the context via a directory. - pub fn will_include_file_path( + pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option { + let project = self.project.upgrade()?.read(cx); + self.context().find_map(|context| match context { + AgentContext::File(file_context) => FileInclusion::check_file(file_context, path, cx), + AgentContext::Directory(directory_context) => { + FileInclusion::check_directory(directory_context, path, project, cx) + } + _ => None, + }) + } + + pub fn path_included_in_directory( &self, - project_path: &ProjectPath, + path: &ProjectPath, cx: &App, ) -> Option { - if !self.files.is_empty() { - let found_file_context = self.context.iter().find(|context| match &context { - AssistantContext::File(file_context) => { - let buffer = file_context.context_buffer.buffer.read(cx); - if let Some(context_path) = buffer.project_path(cx) { - &context_path == project_path - } else { - false - } + let project = self.project.upgrade()?.read(cx); + self.context().find_map(|context| match context { + AgentContext::Directory(directory_context) => { + FileInclusion::check_directory(directory_context, path, project, cx) + } + _ => None, + }) + } + + pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool { + self.context().any(|context| match context { + AgentContext::Symbol(context) => { + if context.symbol != symbol.name { + return false; } - _ => false, - }); - if let Some(context) = found_file_context { - return Some(FileInclusion::Direct(context.id())); + let buffer = context.buffer.read(cx); + let Some(context_path) = buffer.project_path(cx) else { + return false; + }; + if context_path != symbol.path { + return false; + } + let context_range = context.range.to_point_utf16(&buffer.snapshot()); + context_range.start == symbol.range.start.0 + && context_range.end == symbol.range.end.0 } - } - - self.will_include_file_path_via_directory(project_path) + _ => false, + }) } - fn will_include_file_path_via_directory( - &self, - project_path: &ProjectPath, - ) -> Option { - if self.directories.is_empty() { - return None; - } - - let mut path_buf = project_path.path.to_path_buf(); - - while path_buf.pop() { - // TODO: This isn't very efficient. Consider using a better representation of the - // directories map. - let directory_project_path = ProjectPath { - worktree_id: project_path.worktree_id, - path: path_buf.clone().into(), - }; - if let Some(_) = self.directories.get(&directory_project_path) { - return Some(FileInclusion::InDirectory(directory_project_path)); - } - } - - None + pub fn includes_thread(&self, thread_id: &ThreadId) -> bool { + self.context_thread_ids.contains(thread_id) } - pub fn includes_directory(&self, project_path: &ProjectPath) -> Option { - if let Some(context_id) = self.directories.get(project_path) { - return Some(FileInclusion::Direct(*context_id)); - } - - self.will_include_file_path_via_directory(project_path) + pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool { + self.context_set + .contains(&RulesContext::lookup_key(prompt_id)) } - pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option { - self.symbols.get(symbol_id).copied() - } - - pub fn included_symbols_by_path(&self) -> &HashMap> { - &self.symbols_by_path - } - - pub fn buffer_for_symbol(&self, symbol_id: &ContextSymbolId) -> Option> { - self.symbol_buffers.get(symbol_id).cloned() - } - - pub fn includes_thread(&self, thread_id: &ThreadId) -> Option { - self.threads.get(thread_id).copied() - } - - pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option { - self.user_rules.get(prompt_id).copied() - } - - pub fn includes_url(&self, url: &str) -> Option { - self.fetched_urls.get(url).copied() - } - - /// Replaces the context that matches the ID of the new context, if any match. - fn replace_context(&mut self, new_context: AssistantContext) { - let id = new_context.id(); - for context in self.context.iter_mut() { - if context.id() == id { - *context = new_context; - break; - } - } + pub fn includes_url(&self, url: impl Into) -> bool { + self.context_set + .contains(&FetchedUrlContext::lookup_key(url.into())) } pub fn file_paths(&self, cx: &App) -> HashSet { - self.context - .iter() + self.context() .filter_map(|context| match context { - AssistantContext::File(file) => { - let buffer = file.context_buffer.buffer.read(cx); + AgentContext::File(file) => { + let buffer = file.buffer.read(cx); buffer.project_path(cx) } - AssistantContext::Directory(_) - | AssistantContext::Symbol(_) - | AssistantContext::Selection(_) - | AssistantContext::FetchedUrl(_) - | AssistantContext::Thread(_) - | AssistantContext::Rules(_) - | AssistantContext::Image(_) => None, + AgentContext::Directory(_) + | AgentContext::Symbol(_) + | AgentContext::Selection(_) + | AgentContext::FetchedUrl(_) + | AgentContext::Thread(_) + | AgentContext::Rules(_) + | AgentContext::Image(_) => None, }) .collect() } - pub fn thread_ids(&self) -> HashSet { - self.threads.keys().cloned().collect() + pub fn thread_ids(&self) -> &HashSet { + &self.context_thread_ids } } pub enum FileInclusion { - Direct(ContextId), - InDirectory(ProjectPath), + Direct, + InDirectory { full_path: PathBuf }, } -fn make_context_symbol( - context_buffer: ContextBuffer, - path: ProjectPath, - name: SharedString, - range: Range, - enclosing_range: Range, -) -> ContextSymbol { - ContextSymbol { - id: ContextSymbolId { name, range, path }, - buffer_version: context_buffer.version, - enclosing_range, - buffer: context_buffer.buffer, - text: context_buffer.text, - } -} - -fn load_context_buffer_range( - buffer: Entity, - range: Range, - cx: &App, -) -> Result<(Range, Task)> { - let buffer_ref = buffer.read(cx); - let id = buffer_ref.remote_id(); - - let file = buffer_ref.file().context("context buffer missing path")?; - let full_path = file.full_path(cx); - - // Important to collect version at the same time as content so that staleness logic is correct. - let version = buffer_ref.version(); - let content = buffer_ref.text_for_range(range.clone()).collect::(); - let line_range = range.to_point(&buffer_ref.snapshot()); - - // Build the text on a background thread. - let task = cx.background_spawn({ - let line_range = line_range.clone(); - async move { - let text = to_fenced_codeblock(&full_path, content, Some(line_range)); - ContextBuffer { - id, - buffer, - last_full_path: full_path.into(), - version, - text, - } - } - }); - - Ok((line_range, task)) -} - -fn load_context_buffer(buffer: Entity, cx: &App) -> Result> { - let buffer_ref = buffer.read(cx); - let id = buffer_ref.remote_id(); - - let file = buffer_ref.file().context("context buffer missing path")?; - let full_path = file.full_path(cx); - - // Important to collect version at the same time as content so that staleness logic is correct. - let version = buffer_ref.version(); - let content = buffer_ref.as_rope().clone(); - - // Build the text on a background thread. - Ok(cx.background_spawn(async move { - let text = to_fenced_codeblock(&full_path, content, None); - ContextBuffer { - id, - buffer, - last_full_path: full_path.into(), - version, - text, - } - })) -} - -fn to_fenced_codeblock( - path: &Path, - content: Rope, - line_range: Option>, -) -> SharedString { - let line_range_text = line_range.map(|range| { - if range.start.row == range.end.row { - format!(":{}", range.start.row + 1) +impl FileInclusion { + fn check_file(file_context: &FileContext, path: &ProjectPath, cx: &App) -> Option { + let file_path = file_context.buffer.read(cx).project_path(cx)?; + if path == &file_path { + Some(FileInclusion::Direct) } else { - format!(":{}-{}", range.start.row + 1, range.end.row + 1) - } - }); - - let path_extension = path.extension().and_then(|ext| ext.to_str()); - let path_string = path.to_string_lossy(); - let capacity = 3 - + path_extension.map_or(0, |extension| extension.len() + 1) - + path_string.len() - + line_range_text.as_ref().map_or(0, |text| text.len()) - + 1 - + content.len() - + 5; - let mut buffer = String::with_capacity(capacity); - - buffer.push_str("```"); - - if let Some(extension) = path_extension { - buffer.push_str(extension); - buffer.push(' '); - } - buffer.push_str(&path_string); - - if let Some(line_range_text) = line_range_text { - buffer.push_str(&line_range_text); - } - - buffer.push('\n'); - for chunk in content.chunks() { - buffer.push_str(&chunk); - } - - if !buffer.ends_with('\n') { - buffer.push('\n'); - } - - buffer.push_str("```\n"); - - debug_assert!( - buffer.len() == capacity - 1 || buffer.len() == capacity, - "to_fenced_codeblock calculated capacity of {}, but length was {}", - capacity, - buffer.len(), - ); - - buffer.into() -} - -fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec> { - let mut files = Vec::new(); - - for entry in worktree.child_entries(path) { - if entry.is_dir() { - files.extend(collect_files_in_path(worktree, &entry.path)); - } else if entry.is_file() { - files.push(entry.path.clone()); - } - } - - files -} - -pub fn refresh_context_store_text( - context_store: Entity, - changed_buffers: &HashSet>, - cx: &App, -) -> impl Future> + use<> { - let mut tasks = Vec::new(); - - for context in &context_store.read(cx).context { - let id = context.id(); - - let task = maybe!({ - match context { - AssistantContext::File(file_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&file_context.context_buffer.buffer) - { - let context_store = context_store.clone(); - return refresh_file_text(context_store, file_context, cx); - } - } - AssistantContext::Directory(directory_context) => { - let directory_path = directory_context.project_path(cx)?; - let should_refresh = directory_path.path != directory_context.last_path - || changed_buffers.is_empty() - || changed_buffers.iter().any(|buffer| { - let Some(buffer_path) = buffer.read(cx).project_path(cx) else { - return false; - }; - buffer_path.starts_with(&directory_path) - }); - - if should_refresh { - let context_store = context_store.clone(); - return refresh_directory_text( - context_store, - directory_context, - directory_path, - cx, - ); - } - } - AssistantContext::Symbol(symbol_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&symbol_context.context_symbol.buffer) - { - let context_store = context_store.clone(); - return refresh_symbol_text(context_store, symbol_context, cx); - } - } - AssistantContext::Selection(selection_context) => { - // TODO: Should refresh if the path has changed, as it's in the text. - if changed_buffers.is_empty() - || changed_buffers.contains(&selection_context.context_buffer.buffer) - { - let context_store = context_store.clone(); - return refresh_selection_text(context_store, selection_context, cx); - } - } - AssistantContext::Thread(thread_context) => { - if changed_buffers.is_empty() { - let context_store = context_store.clone(); - return Some(refresh_thread_text(context_store, thread_context, cx)); - } - } - // Intentionally omit refreshing fetched URLs as it doesn't seem all that useful, - // and doing the caching properly could be tricky (unless it's already handled by - // the HttpClient?). - AssistantContext::FetchedUrl(_) => {} - AssistantContext::Rules(user_rules_context) => { - let context_store = context_store.clone(); - return Some(refresh_user_rules(context_store, user_rules_context, cx)); - } - AssistantContext::Image(_) => {} - } - None - }); - - if let Some(task) = task { - tasks.push(task.map(move |_| id)); } } - future::join_all(tasks) -} - -fn refresh_file_text( - context_store: Entity, - file_context: &FileContext, - cx: &App, -) -> Option> { - let id = file_context.id; - let task = refresh_context_buffer(&file_context.context_buffer, cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let context_buffer = task.await; - context_store - .update(cx, |context_store, _| { - let new_file_context = FileContext { id, context_buffer }; - context_store.replace_context(AssistantContext::File(new_file_context)); - }) - .ok(); - })) - } else { - None - } -} - -fn refresh_directory_text( - context_store: Entity, - directory_context: &DirectoryContext, - directory_path: ProjectPath, - cx: &App, -) -> Option> { - let mut stale = false; - let futures = directory_context - .context_buffers - .iter() - .map(|context_buffer| { - if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) { - stale = true; - future::Either::Left(refresh_task) + fn check_directory( + directory_context: &DirectoryContext, + path: &ProjectPath, + project: &Project, + cx: &App, + ) -> Option { + let worktree = project + .worktree_for_entry(directory_context.entry_id, cx)? + .read(cx); + let entry = worktree.entry_for_id(directory_context.entry_id)?; + let directory_path = ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }; + if path.starts_with(&directory_path) { + if path == &directory_path { + Some(FileInclusion::Direct) } else { - future::Either::Right(future::ready((*context_buffer).clone())) - } - }) - .collect::>(); - - if !stale { - return None; - } - - let context_buffers = future::join_all(futures); - - let id = directory_context.id; - let worktree = directory_context.worktree.clone(); - let entry_id = directory_context.entry_id; - let last_path = directory_path.path; - Some(cx.spawn(async move |cx| { - let context_buffers = context_buffers.await; - context_store - .update(cx, |context_store, _| { - let new_directory_context = DirectoryContext { - id, - worktree, - entry_id, - last_path, - context_buffers, - }; - context_store.replace_context(AssistantContext::Directory(new_directory_context)); - }) - .ok(); - })) -} - -fn refresh_symbol_text( - context_store: Entity, - symbol_context: &SymbolContext, - cx: &App, -) -> Option> { - let id = symbol_context.id; - let task = refresh_context_symbol(&symbol_context.context_symbol, cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let context_symbol = task.await; - context_store - .update(cx, |context_store, _| { - let new_symbol_context = SymbolContext { id, context_symbol }; - context_store.replace_context(AssistantContext::Symbol(new_symbol_context)); + Some(FileInclusion::InDirectory { + full_path: worktree.full_path(&entry.path), }) - .ok(); - })) - } else { - None - } -} - -fn refresh_selection_text( - context_store: Entity, - selection_context: &SelectionContext, - cx: &App, -) -> Option> { - let id = selection_context.id; - let range = selection_context.range.clone(); - let task = refresh_context_excerpt(&selection_context.context_buffer, range.clone(), cx); - if let Some(task) = task { - Some(cx.spawn(async move |cx| { - let (line_range, context_buffer) = task.await; - context_store - .update(cx, |context_store, _| { - let new_selection_context = SelectionContext { - id, - range, - line_range, - context_buffer, - }; - context_store - .replace_context(AssistantContext::Selection(new_selection_context)); - }) - .ok(); - })) - } else { - None - } -} - -fn refresh_thread_text( - context_store: Entity, - thread_context: &ThreadContext, - cx: &App, -) -> Task<()> { - let id = thread_context.id; - let thread = thread_context.thread.clone(); - cx.spawn(async move |cx| { - context_store - .update(cx, |context_store, cx| { - let text = thread.read(cx).latest_detailed_summary_or_text(); - context_store.replace_context(AssistantContext::Thread(ThreadContext { - id, - thread, - text, - })); - }) - .ok(); - }) -} - -fn refresh_user_rules( - context_store: Entity, - user_rules_context: &RulesContext, - cx: &App, -) -> Task<()> { - let id = user_rules_context.id; - let prompt_id = user_rules_context.prompt_id; - let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else { - return Task::ready(()); - }; - let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| { - thread_store.load_rules(prompt_id, cx) - }) else { - return Task::ready(()); - }; - cx.spawn(async move |cx| { - if let Ok((metadata, text)) = load_task.await { - if let Some(title) = metadata.title.clone() { - context_store - .update(cx, |context_store, _cx| { - context_store.replace_context(AssistantContext::Rules(RulesContext { - id, - prompt_id, - title, - text: text.into(), - })); - }) - .ok(); - return; } + } else { + None } - context_store - .update(cx, |context_store, cx| { - context_store.remove_context(id, cx); - }) - .ok(); - }) -} - -fn refresh_context_buffer(context_buffer: &ContextBuffer, cx: &App) -> Option> { - let buffer = context_buffer.buffer.read(cx); - if buffer.version.changed_since(&context_buffer.version) { - load_context_buffer(context_buffer.buffer.clone(), cx).log_err() - } else { - None - } -} - -fn refresh_context_excerpt( - context_buffer: &ContextBuffer, - range: Range, - cx: &App, -) -> Option, ContextBuffer)> + use<>> { - let buffer = context_buffer.buffer.read(cx); - if buffer.version.changed_since(&context_buffer.version) { - let (line_range, context_buffer_task) = - load_context_buffer_range(context_buffer.buffer.clone(), range, cx).log_err()?; - Some(context_buffer_task.map(move |context_buffer| (line_range, context_buffer))) - } else { - None - } -} - -fn refresh_context_symbol( - context_symbol: &ContextSymbol, - cx: &App, -) -> Option + use<>> { - let buffer = context_symbol.buffer.read(cx); - let project_path = buffer.project_path(cx)?; - if buffer.version.changed_since(&context_symbol.buffer_version) { - let (_line_range, context_buffer_task) = load_context_buffer_range( - context_symbol.buffer.clone(), - context_symbol.enclosing_range.clone(), - cx, - ) - .log_err()?; - let name = context_symbol.id.name.clone(); - let range = context_symbol.id.range.clone(); - let enclosing_range = context_symbol.enclosing_range.clone(); - Some(context_buffer_task.map(move |context_buffer| { - make_context_symbol(context_buffer, project_path, name, range, enclosing_range) - })) - } else { - None } } diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index 6245f88998..8e0ca981a4 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -12,9 +12,9 @@ use itertools::Itertools; use language::Buffer; use project::ProjectItem; use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; -use workspace::{Workspace, notifications::NotifyResultExt}; +use workspace::Workspace; -use crate::context::{ContextId, ContextKind}; +use crate::context::{AgentContext, ContextKind}; use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; use crate::thread::Thread; @@ -32,6 +32,7 @@ pub struct ContextStrip { focus_handle: FocusHandle, suggest_context_kind: SuggestContextKind, workspace: WeakEntity, + thread_store: Option>, _subscriptions: Vec, focused_index: Option, children_bounds: Option>>, @@ -73,12 +74,31 @@ impl ContextStrip { focus_handle, suggest_context_kind, workspace, + thread_store, _subscriptions: subscriptions, focused_index: None, children_bounds: None, } } + fn added_contexts(&self, cx: &App) -> Vec { + if let Some(workspace) = self.workspace.upgrade() { + let project = workspace.read(cx).project().read(cx); + let prompt_store = self + .thread_store + .as_ref() + .and_then(|thread_store| thread_store.upgrade()) + .and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref()); + self.context_store + .read(cx) + .context() + .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx)) + .collect::>() + } else { + Vec::new() + } + } + fn suggested_context(&self, cx: &Context) -> Option { match self.suggest_context_kind { SuggestContextKind::File => self.suggested_file(cx), @@ -93,22 +113,19 @@ impl ContextStrip { let editor = active_item.to_any().downcast::().ok()?.read(cx); let active_buffer_entity = editor.buffer().read(cx).as_singleton()?; let active_buffer = active_buffer_entity.read(cx); - let project_path = active_buffer.project_path(cx)?; if self .context_store .read(cx) - .will_include_buffer(active_buffer.remote_id(), &project_path) + .file_path_included(&project_path, cx) .is_some() { return None; } let file_name = active_buffer.file()?.file_name(cx); - let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx); - Some(SuggestedContext::File { name: file_name.to_string_lossy().into_owned().into(), buffer: active_buffer_entity.downgrade(), @@ -135,7 +152,6 @@ impl ContextStrip { .context_store .read(cx) .includes_thread(active_thread.id()) - .is_some() { return None; } @@ -272,12 +288,12 @@ impl ContextStrip { best.map(|(index, _, _)| index) } - fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) { + fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) { let Some(workspace) = self.workspace.upgrade() else { return; }; - crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx); + crate::active_thread::open_context(context, workspace, window, cx); } fn remove_focused_context( @@ -287,17 +303,17 @@ impl ContextStrip { cx: &mut Context, ) { if let Some(index) = self.focused_index { - let mut is_empty = false; + let added_contexts = self.added_contexts(cx); + let Some(context) = added_contexts.get(index) else { + return; + }; self.context_store.update(cx, |this, cx| { - if let Some(item) = this.context().get(index) { - this.remove_context(item.id(), cx); - } - - is_empty = this.context().is_empty(); + this.remove_context(&context.context, cx); }); - if is_empty { + let is_now_empty = added_contexts.len() == 1; + if is_now_empty { cx.emit(ContextStripEvent::BlurredEmpty); } else { self.focused_index = Some(index.saturating_sub(1)); @@ -306,49 +322,28 @@ impl ContextStrip { } } - fn is_suggested_focused(&self, context: &Vec) -> bool { + fn is_suggested_focused(&self, added_contexts: &Vec) -> bool { // We only suggest one item after the actual context - self.focused_index == Some(context.len()) + self.focused_index == Some(added_contexts.len()) } fn accept_suggested_context( &mut self, _: &AcceptSuggestedContext, - window: &mut Window, + _window: &mut Window, cx: &mut Context, ) { if let Some(suggested) = self.suggested_context(cx) { - let context_store = self.context_store.read(cx); - - if self.is_suggested_focused(context_store.context()) { - self.add_suggested_context(&suggested, window, cx); + if self.is_suggested_focused(&self.added_contexts(cx)) { + self.add_suggested_context(&suggested, cx); } } } - fn add_suggested_context( - &mut self, - suggested: &SuggestedContext, - window: &mut Window, - cx: &mut Context, - ) { - let task = self.context_store.update(cx, |context_store, cx| { - context_store.accept_suggested_context(&suggested, cx) + fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context) { + self.context_store.update(cx, |context_store, cx| { + context_store.add_suggested_context(&suggested, cx) }); - - cx.spawn_in(window, async move |this, cx| { - match task.await.notify_async_err(cx) { - None => {} - Some(()) => { - if let Some(this) = this.upgrade() { - this.update(cx, |_, cx| cx.notify())?; - } - } - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - cx.notify(); } } @@ -361,17 +356,10 @@ impl Focusable for ContextStrip { impl Render for ContextStrip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let context_store = self.context_store.read(cx); - let context = context_store.context(); let context_picker = self.context_picker.clone(); let focus_handle = self.focus_handle.clone(); - let suggested_context = self.suggested_context(cx); - - let added_contexts = context - .iter() - .map(|c| AddedContext::new(c, cx)) - .collect::>(); + let added_contexts = self.added_contexts(cx); let dupe_names = added_contexts .iter() .map(|c| c.name.clone()) @@ -380,6 +368,14 @@ impl Render for ContextStrip { .filter(|(a, b)| a == b) .map(|(a, _)| a) .collect::>(); + let no_added_context = added_contexts.is_empty(); + + let suggested_context = self.suggested_context(cx).map(|suggested_context| { + ( + suggested_context, + self.is_suggested_focused(&added_contexts), + ) + }); h_flex() .flex_wrap() @@ -436,7 +432,7 @@ impl Render for ContextStrip { }) .with_handle(self.context_picker_menu_handle.clone()), ) - .when(context.is_empty() && suggested_context.is_none(), { + .when(no_added_context && suggested_context.is_none(), { |parent| { parent.child( h_flex() @@ -466,16 +462,17 @@ impl Render for ContextStrip { .enumerate() .map(|(i, added_context)| { let name = added_context.name.clone(); - let id = added_context.id; + let context = added_context.context.clone(); ContextPill::added( added_context, dupe_names.contains(&name), self.focused_index == Some(i), Some({ + let context = context.clone(); let context_store = self.context_store.clone(); Rc::new(cx.listener(move |_this, _event, _window, cx| { context_store.update(cx, |this, cx| { - this.remove_context(id, cx); + this.remove_context(&context, cx); }); cx.notify(); })) @@ -484,7 +481,7 @@ impl Render for ContextStrip { .on_click({ Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { if event.down.click_count > 1 { - this.open_context(id, window, cx); + this.open_context(&context, window, cx); } else { this.focused_index = Some(i); } @@ -493,22 +490,22 @@ impl Render for ContextStrip { }) }), ) - .when_some(suggested_context, |el, suggested| { + .when_some(suggested_context, |el, (suggested, focused)| { el.child( ContextPill::suggested( suggested.name().clone(), suggested.icon_path(), suggested.kind(), - self.is_suggested_focused(&context), + focused, ) .on_click(Rc::new(cx.listener( - move |this, _event, window, cx| { - this.add_suggested_context(&suggested, window, cx); + move |this, _event, _window, cx| { + this.add_suggested_context(&suggested, cx); }, ))), ) }) - .when(!context.is_empty(), { + .when(!no_added_context, { move |parent| { parent.child( IconButton::new("remove-all-context", IconName::Eraser) @@ -534,6 +531,7 @@ impl Render for ContextStrip { ) } }) + .into_any() } } diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 3947a7dc0e..029fe1b381 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -51,7 +51,10 @@ impl HistoryStore { return history_entries; } - for thread in self.thread_store.update(cx, |this, _cx| this.threads()) { + for thread in self + .thread_store + .update(cx, |this, _cx| this.reverse_chronological_threads()) + { history_entries.push(HistoryEntry::Thread(thread)); } diff --git a/crates/agent/src/inline_assistant.rs b/crates/agent/src/inline_assistant.rs index e8d626b82a..6785d08574 100644 --- a/crates/agent/src/inline_assistant.rs +++ b/crates/agent/src/inline_assistant.rs @@ -32,6 +32,7 @@ use project::LspAction; use project::Project; use project::{CodeAction, ProjectTransaction}; use prompt_store::PromptBuilder; +use prompt_store::PromptStore; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use terminal_view::{TerminalView, terminal_panel::TerminalPanel}; @@ -245,9 +246,13 @@ impl InlineAssistant { .map_or(false, |model| model.provider.is_authenticated(cx)) }; - let thread_store = workspace + let assistant_panel = workspace .panel::(cx) - .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade()); + .map(|assistant_panel| assistant_panel.read(cx)); + let prompt_store = assistant_panel + .and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned()); + let thread_store = + assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade()); let handle_assist = |window: &mut Window, cx: &mut Context| match inline_assist_target { @@ -257,6 +262,7 @@ impl InlineAssistant { &active_editor, cx.entity().downgrade(), workspace.project().downgrade(), + prompt_store, thread_store, window, cx, @@ -269,6 +275,7 @@ impl InlineAssistant { &active_terminal, cx.entity().downgrade(), workspace.project().downgrade(), + prompt_store, thread_store, window, cx, @@ -323,6 +330,7 @@ impl InlineAssistant { editor: &Entity, workspace: WeakEntity, project: WeakEntity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -437,6 +445,8 @@ impl InlineAssistant { range.clone(), None, context_store.clone(), + project.clone(), + prompt_store.clone(), self.telemetry.clone(), self.prompt_builder.clone(), cx, @@ -525,6 +535,7 @@ impl InlineAssistant { initial_transaction_id: Option, focus: bool, workspace: Entity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -543,7 +554,7 @@ impl InlineAssistant { } let project = workspace.read(cx).project().downgrade(); - let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone())); + let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone())); let codegen = cx.new(|cx| { BufferCodegen::new( @@ -551,6 +562,8 @@ impl InlineAssistant { range.clone(), initial_transaction_id, context_store.clone(), + project, + prompt_store, self.telemetry.clone(), self.prompt_builder.clone(), cx, @@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { let editor = self.editor.clone(); let workspace = self.workspace.clone(); let thread_store = self.thread_store.clone(); + let prompt_store = PromptStore::global(cx); window.spawn(cx, async move |cx| { let workspace = workspace.upgrade().context("workspace was released")?; let editor = editor.upgrade().context("editor was released")?; @@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { })? .context("invalid range")?; + let prompt_store = prompt_store.await.ok(); cx.update_global(|assistant: &mut InlineAssistant, window, cx| { let assist_id = assistant.suggest_assist( &editor, @@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider { None, true, workspace, + prompt_store, thread_store, window, cx, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 05599b536b..ef60b0e6c4 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use crate::assistant_model_selector::ModelType; -use crate::context::{AssistantContext, format_context_as_string}; +use crate::context::{ContextLoadResult, load_context}; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use buffer_diff::BufferDiff; use collections::HashSet; @@ -13,6 +13,8 @@ use editor::{ }; use file_icons::FileIcons; use fs::Fs; +use futures::future::Shared; +use futures::{FutureExt as _, future}; use gpui::{ Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, @@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; +use prompt_store::PromptStore; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -31,7 +34,7 @@ use workspace::Workspace; use crate::assistant_model_selector::AssistantModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; -use crate::context_store::{ContextStore, refresh_context_store_text}; +use crate::context_store::ContextStore; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; use crate::thread::{Thread, TokenUsageRatio}; @@ -49,9 +52,12 @@ pub struct MessageEditor { workspace: WeakEntity, project: Entity, context_store: Entity, + prompt_store: Option>, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, + last_loaded_context: Option, + context_load_task: Option>>, profile_selector: Entity, edits_expanded: bool, editor_is_expanded: bool, @@ -68,6 +74,7 @@ impl MessageEditor { fs: Arc, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, thread_store: WeakEntity, thread: Entity, window: &mut Window, @@ -135,13 +142,11 @@ impl MessageEditor { let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), cx.subscribe(&editor, |this, _, event, cx| match event { - EditorEvent::BufferEdited => { - this.message_or_context_changed(true, cx); - } + EditorEvent::BufferEdited => this.handle_message_changed(cx), _ => {} }), cx.observe(&context_store, |this, _, cx| { - this.message_or_context_changed(false, cx); + this.handle_context_changed(cx) }), ]; @@ -152,8 +157,11 @@ impl MessageEditor { incompatible_tools_state: incompatible_tools.clone(), workspace, context_store, + prompt_store, context_strip, context_picker_menu_handle, + context_load_task: None, + last_loaded_context: None, model_selector: cx.new(|cx| { AssistantModelSelector::new( fs.clone(), @@ -175,6 +183,10 @@ impl MessageEditor { } } + pub fn context_store(&self) -> &Entity { + &self.context_store + } + fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context) { cx.notify(); } @@ -214,6 +226,7 @@ impl MessageEditor { ) { self.context_picker_menu_handle.toggle(window, cx); } + pub fn remove_all_context( &mut self, _: &RemoveAllContext, @@ -270,68 +283,22 @@ impl MessageEditor { self.last_estimated_token_count.take(); cx.emit(MessageEditorEvent::EstimatedTokenCount); - let refresh_task = - refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); - let wait_for_images = self.context_store.read(cx).wait_for_images(cx); - let thread = self.thread.clone(); - let context_store = self.context_store.clone(); let git_store = self.project.read(cx).git_store().clone(); let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); + let context_task = self.wait_for_context(cx); let window_handle = window.window_handle(); - cx.spawn(async move |this, cx| { - let checkpoint = checkpoint.await.ok(); - refresh_task.await; - wait_for_images.await; + cx.spawn(async move |_this, cx| { + let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await; + let loaded_context = loaded_context.unwrap_or_default(); thread .update(cx, |thread, cx| { - let context = context_store.read(cx).context().clone(); - thread.insert_user_message(user_message, context, checkpoint, cx); + thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx); }) .log_err(); - context_store - .update(cx, |context_store, cx| { - let excerpt_ids = context_store - .context() - .iter() - .filter(|ctx| { - matches!( - ctx, - AssistantContext::Selection(_) | AssistantContext::Image(_) - ) - }) - .map(|ctx| ctx.id()) - .collect::>(); - - for id in excerpt_ids { - context_store.remove_context(id, cx); - } - }) - .log_err(); - - if let Some(wait_for_summaries) = context_store - .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) - .log_err() - { - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = true; - cx.notify(); - }) - .log_err(); - - wait_for_summaries.await; - - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = false; - cx.notify(); - }) - .log_err(); - } - - // Send to model after summaries are done thread .update(cx, |thread, cx| { thread.advance_prompt_id(); @@ -342,6 +309,30 @@ impl MessageEditor { .detach(); } + fn wait_for_summaries(&mut self, cx: &mut Context) -> Task<()> { + let context_store = self.context_store.clone(); + cx.spawn(async move |this, cx| { + if let Some(wait_for_summaries) = context_store + .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) + .ok() + { + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = true; + cx.notify(); + }) + .ok(); + + wait_for_summaries.await; + + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = false; + cx.notify(); + }) + .ok(); + } + }) + } + fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context) { let cancelled = self.thread.update(cx, |thread, cx| { thread.cancel_last_completion(Some(window.window_handle()), cx) @@ -1015,6 +1006,49 @@ impl MessageEditor { self.update_token_count_task.is_some() } + fn handle_message_changed(&mut self, cx: &mut Context) { + self.message_or_context_changed(true, cx); + } + + fn handle_context_changed(&mut self, cx: &mut Context) { + let summaries_task = self.wait_for_summaries(cx); + let load_task = cx.spawn(async move |this, cx| { + // Waits for detailed summaries before `load_context`, as it directly reads these from + // the thread. TODO: Would be cleaner to have context loading await on summarization. + summaries_task.await; + let Ok(load_task) = this.update(cx, |this, cx| { + let new_context = this.context_store.read_with(cx, |context_store, cx| { + context_store.new_context_for_thread(this.thread.read(cx)) + }); + load_context(new_context, &this.project, &this.prompt_store, cx) + }) else { + return; + }; + let result = load_task.await; + this.update(cx, |this, cx| { + this.last_loaded_context = Some(result); + this.context_load_task = None; + this.message_or_context_changed(false, cx); + }) + .ok(); + }); + // Replace existing load task, if any, causing it to be cancelled. + self.context_load_task = Some(load_task.shared()); + } + + fn wait_for_context(&self, cx: &mut Context) -> Task> { + if let Some(context_load_task) = self.context_load_task.clone() { + cx.spawn(async move |this, cx| { + context_load_task.await; + this.read_with(cx, |this, _cx| this.last_loaded_context.clone()) + .ok() + .flatten() + }) + } else { + Task::ready(self.last_loaded_context.clone()) + } + } + fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context) { cx.emit(MessageEditorEvent::Changed); self.update_token_count_task.take(); @@ -1024,9 +1058,7 @@ impl MessageEditor { return; }; - let context_store = self.context_store.clone(); let editor = self.editor.clone(); - let thread = self.thread.clone(); self.update_token_count_task = Some(cx.spawn(async move |this, cx| { if debounce { @@ -1035,27 +1067,33 @@ impl MessageEditor { .await; } - let token_count = if let Some(task) = cx.update(|cx| { - let context = context_store.read(cx).context().iter(); - let new_context = thread.read(cx).filter_new_context(context); - let context_text = - format_context_as_string(new_context, cx).unwrap_or(String::new()); + let token_count = if let Some(task) = this.update(cx, |this, cx| { + let loaded_context = this + .last_loaded_context + .as_ref() + .map(|context_load_result| &context_load_result.loaded_context); let message_text = editor.read(cx).text(cx); - let content = context_text + &message_text; - - if content.is_empty() { + if message_text.is_empty() + && loaded_context.map_or(true, |loaded_context| loaded_context.is_empty()) + { return None; } + let mut request_message = LanguageModelRequestMessage { + role: language_model::Role::User, + content: Vec::new(), + cache: false, + }; + + if let Some(loaded_context) = loaded_context { + loaded_context.add_to_request_message(&mut request_message); + } + let request = language_model::LanguageModelRequest { thread_id: None, prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: language_model::Role::User, - content: vec![content.into()], - cache: false, - }], + messages: vec![request_message], tools: vec![], stop: vec![], temperature: None, diff --git a/crates/agent/src/terminal_codegen.rs b/crates/agent/src/terminal_codegen.rs index 8c0e9e1675..925187c7cf 100644 --- a/crates/agent/src/terminal_codegen.rs +++ b/crates/agent/src/terminal_codegen.rs @@ -32,7 +32,7 @@ impl TerminalCodegen { } } - pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context) { + pub fn start(&mut self, prompt_task: Task, cx: &mut Context) { let Some(ConfiguredModel { model, .. }) = LanguageModelRegistry::read_global(cx).inline_assistant_model() else { @@ -45,6 +45,7 @@ impl TerminalCodegen { self.status = CodegenStatus::Pending; self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); self.generation = cx.spawn(async move |this, cx| { + let prompt = prompt_task.await; let model_telemetry_id = model.telemetry_id(); let model_provider_id = model.provider_id(); let response = model.stream_completion_text(prompt, &cx).await; diff --git a/crates/agent/src/terminal_inline_assistant.rs b/crates/agent/src/terminal_inline_assistant.rs index 95099e542c..b7690fa580 100644 --- a/crates/agent/src/terminal_inline_assistant.rs +++ b/crates/agent/src/terminal_inline_assistant.rs @@ -1,4 +1,4 @@ -use crate::context::attach_context_to_message; +use crate::context::load_context; use crate::context_store::ContextStore; use crate::inline_prompt_editor::{ CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, @@ -10,14 +10,14 @@ use client::telemetry::Telemetry; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; -use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity}; +use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity}; use language::Buffer; use language_model::{ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, report_assistant_event, }; use project::Project; -use prompt_store::PromptBuilder; +use prompt_store::{PromptBuilder, PromptStore}; use std::sync::Arc; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use terminal_view::TerminalView; @@ -69,6 +69,7 @@ impl TerminalInlineAssistant { terminal_view: &Entity, workspace: WeakEntity, project: WeakEntity, + prompt_store: Option>, thread_store: Option>, window: &mut Window, cx: &mut App, @@ -109,6 +110,7 @@ impl TerminalInlineAssistant { prompt_editor, workspace.clone(), context_store, + prompt_store, window, cx, ); @@ -196,11 +198,11 @@ impl TerminalInlineAssistant { .log_err(); let codegen = assist.codegen.clone(); - let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else { + let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else { return; }; - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + codegen.update(cx, |codegen, cx| codegen.start(request_task, cx)); } fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) { @@ -217,7 +219,7 @@ impl TerminalInlineAssistant { &self, assist_id: TerminalInlineAssistId, cx: &mut App, - ) -> Result { + ) -> Result> { let assist = self.assists.get(&assist_id).context("invalid assist")?; let shell = std::env::var("SHELL").ok(); @@ -246,28 +248,40 @@ impl TerminalInlineAssistant { &latest_output, )?; - let mut request_message = LanguageModelRequestMessage { - role: Role::User, - content: vec![], - cache: false, - }; + let contexts = assist + .context_store + .read(cx) + .context() + .cloned() + .collect::>(); + let context_load_task = assist.workspace.update(cx, |workspace, cx| { + let project = workspace.project(); + load_context(contexts, project, &assist.prompt_store, cx) + })?; - attach_context_to_message( - &mut request_message, - assist.context_store.read(cx).context().iter(), - cx, - ); + Ok(cx.background_spawn(async move { + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![], + cache: false, + }; - request_message.content.push(prompt.into()); + context_load_task + .await + .loaded_context + .add_to_request_message(&mut request_message); - Ok(LanguageModelRequest { - thread_id: None, - prompt_id: None, - messages: vec![request_message], - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - }) + request_message.content.push(prompt.into()); + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![request_message], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + } + })) } fn finish_assist( @@ -380,6 +394,7 @@ struct TerminalInlineAssist { codegen: Entity, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, _subscriptions: Vec, } @@ -390,6 +405,7 @@ impl TerminalInlineAssist { prompt_editor: Entity>, workspace: WeakEntity, context_store: Entity, + prompt_store: Option>, window: &mut Window, cx: &mut App, ) -> Self { @@ -400,6 +416,7 @@ impl TerminalInlineAssist { codegen: codegen.clone(), workspace: workspace.clone(), context_store, + prompt_store, _subscriptions: vec![ window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| { TerminalInlineAssistant::update_global(cx, |this, cx| { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index cce7d109c0..29709490c9 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,7 +8,7 @@ use anyhow::{Result, anyhow}; use assistant_settings::AssistantSettings; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use collections::{BTreeMap, HashMap}; +use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; @@ -18,9 +18,9 @@ use gpui::{ }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, + LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, TokenUsage, }; @@ -35,7 +35,7 @@ use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; -use crate::context::{AssistantContext, ContextId, format_context_as_string}; +use crate::context::{AgentContext, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext, @@ -98,8 +98,7 @@ pub struct Message { pub id: MessageId, pub role: Role, pub segments: Vec, - pub context: String, - pub images: Vec, + pub loaded_context: LoadedContext, } impl Message { @@ -138,8 +137,8 @@ impl Message { pub fn to_string(&self) -> String { let mut result = String::new(); - if !self.context.is_empty() { - result.push_str(&self.context); + if !self.loaded_context.text.is_empty() { + result.push_str(&self.loaded_context.text); } for segment in &self.segments { @@ -294,8 +293,6 @@ pub struct Thread { messages: Vec, next_message_id: MessageId, last_prompt_id: PromptId, - context: BTreeMap, - context_by_message: HashMap>, project_context: SharedProjectContext, checkpoints_by_message: HashMap, completion_count: usize, @@ -345,8 +342,6 @@ impl Thread { messages: Vec::new(), next_message_id: MessageId(0), last_prompt_id: PromptId::new(), - context: BTreeMap::default(), - context_by_message: HashMap::default(), project_context: system_prompt, checkpoints_by_message: HashMap::default(), completion_count: 0, @@ -418,14 +413,15 @@ impl Thread { } }) .collect(), - context: message.context, - images: Vec::new(), + loaded_context: LoadedContext { + contexts: Vec::new(), + text: message.context, + images: Vec::new(), + }, }) .collect(), next_message_id, last_prompt_id: PromptId::new(), - context: BTreeMap::default(), - context_by_message: HashMap::default(), project_context, checkpoints_by_message: HashMap::default(), completion_count: 0, @@ -660,21 +656,17 @@ impl Thread { return; }; for deleted_message in self.messages.drain(message_ix..) { - self.context_by_message.remove(&deleted_message.id); self.checkpoints_by_message.remove(&deleted_message.id); } cx.notify(); } - pub fn context_for_message(&self, id: MessageId) -> impl Iterator { - self.context_by_message - .get(&id) + pub fn context_for_message(&self, id: MessageId) -> impl Iterator { + self.messages + .iter() + .find(|message| message.id == id) .into_iter() - .flat_map(|context| { - context - .iter() - .filter_map(|context_id| self.context.get(&context_id)) - }) + .flat_map(|message| message.loaded_context.contexts.iter()) } pub fn is_turn_end(&self, ix: usize) -> bool { @@ -736,91 +728,27 @@ impl Thread { self.tool_use.tool_result_card(id).cloned() } - /// Filter out contexts that have already been included in previous messages - pub fn filter_new_context<'a>( - &self, - context: impl Iterator, - ) -> impl Iterator { - context.filter(|ctx| self.is_context_new(ctx)) - } - - fn is_context_new(&self, context: &AssistantContext) -> bool { - !self.context.contains_key(&context.id()) - } - pub fn insert_user_message( &mut self, text: impl Into, - context: Vec, + loaded_context: ContextLoadResult, git_checkpoint: Option, cx: &mut Context, ) -> MessageId { - let text = text.into(); - - let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx); - - let new_context: Vec<_> = context - .into_iter() - .filter(|ctx| self.is_context_new(ctx)) - .collect(); - - if !new_context.is_empty() { - if let Some(context_string) = format_context_as_string(new_context.iter(), cx) { - if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) { - message.context = context_string; - } - } - - if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) { - message.images = new_context - .iter() - .filter_map(|context| { - if let AssistantContext::Image(image_context) = context { - image_context.image_task.clone().now_or_never().flatten() - } else { - None - } - }) - .collect::>(); - } - + if !loaded_context.referenced_buffers.is_empty() { self.action_log.update(cx, |log, cx| { - // Track all buffers added as context - for ctx in &new_context { - match ctx { - AssistantContext::File(file_ctx) => { - log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx); - } - AssistantContext::Directory(dir_ctx) => { - for context_buffer in &dir_ctx.context_buffers { - log.track_buffer(context_buffer.buffer.clone(), cx); - } - } - AssistantContext::Symbol(symbol_ctx) => { - log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx); - } - AssistantContext::Selection(selection_context) => { - log.track_buffer(selection_context.context_buffer.buffer.clone(), cx); - } - AssistantContext::FetchedUrl(_) - | AssistantContext::Thread(_) - | AssistantContext::Rules(_) - | AssistantContext::Image(_) => {} - } + for buffer in loaded_context.referenced_buffers { + log.track_buffer(buffer, cx); } }); } - let context_ids = new_context - .iter() - .map(|context| context.id()) - .collect::>(); - self.context.extend( - new_context - .into_iter() - .map(|context| (context.id(), context)), + let message_id = self.insert_message( + Role::User, + vec![MessageSegment::Text(text.into())], + loaded_context.loaded_context, + cx, ); - self.context_by_message.insert(message_id, context_ids); if let Some(git_checkpoint) = git_checkpoint { self.pending_checkpoint = Some(ThreadCheckpoint { @@ -834,10 +762,19 @@ impl Thread { message_id } + pub fn insert_assistant_message( + &mut self, + segments: Vec, + cx: &mut Context, + ) -> MessageId { + self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx) + } + pub fn insert_message( &mut self, role: Role, segments: Vec, + loaded_context: LoadedContext, cx: &mut Context, ) -> MessageId { let id = self.next_message_id.post_inc(); @@ -845,8 +782,7 @@ impl Thread { id, role, segments, - context: String::new(), - images: Vec::new(), + loaded_context, }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); @@ -875,7 +811,6 @@ impl Thread { return false; }; self.messages.remove(index); - self.context_by_message.remove(&id); self.touch_updated_at(); cx.emit(ThreadEvent::MessageDeleted(id)); true @@ -962,7 +897,7 @@ impl Thread { content: tool_result.content.clone(), }) .collect(), - context: message.context.clone(), + context: message.loaded_context.text.clone(), }) .collect(), initial_project_snapshot, @@ -1080,26 +1015,9 @@ impl Thread { cache: false, }; - if !message.context.is_empty() { - request_message - .content - .push(MessageContent::Text(message.context.to_string())); - } - - if !message.images.is_empty() { - // Some providers only support image parts after an initial text part - if request_message.content.is_empty() { - request_message - .content - .push(MessageContent::Text("Images attached by user:".to_string())); - } - - for image in &message.images { - request_message - .content - .push(MessageContent::Image(image.clone())) - } - } + message + .loaded_context + .add_to_request_message(&mut request_message); for segment in &message.segments { match segment { @@ -1301,11 +1219,11 @@ impl Thread { match event { LanguageModelCompletionEvent::StartMessage { .. } => { - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Text(String::new())], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Text(String::new())], + cx, + )); } LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -1334,11 +1252,11 @@ impl Thread { // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Text(chunk.to_string())], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Text(chunk.to_string())], + cx, + )); }; } } @@ -1361,14 +1279,14 @@ impl Thread { // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = Some(thread.insert_message( - Role::Assistant, - vec![MessageSegment::Thinking { - text: chunk.to_string(), - signature, - }], - cx, - )); + request_assistant_message_id = + Some(thread.insert_assistant_message( + vec![MessageSegment::Thinking { + text: chunk.to_string(), + signature, + }], + cx, + )); }; } } @@ -1376,7 +1294,7 @@ impl Thread { let last_assistant_message_id = request_assistant_message_id .unwrap_or_else(|| { let new_assistant_message_id = - thread.insert_message(Role::Assistant, vec![], cx); + thread.insert_assistant_message(vec![], cx); request_assistant_message_id = Some(new_assistant_message_id); new_assistant_message_id @@ -2097,8 +2015,16 @@ impl Thread { } )?; - if !message.context.is_empty() { - writeln!(markdown, "{}", message.context)?; + if !message.loaded_context.text.is_empty() { + writeln!(markdown, "{}", message.loaded_context.text)?; + } + + if !message.loaded_context.images.is_empty() { + writeln!( + markdown, + "\n{} images attached as context.\n", + message.loaded_context.images.len() + )?; } for segment in &message.segments { @@ -2373,7 +2299,7 @@ struct PendingCompletion { #[cfg(test)] mod tests { use super::*; - use crate::{ThreadStore, context_store::ContextStore, thread_store}; + use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use assistant_settings::AssistantSettings; use context_server::ContextServerSettings; use editor::EditorSettings; @@ -2404,12 +2330,14 @@ mod tests { .await .unwrap(); - let context = - context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap()); + let loaded_context = cx + .update(|cx| load_context(vec![context], &project, &None, cx)) + .await; // Insert user message with context let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Please explain this code", vec![context], None, cx) + thread.insert_user_message("Please explain this code", loaded_context, None, cx) }); // Check content and context in message object @@ -2443,7 +2371,7 @@ fn main() {{ message.segments[0], MessageSegment::Text("Please explain this code".to_string()) ); - assert_eq!(message.context, expected_context); + assert_eq!(message.loaded_context.text, expected_context); // Check message in request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2470,48 +2398,50 @@ fn main() {{ let (_, _thread_store, thread, context_store) = setup_test_environment(cx, project.clone()).await; - // Open files individually + // First message with context 1 add_file_to_context(&project, &context_store, "test/file1.rs", cx) .await .unwrap(); - add_file_to_context(&project, &context_store, "test/file2.rs", cx) - .await - .unwrap(); - add_file_to_context(&project, &context_store, "test/file3.rs", cx) - .await - .unwrap(); - - // Get the context objects - let contexts = context_store.update(cx, |store, _| store.context().clone()); - assert_eq!(contexts.len(), 3); - - // First message with context 1 + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message1_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx) + thread.insert_user_message("Message 1", loaded_context, None, cx) }); // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) + add_file_to_context(&project, &context_store, "test/file2.rs", cx) + .await + .unwrap(); + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Message 2", - vec![contexts[0].clone(), contexts[1].clone()], - None, - cx, - ) + thread.insert_user_message("Message 2", loaded_context, None, cx) }); // Third message with all three contexts (contexts 1 and 2 should be skipped) + // + add_file_to_context(&project, &context_store, "test/file3.rs", cx) + .await + .unwrap(); + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await; let message3_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Message 3", - vec![ - contexts[0].clone(), - contexts[1].clone(), - contexts[2].clone(), - ], - None, - cx, - ) + thread.insert_user_message("Message 3", loaded_context, None, cx) }); // Check what contexts are included in each message @@ -2524,16 +2454,16 @@ fn main() {{ }); // First message should include context 1 - assert!(message1.context.contains("file1.rs")); + assert!(message1.loaded_context.text.contains("file1.rs")); // Second message should include only context 2 (not 1) - assert!(!message2.context.contains("file1.rs")); - assert!(message2.context.contains("file2.rs")); + assert!(!message2.loaded_context.text.contains("file1.rs")); + assert!(message2.loaded_context.text.contains("file2.rs")); // Third message should include only context 3 (not 1 or 2) - assert!(!message3.context.contains("file1.rs")); - assert!(!message3.context.contains("file2.rs")); - assert!(message3.context.contains("file3.rs")); + assert!(!message3.loaded_context.text.contains("file1.rs")); + assert!(!message3.loaded_context.text.contains("file2.rs")); + assert!(message3.loaded_context.text.contains("file3.rs")); // Check entire request to make sure all contexts are properly included let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2570,7 +2500,12 @@ fn main() {{ // Insert user message without any context (empty context vector) let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx) + thread.insert_user_message( + "What is the best way to learn Rust?", + ContextLoadResult::default(), + None, + cx, + ) }); // Check content and context in message object @@ -2583,7 +2518,7 @@ fn main() {{ message.segments[0], MessageSegment::Text("What is the best way to learn Rust?".to_string()) ); - assert_eq!(message.context, ""); + assert_eq!(message.loaded_context.text, ""); // Check message in request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2596,12 +2531,17 @@ fn main() {{ // Add second message, also without context let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Are there any good books?", vec![], None, cx) + thread.insert_user_message( + "Are there any good books?", + ContextLoadResult::default(), + None, + cx, + ) }); let message2 = thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); - assert_eq!(message2.context, ""); + assert_eq!(message2.loaded_context.text, ""); // Check that both messages appear in the request let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); @@ -2635,12 +2575,14 @@ fn main() {{ .await .unwrap(); - let context = - context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap()); + let loaded_context = cx + .update(|cx| load_context(vec![context], &project, &None, cx)) + .await; // Insert user message with the buffer as context thread.update(cx, |thread, cx| { - thread.insert_user_message("Explain this code", vec![context], None, cx) + thread.insert_user_message("Explain this code", loaded_context, None, cx) }); // Create a request and check that it doesn't have a stale buffer warning yet @@ -2668,7 +2610,12 @@ fn main() {{ // Insert another user message without context thread.update(cx, |thread, cx| { - thread.insert_user_message("What does the code do now?", vec![], None, cx) + thread.insert_user_message( + "What does the code do now?", + ContextLoadResult::default(), + None, + cx, + ) }); // Create a new request and check for the stale buffer warning @@ -2735,6 +2682,7 @@ fn main() {{ ThreadStore::load( project.clone(), cx.new(|_| ToolWorkingSet::default()), + None, Arc::new(PromptBuilder::new(None).unwrap()), cx, ) @@ -2759,15 +2707,15 @@ fn main() {{ .unwrap(); let buffer = project - .update(cx, |project, cx| project.open_buffer(buffer_path, cx)) + .update(cx, |project, cx| { + project.open_buffer(buffer_path.clone(), cx) + }) .await .unwrap(); - context_store - .update(cx, |store, cx| { - store.add_file_from_buffer(buffer.clone(), cx) - }) - .await?; + context_store.update(cx, |context_store, cx| { + context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); + }); Ok(buffer) } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 117be3675c..a907e1d440 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -24,8 +24,8 @@ use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::{Project, Worktree}; use prompt_store::{ - ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent, - RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext, + ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext, + UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; @@ -82,12 +82,11 @@ impl ThreadStore { pub fn load( project: Entity, tools: Entity, + prompt_store: Option>, prompt_builder: Arc, cx: &mut App, ) -> Task>> { - let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { - let prompt_store = prompt_store.await.ok(); let (thread_store, ready_rx) = cx.update(|cx| { let mut option_ready_rx = None; let thread_store = cx.new(|cx| { @@ -349,25 +348,8 @@ impl ThreadStore { self.context_server_manager.clone() } - pub fn prompt_store(&self) -> Option> { - self.prompt_store.clone() - } - - pub fn load_rules( - &self, - prompt_id: UserPromptId, - cx: &App, - ) -> Task> { - let prompt_id = PromptId::User { uuid: prompt_id }; - let Some(prompt_store) = self.prompt_store.as_ref() else { - return Task::ready(Err(anyhow!("Prompt store unexpectedly missing."))); - }; - let prompt_store = prompt_store.read(cx); - let Some(metadata) = prompt_store.metadata(prompt_id) else { - return Task::ready(Err(anyhow!("User rules not found in library."))); - }; - let text_task = prompt_store.load(prompt_id, cx); - cx.background_spawn(async move { Ok((metadata, text_task.await?)) }) + pub fn prompt_store(&self) -> &Option> { + &self.prompt_store } pub fn tools(&self) -> Entity { @@ -379,16 +361,12 @@ impl ThreadStore { self.threads.len() } - pub fn threads(&self) -> Vec { + pub fn reverse_chronological_threads(&self) -> Vec { let mut threads = self.threads.iter().cloned().collect::>(); threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at)); threads } - pub fn recent_threads(&self, limit: usize) -> Vec { - self.threads().into_iter().take(limit).collect() - } - pub fn create_thread(&mut self, cx: &mut Context) -> Entity { cx.new(|cx| { Thread::new( diff --git a/crates/agent/src/ui/context_pill.rs b/crates/agent/src/ui/context_pill.rs index a3c6608179..c31af4b642 100644 --- a/crates/agent/src/ui/context_pill.rs +++ b/crates/agent/src/ui/context_pill.rs @@ -1,14 +1,13 @@ -use std::sync::Arc; use std::{rc::Rc, time::Duration}; use file_icons::FileIcons; -use futures::FutureExt; -use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between}; -use gpui::{ClickEvent, Task}; -use language_model::LanguageModelImage; +use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between}; +use project::Project; +use prompt_store::PromptStore; +use text::OffsetRangeExt; use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container}; -use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext}; +use crate::context::{AgentContext, ContextKind, ImageStatus}; #[derive(IntoElement)] pub enum ContextPill { @@ -73,9 +72,7 @@ impl ContextPill { pub fn id(&self) -> ElementId { match self { - Self::Added { context, .. } => { - ElementId::NamedInteger("context-pill".into(), context.id.0) - } + Self::Added { context, .. } => context.context.element_id("context-pill".into()), Self::Suggested { .. } => "suggested-context-pill".into(), } } @@ -199,14 +196,17 @@ impl RenderOnce for ContextPill { ) .when_some(on_remove.as_ref(), |element, on_remove| { element.child( - IconButton::new(("remove", context.id.0), IconName::Close) - .shape(IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .tooltip(Tooltip::text("Remove Context")) - .on_click({ - let on_remove = on_remove.clone(); - move |event, window, cx| on_remove(event, window, cx) - }), + IconButton::new( + context.context.element_id("remove".into()), + IconName::Close, + ) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .tooltip(Tooltip::text("Remove Context")) + .on_click({ + let on_remove = on_remove.clone(); + move |event, window, cx| on_remove(event, window, cx) + }), ) }) .when_some(on_click.as_ref(), |element, on_click| { @@ -262,9 +262,11 @@ pub enum ContextStatus { Error { message: SharedString }, } -#[derive(RegisterComponent)] +// TODO: Component commented out due to new dependency on `Project`. +// +// #[derive(RegisterComponent)] pub struct AddedContext { - pub id: ContextId, + pub context: AgentContext, pub kind: ContextKind, pub name: SharedString, pub parent: Option, @@ -275,10 +277,19 @@ pub struct AddedContext { } impl AddedContext { - pub fn new(context: &AssistantContext, cx: &App) -> AddedContext { + /// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a + /// `None` if `DirectoryContext` or `RulesContext` no longer exist. + /// + /// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak. + pub fn new( + context: AgentContext, + prompt_store: Option<&Entity>, + project: &Project, + cx: &App, + ) -> Option { match context { - AssistantContext::File(file_context) => { - let full_path = file_context.context_buffer.full_path(cx); + AgentContext::File(ref file_context) => { + let full_path = file_context.buffer.read(cx).file()?.full_path(cx); let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let name = full_path @@ -289,8 +300,7 @@ impl AddedContext { .parent() .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: file_context.id, + Some(AddedContext { kind: ContextKind::File, name, parent, @@ -298,18 +308,16 @@ impl AddedContext { icon_path: FileIcons::get_icon(&full_path, cx), status: ContextStatus::Ready, render_preview: None, - } + context, + }) } - AssistantContext::Directory(directory_context) => { - let worktree = directory_context.worktree.read(cx); - // If the directory no longer exists, use its last known path. - let full_path = worktree - .entry_for_id(directory_context.entry_id) - .map_or_else( - || directory_context.last_path.clone(), - |entry| worktree.full_path(&entry.path).into(), - ); + AgentContext::Directory(ref directory_context) => { + let worktree = project + .worktree_for_entry(directory_context.entry_id, cx)? + .read(cx); + let entry = worktree.entry_for_id(directory_context.entry_id)?; + let full_path = worktree.full_path(&entry.path); let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let name = full_path @@ -320,8 +328,7 @@ impl AddedContext { .parent() .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: directory_context.id, + Some(AddedContext { kind: ContextKind::Directory, name, parent, @@ -329,33 +336,34 @@ impl AddedContext { icon_path: None, status: ContextStatus::Ready, render_preview: None, - } + context, + }) } - AssistantContext::Symbol(symbol_context) => AddedContext { - id: symbol_context.id, + AgentContext::Symbol(ref symbol_context) => Some(AddedContext { kind: ContextKind::Symbol, - name: symbol_context.context_symbol.id.name.clone(), + name: symbol_context.symbol.clone(), parent: None, tooltip: None, icon_path: None, status: ContextStatus::Ready, render_preview: None, - }, + context, + }), - AssistantContext::Selection(selection_context) => { - let full_path = selection_context.context_buffer.full_path(cx); + AgentContext::Selection(ref selection_context) => { + let buffer = selection_context.buffer.read(cx); + let full_path = buffer.file()?.full_path(cx); let mut full_path_string = full_path.to_string_lossy().into_owned(); let mut name = full_path .file_name() .map(|n| n.to_string_lossy().into_owned()) .unwrap_or_else(|| full_path_string.clone()); - let line_range_text = format!( - " ({}-{})", - selection_context.line_range.start.row + 1, - selection_context.line_range.end.row + 1 - ); + let line_range = selection_context.range.to_point(&buffer.snapshot()); + + let line_range_text = + format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1); full_path_string.push_str(&line_range_text); name.push_str(&line_range_text); @@ -365,16 +373,17 @@ impl AddedContext { .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - AddedContext { - id: selection_context.id, + Some(AddedContext { kind: ContextKind::Selection, name: name.into(), parent, tooltip: None, icon_path: FileIcons::get_icon(&full_path, cx), status: ContextStatus::Ready, + render_preview: None, + /* render_preview: Some(Rc::new({ - let content = selection_context.context_buffer.text.clone(); + let content = selection_context.text.clone(); move |_, cx| { div() .id("context-pill-selection-preview") @@ -385,11 +394,12 @@ impl AddedContext { .into_any_element() } })), - } + */ + context, + }) } - AssistantContext::FetchedUrl(fetched_url_context) => AddedContext { - id: fetched_url_context.id, + AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext { kind: ContextKind::FetchedUrl, name: fetched_url_context.url.clone(), parent: None, @@ -397,12 +407,12 @@ impl AddedContext { icon_path: None, status: ContextStatus::Ready, render_preview: None, - }, + context, + }), - AssistantContext::Thread(thread_context) => AddedContext { - id: thread_context.id, + AgentContext::Thread(ref thread_context) => Some(AddedContext { kind: ContextKind::Thread, - name: thread_context.summary(cx), + name: thread_context.name(cx), parent: None, tooltip: None, icon_path: None, @@ -418,36 +428,41 @@ impl AddedContext { ContextStatus::Ready }, render_preview: None, - }, + context, + }), - AssistantContext::Rules(user_rules_context) => AddedContext { - id: user_rules_context.id, - kind: ContextKind::Rules, - name: user_rules_context.title.clone(), - parent: None, - tooltip: None, - icon_path: None, - status: ContextStatus::Ready, - render_preview: None, - }, + AgentContext::Rules(ref user_rules_context) => { + let name = prompt_store + .as_ref()? + .read(cx) + .metadata(user_rules_context.prompt_id.into())? + .title?; + Some(AddedContext { + kind: ContextKind::Rules, + name: name.clone(), + parent: None, + tooltip: None, + icon_path: None, + status: ContextStatus::Ready, + render_preview: None, + context, + }) + } - AssistantContext::Image(image_context) => AddedContext { - id: image_context.id, + AgentContext::Image(ref image_context) => Some(AddedContext { kind: ContextKind::Image, name: "Image".into(), parent: None, tooltip: None, icon_path: None, - status: if image_context.is_loading() { - ContextStatus::Loading { + status: match image_context.status() { + ImageStatus::Loading => ContextStatus::Loading { message: "Loading…".into(), - } - } else if image_context.is_error() { - ContextStatus::Error { + }, + ImageStatus::Error => ContextStatus::Error { message: "Failed to load image".into(), - } - } else { - ContextStatus::Ready + }, + ImageStatus::Ready => ContextStatus::Ready, }, render_preview: Some(Rc::new({ let image = image_context.original_image.clone(); @@ -458,7 +473,8 @@ impl AddedContext { .into_any_element() } })), - }, + context, + }), } } } @@ -478,6 +494,8 @@ impl Render for ContextPillPreview { } } +// TODO: Component commented out due to new dependency on `Project`. +/* impl Component for AddedContext { fn scope() -> ComponentScope { ComponentScope::Agent @@ -487,12 +505,13 @@ impl Component for AddedContext { "AddedContext" } - fn preview(_window: &mut Window, cx: &mut App) -> Option { + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + let next_context_id = ContextId::zero(); let image_ready = ( "Ready", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(0), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: Task::ready(Some(LanguageModelImage::empty())).shared(), }), @@ -503,8 +522,8 @@ impl Component for AddedContext { let image_loading = ( "Loading", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(1), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: cx .background_spawn(async move { @@ -520,8 +539,8 @@ impl Component for AddedContext { let image_error = ( "Error", AddedContext::new( - &AssistantContext::Image(ImageContext { - id: ContextId(2), + AgentContext::Image(ImageContext { + context_id: next_context_id.post_inc(), original_image: Arc::new(Image::empty()), image_task: Task::ready(None).shared(), }), @@ -544,5 +563,8 @@ impl Component for AddedContext { ) .into_any(), ) + + None } } +*/ diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 2e02f3f29a..d58ef2d58e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -25,7 +25,7 @@ use language_model::{ AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, }; use project::Project; -use prompt_store::{PromptBuilder, PromptId, UserPromptId}; +use prompt_store::{PromptBuilder, UserPromptId}; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search::DivRegistrar}; @@ -1059,9 +1059,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_select.map(|uuid| PromptId::User { - uuid: UserPromptId(uuid), - }), + action + .prompt_to_select + .map(|uuid| UserPromptId(uuid).into()), cx, ) .detach_and_log_err(cx); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index fccb9de7c8..c0c3c4cd99 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -10,7 +10,7 @@ use crate::{ ToolMetrics, assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, }; -use agent::ThreadEvent; +use agent::{ContextLoadResult, ThreadEvent}; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; @@ -115,7 +115,12 @@ impl ExampleContext { pub fn push_user_message(&mut self, text: impl ToString) { self.app .update_entity(&self.agent_thread, |thread, cx| { - thread.insert_user_message(text.to_string(), vec![], None, cx); + thread.insert_user_message( + text.to_string(), + ContextLoadResult::default(), + None, + cx, + ); }) .unwrap(); } diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 9210d0b818..e165506abf 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -218,8 +218,14 @@ impl ExampleInstance { }); let tools = cx.new(|_| ToolWorkingSet::default()); - let thread_store = - ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + let prompt_store = None; + let thread_store = ThreadStore::load( + project.clone(), + tools, + prompt_store, + app_state.prompt_builder.clone(), + cx, + ); let meta = self.thread.meta(); let this = self.clone(); diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index 84aaa688cd..b0d68fd416 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -60,9 +60,7 @@ pub enum PromptId { impl PromptId { pub fn new() -> PromptId { - PromptId::User { - uuid: UserPromptId::new(), - } + UserPromptId::new().into() } pub fn is_built_in(&self) -> bool { @@ -70,6 +68,12 @@ impl PromptId { } } +impl From for PromptId { + fn from(uuid: UserPromptId) -> Self { + PromptId::User { uuid } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct UserPromptId(pub Uuid); @@ -227,9 +231,7 @@ impl PromptStore { .collect::>>()?; for (prompt_id_v1, metadata_v1) in metadata_v1 { - let prompt_id_v2 = PromptId::User { - uuid: UserPromptId(prompt_id_v1.0), - }; + let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into(); let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { continue; };