Restructure agent context (#29233)

Simplifies the data structures involved in agent context by removing
caching and limiting the use of ContextId:

* `AssistantContext` enum is now like an ID / handle to context that
does not need to be updated. `ContextId` still exists but is only used
for generating unique `ElementId`.
* `ContextStore` has a `IndexMap<ContextSetEntry>`. Only need to keep a
`HashSet<ThreadId>` consistent with it. `ContextSetEntry` is a newtype
wrapper around `AssistantContext` which implements eq / hash on a subset
of fields.
* Thread `Message` directly stores its context.

Fixes the following bugs:

* If a context entry is removed from the strip and added again, it was
reincluded in the next message.
* Clicking file context in the thread that has been removed from the
context strip didn't jump to the file.
* Refresh of directory context didn't reflect added / removed files.
* Deleted directories would remain in the message editor context strip.
* Token counting requests didn't include image context.
* File, directory, and symbol context deduplication relied on
`ProjectPath` for identity, and so didn't handle renames.
* Symbol context line numbers didn't update when shifted

Known bugs (not fixed):

* Deleting a directory causes it to disappear from messages in threads.
Fixing this in a nice way is tricky. One easy fix is to store the
original path and show that on deletion. It's weird that deletion would
cause the name to "revert", though. Another possibility would be to
snapshot context metadata on add (ala `AddedContext`), and keep that
around despite deletion.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-04-24 15:29:33 -06:00 committed by GitHub
parent d492939bed
commit 17ecf94f6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1893 additions and 2061 deletions

22
Cargo.lock generated
View File

@ -61,7 +61,6 @@ dependencies = [
"buffer_diff",
"chrono",
"client",
"clock",
"collections",
"command_palette_hooks",
"component",
@ -99,6 +98,7 @@ dependencies = [
"prompt_store",
"proto",
"rand 0.8.5",
"ref-cast",
"release_channel",
"rope",
"rules_library",
@ -11716,6 +11716,26 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "ref-cast"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "refineable"
version = "0.1.0"

View File

@ -500,6 +500,7 @@ prost-types = "0.9"
pulldown-cmark = { version = "0.12.0", default-features = false }
quote = "1.0.9"
rand = "0.8.5"
ref-cast = "1.0.24"
rayon = "1.8"
regex = "1.5"
repair_json = "0.1.0"

View File

@ -1,2 +1,7 @@
allow-private-module-inception = true
avoid-breaking-exported-api = false
ignore-interior-mutability = [
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
# and Hash impls do not use fields with interior mutability.
"agent::context::AgentContextKey"
]

View File

@ -28,7 +28,6 @@ async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@ -65,6 +64,7 @@ project.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
ref-cast.workspace = true
release_channel.workspace = true
rope.workspace = true
schemars.workspace = true

View File

@ -1,4 +1,4 @@
use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
use crate::context::{AgentContext, RULES_ICON};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@ -25,8 +25,8 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role,
StopReason,
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
RequestUsage, Role, StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@ -47,13 +47,10 @@ use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use crate::context_store::ContextStore;
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
@ -717,7 +714,6 @@ impl ActiveThread {
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
@ -740,7 +736,6 @@ impl ActiveThread {
language_registry,
thread_store,
thread: thread.clone(),
context_store,
workspace,
save_thread_task: None,
messages: Vec::new(),
@ -780,10 +775,6 @@ impl ActiveThread {
this
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
pub fn thread(&self) -> &Entity<Thread> {
&self.thread
}
@ -1273,26 +1264,36 @@ impl ActiveThread {
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = thread.read(cx).context_for_message(message_id);
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let Some(message) = thread.read(cx).message(message_id) else {
log::error!("Message that was being edited no longer exists");
return None;
};
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
if message_text.is_empty() && message.loaded_context.is_empty() {
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
message
.loaded_context
.add_to_request_message(&mut request_message);
if !message_text.is_empty() {
request_message
.content
.push(MessageContent::Text(message_text));
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
@ -1487,13 +1488,21 @@ impl ActiveThread {
return Empty.into_any();
};
let context_store = self.context_store.clone();
let workspace = self.workspace.clone();
let thread = self.thread.read(cx);
let prompt_store = self.thread_store.read(cx).prompt_store().as_ref();
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id).collect::<Vec<_>>();
let added_context = if let Some(workspace) = workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
thread
.context_for_message(message_id)
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
.collect::<Vec<_>>()
} else {
return Empty.into_any();
};
let tool_uses = thread.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty();
@ -1641,90 +1650,78 @@ impl ActiveThread {
};
let message_is_empty = message.should_display_content();
let has_content = !message_is_empty || !context.is_empty();
let has_content = !message_is_empty || !added_context.is_empty();
let message_content =
has_content.then(|| {
v_flex()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let message_content = has_content.then(|| {
v_flex()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
context.into_iter().map(|context| {
let context_id = context.id();
ContextPill::added(
AddedContext::new(context, cx),
false,
false,
None,
)
.on_click(Rc::new(cx.listener({
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.context.clone();
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
let context_store = context_store.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(
context_id,
context_store.clone(),
workspace,
window,
cx,
);
open_context(&context, workspace, window, cx);
cx.notify();
}
}
})))
}),
))
})
});
}),
))
}),
))
})
});
let styled_message = match message.role {
Role::User => v_flex()
@ -3173,20 +3170,14 @@ impl Render for ActiveThread {
}
pub(crate) fn open_context(
id: ContextId,
context_store: Entity<ContextStore>,
context: &AgentContext,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut App,
) {
let Some(context) = context_store.read(cx).context_for_id(id) else {
return;
};
match context {
AssistantContext::File(file_context) => {
if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx)
{
AgentContext::File(file_context) => {
if let Some(project_path) = file_context.project_path(cx) {
workspace.update(cx, |workspace, cx| {
workspace
.open_path(project_path, None, true, window, cx)
@ -3194,7 +3185,8 @@ pub(crate) fn open_context(
});
}
}
AssistantContext::Directory(directory_context) => {
AgentContext::Directory(directory_context) => {
let entry_id = directory_context.entry_id;
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |_project, cx| {
@ -3202,61 +3194,51 @@ pub(crate) fn open_context(
})
})
}
AssistantContext::Symbol(symbol_context) => {
if let Some(project_path) = symbol_context
.context_symbol
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot();
let target_position = symbol_context
.context_symbol
.id
.range
.start
.to_point(&snapshot);
AgentContext::Symbol(symbol_context) => {
let buffer = symbol_context.buffer.read(cx);
if let Some(project_path) = buffer.project_path(cx) {
let snapshot = buffer.snapshot();
let target_position = symbol_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
AssistantContext::Selection(selection_context) => {
if let Some(project_path) = selection_context
.context_buffer
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot();
AgentContext::Selection(selection_context) => {
let buffer = selection_context.buffer.read(cx);
if let Some(project_path) = buffer.project_path(cx) {
let snapshot = buffer.snapshot();
let target_position = selection_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
AssistantContext::FetchedUrl(fetched_url_context) => {
AgentContext::FetchedUrl(fetched_url_context) => {
cx.open_url(&fetched_url_context.url);
}
AssistantContext::Thread(thread_context) => {
let thread_id = thread_context.thread.read(cx).id().clone();
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel.update(cx, |panel, cx| {
panel
.open_thread(&thread_id, window, cx)
.detach_and_log_err(cx)
});
}
})
}
AssistantContext::Rules(rules_context) => window.dispatch_action(
AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel.update(cx, |panel, cx| {
let thread_id = thread_context.thread.read(cx).id().clone();
panel
.open_thread(&thread_id, window, cx)
.detach_and_log_err(cx)
});
}
}),
AgentContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenRulesLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
AssistantContext::Image(_) => {}
AgentContext::Image(_) => {}
}
}

View File

@ -962,11 +962,13 @@ mod tests {
})
.unwrap();
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View File

@ -39,6 +39,7 @@ use thread::ThreadId;
pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;

View File

@ -24,7 +24,7 @@ use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library};
use settings::{Settings, update_settings_file};
@ -189,6 +189,7 @@ pub struct AssistantPanel {
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
local_timezone: UtcOffset,
@ -205,14 +206,25 @@ impl AssistantPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext,
mut cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
cx.spawn(async move |cx| {
let prompt_store = match prompt_store {
Ok(prompt_store) => prompt_store.await.ok(),
Err(_) => None,
};
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
ThreadStore::load(
project,
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
cx,
)
})?
.await?;
@ -230,7 +242,16 @@ impl AssistantPanel {
.await?;
workspace.update_in(cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
cx.new(|cx| {
Self::new(
workspace,
thread_store,
context_store,
prompt_store,
window,
cx,
)
})
})
})
}
@ -239,6 +260,7 @@ impl AssistantPanel {
workspace: &Workspace,
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -262,6 +284,7 @@ impl AssistantPanel {
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@ -293,7 +316,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
message_editor_context_store.clone(),
workspace.clone(),
window,
cx,
@ -322,6 +344,7 @@ impl AssistantPanel {
message_editor_subscription,
],
context_store,
prompt_store,
configuration: None,
configuration_subscription: None,
local_timezone: UtcOffset::from_whole_seconds(
@ -355,6 +378,10 @@ impl AssistantPanel {
self.local_timezone
}
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
&self.thread_store
}
@ -411,7 +438,6 @@ impl AssistantPanel {
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
message_editor_context_store.clone(),
self.workspace.clone(),
window,
cx,
@ -430,6 +456,7 @@ impl AssistantPanel {
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
window,
@ -500,9 +527,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@ -598,7 +625,6 @@ impl AssistantPanel {
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
message_editor_context_store.clone(),
this.workspace.clone(),
window,
cx,
@ -617,6 +643,7 @@ impl AssistantPanel {
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.prompt_store.clone(),
this.thread_store.downgrade(),
thread,
window,
@ -1876,11 +1903,14 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
else {
return;
};
let prompt_store = None;
let thread_store = None;
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
None,
prompt_store,
thread_store,
window,
cx,
)
@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if panel.has_active_thread() {
panel.thread.update(cx, |thread, cx| {
thread.context_store().update(cx, |store, cx| {
panel.message_editor.update(cx, |message_editor, cx| {
message_editor.context_store().update(cx, |store, cx| {
let buffer = buffer.read(cx);
let selection_ranges = selection_ranges
.into_iter()
@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
store.add_selection(buffer, range, cx);
}
})
})

View File

@ -1,6 +1,6 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::context::ContextLoadResult;
use crate::inline_prompt_editor::CodegenStatus;
use crate::{context::load_context, context_store::ContextStore};
use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
@ -8,7 +8,7 @@ use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
@ -16,7 +16,9 @@ use language_model::{
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use rope::Rope;
use smol::future::FutureExt;
use std::{
@ -41,6 +43,8 @@ pub struct BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
pub is_insertion: bool,
@ -52,6 +56,8 @@ impl BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@ -62,6 +68,8 @@ impl BufferCodegen {
range.clone(),
false,
Some(context_store.clone()),
project.clone(),
prompt_store.clone(),
Some(telemetry.clone()),
builder.clone(),
cx,
@ -77,6 +85,8 @@ impl BufferCodegen {
range,
initial_transaction_id,
context_store,
project,
prompt_store,
telemetry,
builder,
};
@ -155,6 +165,8 @@ impl BufferCodegen {
self.range.clone(),
false,
Some(self.context_store.clone()),
self.project.clone(),
self.prompt_store.clone(),
Some(self.telemetry.clone()),
self.builder.clone(),
cx,
@ -231,13 +243,14 @@ pub struct CodegenAlternative {
generation: Task<()>,
diff: Diff,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
active: bool,
edits: Vec<(Range<Anchor>, String)>,
line_operations: Vec<LineOperation>,
request: Option<LanguageModelRequest>,
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
@ -251,6 +264,8 @@ impl CodegenAlternative {
range: Range<Anchor>,
active: bool,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@ -292,6 +307,8 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
context_store,
project,
prompt_store,
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
@ -299,7 +316,6 @@ impl CodegenAlternative {
edits: Vec::new(),
line_operations: Vec::new(),
range,
request: None,
elapsed_time: None,
completion: None,
}
@ -368,16 +384,18 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(user_prompt, cx)?;
self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())
}
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
fn build_request(
&self,
user_prompt: String,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@ -410,30 +428,44 @@ impl CodegenAlternative {
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
let context_task = self.context_store.as_ref().map(|context_store| {
if let Some(project) = self.project.upgrade() {
let context = context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
load_context(context, &project, &self.prompt_store, cx)
} else {
Task::ready(ContextLoadResult::default())
}
});
if let Some(context_store) = &self.context_store {
attach_context_to_message(
&mut request_message,
context_store.read(cx).context().iter(),
cx,
);
}
Ok(cx.spawn(async move |_cx| {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
request_message.content.push(prompt.into());
if let Some(context_task) = context_task {
context_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
}
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
}
}))
}
pub fn handle_stream(
@ -1038,6 +1070,7 @@ impl Diff {
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use futures::{
Stream,
stream::{self},
@ -1080,12 +1113,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@ -1144,12 +1181,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@ -1211,12 +1252,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@ -1278,12 +1323,16 @@ mod tests {
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@ -1333,12 +1382,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
false,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,

File diff suppressed because it is too large Load Diff

View File

@ -10,8 +10,11 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
use gpui::{
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
@ -20,10 +23,10 @@ use gpui::{
use language::Buffer;
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use prompt_store::{PromptStore, UserPromptId};
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@ -193,6 +192,13 @@ impl ContextPicker {
)
.collect::<Vec<Subscription>>();
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
window,
@ -202,6 +208,7 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
prompt_store,
_subscriptions: subscriptions,
}
}
@ -226,7 +233,12 @@ impl ContextPicker {
.workspace
.upgrade()
.map(|workspace| {
available_context_picker_entries(&self.thread_store, &workspace, cx)
available_context_picker_entries(
&self.prompt_store,
&self.thread_store,
&workspace,
cx,
)
})
.unwrap_or_default();
@ -304,10 +316,10 @@ impl ContextPicker {
}));
}
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
if let Some(prompt_store) = self.prompt_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
prompt_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
@ -526,6 +538,7 @@ enum RecentEntry {
}
fn available_context_picker_entries(
prompt_store: &Option<Entity<PromptStore>>,
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
cx: &mut App,
@ -550,6 +563,9 @@ fn available_context_picker_entries(
if thread_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
}
if prompt_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
}
@ -585,22 +601,21 @@ fn recent_context_picker_entries(
}),
);
let mut current_threads = context_store.read(cx).thread_ids();
let current_threads = context_store.read(cx).thread_ids();
if let Some(active_thread) = workspace
let active_thread_id = workspace
.panel::<AssistantPanel>(cx)
.map(|panel| panel.read(cx).active_thread(cx))
{
current_threads.insert(active_thread.read(cx).id().clone());
}
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
recent.extend(
thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.filter(|thread| !current_threads.contains(&thread.id))
.filter(|thread| {
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
})
.take(2)
.map(|thread| {
RecentEntry::Thread(ThreadContextEntry {
@ -622,9 +637,7 @@ fn add_selections_as_context(
let selection_ranges = selection_ranges(workspace, cx);
context_store.update(cx, |context_store, cx| {
for (buffer, range) in selection_ranges {
context_store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
context_store.add_selection(buffer, range, cx);
}
})
}

View File

@ -15,22 +15,21 @@ use itertools::Itertools;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch;
use super::file_context_picker::{FileMatch, search_files};
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
use super::symbol_context_picker::search_symbols;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
@ -38,8 +37,8 @@ use super::{
};
pub(crate) enum Match {
Symbol(SymbolMatch),
File(FileMatch),
Symbol(SymbolMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
@ -69,6 +68,7 @@ fn search(
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
@ -85,6 +85,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
@ -96,6 +97,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
@ -111,6 +113,7 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
@ -118,10 +121,11 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
if let Some(prompt_store) = prompt_store.as_ref() {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
@ -133,6 +137,7 @@ fn search(
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
@ -163,7 +168,7 @@ fn search(
.collect::<Vec<_>>();
matches.extend(
available_context_picker_entries(&thread_store, &workspace, cx)
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
@ -180,7 +185,8 @@ fn search(
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
let entries =
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
let entry_candidates = entries
.iter()
.enumerate()
@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
move |_, _: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
context_store
.add_selection(buffer.clone(), range.clone(), cx)
.detach_and_log_err(cx)
context_store.add_selection(
buffer.clone(),
range.clone(),
cx,
);
}
});
@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
let user_prompt_id = rules.prompt_id;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx);
});
},
)),
}
@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
let url_to_fetch = url_to_fetch.clone();
cx.spawn(async move |cx| {
if context_store.update(cx, |context_store, _| {
context_store.includes_url(&url_to_fetch).is_some()
context_store.includes_url(&url_to_fetch)
})? {
return Ok(());
}
@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
move |cx| {
context_store.update(cx, |context_store, cx| {
let task = if is_directory {
context_store.add_directory(project_path.clone(), false, cx)
Task::ready(context_store.add_directory(&project_path, false, cx))
} else {
context_store.add_file_from_path(project_path.clone(), false, cx)
};
@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
);
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
prompt_store,
thread_store.clone(),
workspace.clone(),
cx,
@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
))
}
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
workspace.clone(),
cx,
),
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Rules(user_rules) => Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
)),
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,

View File

@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).includes_url(&self.url).is_some()
context_store.read(cx).includes_url(&self.url)
});
Some(

View File

@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
.context_store
.update(cx, |context_store, cx| {
if is_directory {
context_store.add_directory(project_path, true, cx)
Task::ready(context_store.add_directory(&project_path, true, cx))
} else {
context_store.add_file_from_path(project_path, true, cx)
context_store.add_file_from_path(project_path.clone(), true, cx)
}
})
.ok()
@ -325,11 +325,11 @@ pub fn render_file_context_entry(
path: path.clone(),
};
if is_directory {
context_store.read(cx).includes_directory(&project_path)
} else {
context_store
.read(cx)
.will_include_file_path(&project_path, cx)
.path_included_in_directory(&project_path, cx)
} else {
context_store.read(cx).file_path_included(&project_path, cx)
}
});
@ -357,7 +357,7 @@ pub fn render_file_context_entry(
})),
)
.when_some(added, |el, added| match added {
FileInclusion::Direct(_) => el.child(
FileInclusion::Direct => el.child(
h_flex()
.w_full()
.justify_end()
@ -369,9 +369,8 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
FileInclusion::InDirectory(directory_project_path) => {
// TODO: Consider using worktree full_path to include worktree name.
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
FileInclusion::InDirectory { full_path } => {
let directory_full_path = full_path.to_string_lossy().into_owned();
el.child(
h_flex()
@ -385,7 +384,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
.tooltip(Tooltip::text(format!("in {directory_path}")))
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
}
})
}

View File

@ -1,16 +1,15 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use prompt_store::{PromptId, PromptStore, UserPromptId};
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
@ -18,13 +17,13 @@ pub struct RulesContextPicker {
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
@ -50,7 +49,7 @@ pub struct RulesContextEntry {
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
prompt_store,
context_picker,
context_store,
matches: Vec::new(),
@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
let search_task = search_rules(
query,
Arc::new(AtomicBool::default()),
&self.prompt_store,
cx,
);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
self.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(entry.prompt_id, true, cx)
})
})
.detach_and_log_err(cx);
.log_err();
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
let added = context_store.upgrade().map_or(false, |context_store| {
context_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
.includes_user_rules(user_rules.prompt_id)
});
h_flex()
@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
prompt_store: &Entity<PromptStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task

View File

@ -10,7 +10,6 @@ use gpui::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use project::{DocumentSymbol, Symbol};
use text::OffsetRangeExt;
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
)
})?;
context_store
.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})?
.await
context_store.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})
})
}
@ -353,38 +350,13 @@ fn compute_symbol_entries(
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
let mut symbol_entries = Vec::with_capacity(symbols.len());
for SymbolMatch { symbol, .. } in symbols {
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
let is_included = if let Some(symbols_for_path) = symbols_for_path {
let mut is_included = false;
for included_symbol_id in symbols_for_path {
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
let snapshot = buffer.read(cx).snapshot();
let included_symbol_range =
included_symbol_id.range.to_point_utf16(&snapshot);
if included_symbol_range.start == symbol.range.start.0
&& included_symbol_range.end == symbol.range.end.0
{
is_included = true;
break;
}
}
}
}
is_included
} else {
false
};
symbol_entries.push(SymbolEntry {
symbols
.into_iter()
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
is_included: context_store.includes_symbol(&symbol, cx),
symbol,
is_included,
})
}
symbol_entries
.collect::<Vec<_>>()
}
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {

View File

@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store.read(cx).includes_thread(&thread.id).is_some()
ctx_store.read(cx).includes_thread(&thread.id)
});
h_flex()
@ -219,7 +219,7 @@ pub(crate) fn search_threads(
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,

File diff suppressed because it is too large Load Diff

View File

@ -12,9 +12,9 @@ use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::Workspace;
use crate::context::{ContextId, ContextKind};
use crate::context::{AgentContext, ContextKind};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
@ -32,6 +32,7 @@ pub struct ContextStrip {
focus_handle: FocusHandle,
suggest_context_kind: SuggestContextKind,
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
_subscriptions: Vec<Subscription>,
focused_index: Option<usize>,
children_bounds: Option<Vec<Bounds<Pixels>>>,
@ -73,12 +74,31 @@ impl ContextStrip {
focus_handle,
suggest_context_kind,
workspace,
thread_store,
_subscriptions: subscriptions,
focused_index: None,
children_bounds: None,
}
}
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
if let Some(workspace) = self.workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
let prompt_store = self
.thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
self.context_store
.read(cx)
.context()
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
.collect::<Vec<_>>()
} else {
Vec::new()
}
}
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
match self.suggest_context_kind {
SuggestContextKind::File => self.suggested_file(cx),
@ -93,22 +113,19 @@ impl ContextStrip {
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), &project_path)
.file_path_included(&project_path, cx)
.is_some()
{
return None;
}
let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
Some(SuggestedContext::File {
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
@ -135,7 +152,6 @@ impl ContextStrip {
.context_store
.read(cx)
.includes_thread(active_thread.id())
.is_some()
{
return None;
}
@ -272,12 +288,12 @@ impl ContextStrip {
best.map(|(index, _, _)| index)
}
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
crate::active_thread::open_context(context, workspace, window, cx);
}
fn remove_focused_context(
@ -287,17 +303,17 @@ impl ContextStrip {
cx: &mut Context<Self>,
) {
if let Some(index) = self.focused_index {
let mut is_empty = false;
let added_contexts = self.added_contexts(cx);
let Some(context) = added_contexts.get(index) else {
return;
};
self.context_store.update(cx, |this, cx| {
if let Some(item) = this.context().get(index) {
this.remove_context(item.id(), cx);
}
is_empty = this.context().is_empty();
this.remove_context(&context.context, cx);
});
if is_empty {
let is_now_empty = added_contexts.len() == 1;
if is_now_empty {
cx.emit(ContextStripEvent::BlurredEmpty);
} else {
self.focused_index = Some(index.saturating_sub(1));
@ -306,49 +322,28 @@ impl ContextStrip {
}
}
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
// We only suggest one item after the actual context
self.focused_index == Some(context.len())
self.focused_index == Some(added_contexts.len())
}
fn accept_suggested_context(
&mut self,
_: &AcceptSuggestedContext,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(suggested) = self.suggested_context(cx) {
let context_store = self.context_store.read(cx);
if self.is_suggested_focused(context_store.context()) {
self.add_suggested_context(&suggested, window, cx);
if self.is_suggested_focused(&self.added_contexts(cx)) {
self.add_suggested_context(&suggested, cx);
}
}
}
fn add_suggested_context(
&mut self,
suggested: &SuggestedContext,
window: &mut Window,
cx: &mut Context<Self>,
) {
let task = self.context_store.update(cx, |context_store, cx| {
context_store.accept_suggested_context(&suggested, cx)
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
self.context_store.update(cx, |context_store, cx| {
context_store.add_suggested_context(&suggested, cx)
});
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => {}
Some(()) => {
if let Some(this) = this.upgrade() {
this.update(cx, |_, cx| cx.notify())?;
}
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
cx.notify();
}
}
@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let context_store = self.context_store.read(cx);
let context = context_store.context();
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
let suggested_context = self.suggested_context(cx);
let added_contexts = context
.iter()
.map(|c| AddedContext::new(c, cx))
.collect::<Vec<_>>();
let added_contexts = self.added_contexts(cx);
let dupe_names = added_contexts
.iter()
.map(|c| c.name.clone())
@ -380,6 +368,14 @@ impl Render for ContextStrip {
.filter(|(a, b)| a == b)
.map(|(a, _)| a)
.collect::<HashSet<SharedString>>();
let no_added_context = added_contexts.is_empty();
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
(
suggested_context,
self.is_suggested_focused(&added_contexts),
)
});
h_flex()
.flex_wrap()
@ -436,7 +432,7 @@ impl Render for ContextStrip {
})
.with_handle(self.context_picker_menu_handle.clone()),
)
.when(context.is_empty() && suggested_context.is_none(), {
.when(no_added_context && suggested_context.is_none(), {
|parent| {
parent.child(
h_flex()
@ -466,16 +462,17 @@ impl Render for ContextStrip {
.enumerate()
.map(|(i, added_context)| {
let name = added_context.name.clone();
let id = added_context.id;
let context = added_context.context.clone();
ContextPill::added(
added_context,
dupe_names.contains(&name),
self.focused_index == Some(i),
Some({
let context = context.clone();
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, cx| {
this.remove_context(id, cx);
this.remove_context(&context, cx);
});
cx.notify();
}))
@ -484,7 +481,7 @@ impl Render for ContextStrip {
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
this.open_context(id, window, cx);
this.open_context(&context, window, cx);
} else {
this.focused_index = Some(i);
}
@ -493,22 +490,22 @@ impl Render for ContextStrip {
})
}),
)
.when_some(suggested_context, |el, suggested| {
.when_some(suggested_context, |el, (suggested, focused)| {
el.child(
ContextPill::suggested(
suggested.name().clone(),
suggested.icon_path(),
suggested.kind(),
self.is_suggested_focused(&context),
focused,
)
.on_click(Rc::new(cx.listener(
move |this, _event, window, cx| {
this.add_suggested_context(&suggested, window, cx);
move |this, _event, _window, cx| {
this.add_suggested_context(&suggested, cx);
},
))),
)
})
.when(!context.is_empty(), {
.when(!no_added_context, {
move |parent| {
parent.child(
IconButton::new("remove-all-context", IconName::Eraser)
@ -534,6 +531,7 @@ impl Render for ContextStrip {
)
}
})
.into_any()
}
}

View File

@ -51,7 +51,10 @@ impl HistoryStore {
return history_entries;
}
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
for thread in self
.thread_store
.update(cx, |this, _cx| this.reverse_chronological_threads())
{
history_entries.push(HistoryEntry::Thread(thread));
}

View File

@ -32,6 +32,7 @@ use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use settings::{Settings, SettingsStore};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
@ -245,9 +246,13 @@ impl InlineAssistant {
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
let assistant_panel = workspace
.panel::<AssistantPanel>(cx)
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
.map(|assistant_panel| assistant_panel.read(cx));
let prompt_store = assistant_panel
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
let thread_store =
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
let handle_assist =
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
@ -257,6 +262,7 @@ impl InlineAssistant {
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@ -269,6 +275,7 @@ impl InlineAssistant {
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@ -323,6 +330,7 @@ impl InlineAssistant {
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@ -437,6 +445,8 @@ impl InlineAssistant {
range.clone(),
None,
context_store.clone(),
project.clone(),
prompt_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@ -525,6 +535,7 @@ impl InlineAssistant {
initial_transaction_id: Option<TransactionId>,
focus: bool,
workspace: Entity<Workspace>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@ -543,7 +554,7 @@ impl InlineAssistant {
}
let project = workspace.read(cx).project().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
let codegen = cx.new(|cx| {
BufferCodegen::new(
@ -551,6 +562,8 @@ impl InlineAssistant {
range.clone(),
initial_transaction_id,
context_store.clone(),
project,
prompt_store,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
let prompt_store = PromptStore::global(cx);
window.spawn(cx, async move |cx| {
let workspace = workspace.upgrade().context("workspace was released")?;
let editor = editor.upgrade().context("editor was released")?;
@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
})?
.context("invalid range")?;
let prompt_store = prompt_store.await.ok();
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
let assist_id = assistant.suggest_assist(
&editor,
@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
None,
true,
workspace,
prompt_store,
thread_store,
window,
cx,

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use crate::context::{AssistantContext, format_context_as_string};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
@ -13,6 +13,8 @@ use editor::{
};
use file_icons::FileIcons;
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt as _, future};
use gpui::{
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@ -31,7 +34,7 @@ use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
@ -49,9 +52,12 @@ pub struct MessageEditor {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
last_loaded_context: Option<ContextLoadResult>,
context_load_task: Option<Shared<Task<()>>>,
profile_selector: Entity<ProfileSelector>,
edits_expanded: bool,
editor_is_expanded: bool,
@ -68,6 +74,7 @@ impl MessageEditor {
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
@ -135,13 +142,11 @@ impl MessageEditor {
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.message_or_context_changed(true, cx);
}
EditorEvent::BufferEdited => this.handle_message_changed(cx),
_ => {}
}),
cx.observe(&context_store, |this, _, cx| {
this.message_or_context_changed(false, cx);
this.handle_context_changed(cx)
}),
];
@ -152,8 +157,11 @@ impl MessageEditor {
incompatible_tools_state: incompatible_tools.clone(),
workspace,
context_store,
prompt_store,
context_strip,
context_picker_menu_handle,
context_load_task: None,
last_loaded_context: None,
model_selector: cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
@ -175,6 +183,10 @@ impl MessageEditor {
}
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
@ -214,6 +226,7 @@ impl MessageEditor {
) {
self.context_picker_menu_handle.toggle(window, cx);
}
pub fn remove_all_context(
&mut self,
_: &RemoveAllContext,
@ -270,68 +283,22 @@ impl MessageEditor {
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let context_task = self.wait_for_context(cx);
let window_handle = window.window_handle();
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
wait_for_images.await;
cx.spawn(async move |_this, cx| {
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
let loaded_context = loaded_context.unwrap_or_default();
thread
.update(cx, |thread, cx| {
let context = context_store.read(cx).context().clone();
thread.insert_user_message(user_message, context, checkpoint, cx);
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
})
.log_err();
context_store
.update(cx, |context_store, cx| {
let excerpt_ids = context_store
.context()
.iter()
.filter(|ctx| {
matches!(
ctx,
AssistantContext::Selection(_) | AssistantContext::Image(_)
)
})
.map(|ctx| ctx.id())
.collect::<Vec<_>>();
for id in excerpt_ids {
context_store.remove_context(id, cx);
}
})
.log_err();
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
.log_err()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
.log_err();
wait_for_summaries.await;
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
.log_err();
}
// Send to model after summaries are done
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
@ -342,6 +309,30 @@ impl MessageEditor {
.detach();
}
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
let context_store = self.context_store.clone();
cx.spawn(async move |this, cx| {
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
.ok()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
.ok();
wait_for_summaries.await;
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
.ok();
}
})
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cancelled = self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
@ -1015,6 +1006,49 @@ impl MessageEditor {
self.update_token_count_task.is_some()
}
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
self.message_or_context_changed(true, cx);
}
fn handle_context_changed(&mut self, cx: &mut Context<Self>) {
let summaries_task = self.wait_for_summaries(cx);
let load_task = cx.spawn(async move |this, cx| {
// Waits for detailed summaries before `load_context`, as it directly reads these from
// the thread. TODO: Would be cleaner to have context loading await on summarization.
summaries_task.await;
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx))
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
return;
};
let result = load_task.await;
this.update(cx, |this, cx| {
this.last_loaded_context = Some(result);
this.context_load_task = None;
this.message_or_context_changed(false, cx);
})
.ok();
});
// Replace existing load task, if any, causing it to be cancelled.
self.context_load_task = Some(load_task.shared());
}
fn wait_for_context(&self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
if let Some(context_load_task) = self.context_load_task.clone() {
cx.spawn(async move |this, cx| {
context_load_task.await;
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
.ok()
.flatten()
})
} else {
Task::ready(self.last_loaded_context.clone())
}
}
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
@ -1024,9 +1058,7 @@ impl MessageEditor {
return;
};
let context_store = self.context_store.clone();
let editor = self.editor.clone();
let thread = self.thread.clone();
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
@ -1035,27 +1067,33 @@ impl MessageEditor {
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = context_store.read(cx).context().iter();
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let token_count = if let Some(task) = this.update(cx, |this, cx| {
let loaded_context = this
.last_loaded_context
.as_ref()
.map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
{
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
if let Some(loaded_context) = loaded_context {
loaded_context.add_to_request_message(&mut request_message);
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,

View File

@ -32,7 +32,7 @@ impl TerminalCodegen {
}
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
@ -45,6 +45,7 @@ impl TerminalCodegen {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
let prompt = prompt_task.await;
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;

View File

@ -1,4 +1,4 @@
use crate::context::attach_context_to_message;
use crate::context::load_context;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::{
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::{PromptBuilder, PromptStore};
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::TerminalView;
@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
terminal_view: &Entity<TerminalView>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
prompt_editor,
workspace.clone(),
context_store,
prompt_store,
window,
cx,
);
@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
.log_err();
let codegen = assist.codegen.clone();
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
return;
};
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
}
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
&self,
assist_id: TerminalInlineAssistId,
cx: &mut App,
) -> Result<LanguageModelRequest> {
) -> Result<Task<LanguageModelRequest>> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let shell = std::env::var("SHELL").ok();
@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
&latest_output,
)?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
let contexts = assist
.context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
let project = workspace.project();
load_context(contexts, project, &assist.prompt_store, cx)
})?;
attach_context_to_message(
&mut request_message,
assist.context_store.read(cx).context().iter(),
cx,
);
Ok(cx.background_spawn(async move {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
request_message.content.push(prompt.into());
context_load_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
}
}))
}
fn finish_assist(
@ -380,6 +394,7 @@ struct TerminalInlineAssist {
codegen: Entity<TerminalCodegen>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@ -390,6 +405,7 @@ impl TerminalInlineAssist {
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut App,
) -> Self {
@ -400,6 +416,7 @@ impl TerminalInlineAssist {
codegen: codegen.clone(),
workspace: workspace.clone(),
context_store,
prompt_store,
_subscriptions: vec![
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
TerminalInlineAssistant::update_global(cx, |this, cx| {

View File

@ -8,7 +8,7 @@ use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
@ -18,9 +18,9 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
};
@ -35,7 +35,7 @@ use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
@ -98,8 +98,7 @@ pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
pub context: String,
pub images: Vec<LanguageModelImage>,
pub loaded_context: LoadedContext,
}
impl Message {
@ -138,8 +137,8 @@ impl Message {
pub fn to_string(&self) -> String {
let mut result = String::new();
if !self.context.is_empty() {
result.push_str(&self.context);
if !self.loaded_context.text.is_empty() {
result.push_str(&self.loaded_context.text);
}
for segment in &self.segments {
@ -294,8 +293,6 @@ pub struct Thread {
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
@ -345,8 +342,6 @@ impl Thread {
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@ -418,14 +413,15 @@ impl Thread {
}
})
.collect(),
context: message.context,
images: Vec::new(),
loaded_context: LoadedContext {
contexts: Vec::new(),
text: message.context,
images: Vec::new(),
},
})
.collect(),
next_message_id,
last_prompt_id: PromptId::new(),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@ -660,21 +656,17 @@ impl Thread {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
self.context_by_message.remove(&deleted_message.id);
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
self.context_by_message
.get(&id)
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
self.messages
.iter()
.find(|message| message.id == id)
.into_iter()
.flat_map(|context| {
context
.iter()
.filter_map(|context_id| self.context.get(&context_id))
})
.flat_map(|message| message.loaded_context.contexts.iter())
}
pub fn is_turn_end(&self, ix: usize) -> bool {
@ -736,91 +728,27 @@ impl Thread {
self.tool_use.tool_result_card(id).cloned()
}
/// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>(
&self,
context: impl Iterator<Item = &'a AssistantContext>,
) -> impl Iterator<Item = &'a AssistantContext> {
context.filter(|ctx| self.is_context_new(ctx))
}
fn is_context_new(&self, context: &AssistantContext) -> bool {
!self.context.contains_key(&context.id())
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
context: Vec<AssistantContext>,
loaded_context: ContextLoadResult,
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let text = text.into();
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| self.is_context_new(ctx))
.collect();
if !new_context.is_empty() {
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.context = context_string;
}
}
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.images = new_context
.iter()
.filter_map(|context| {
if let AssistantContext::Image(image_context) = context {
image_context.image_task.clone().now_or_never().flatten()
} else {
None
}
})
.collect::<Vec<_>>();
}
if !loaded_context.referenced_buffers.is_empty() {
self.action_log.update(cx, |log, cx| {
// Track all buffers added as context
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.track_buffer(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
}
AssistantContext::Selection(selection_context) => {
log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
}
AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_)
| AssistantContext::Image(_) => {}
}
for buffer in loaded_context.referenced_buffers {
log.track_buffer(buffer, cx);
}
});
}
let context_ids = new_context
.iter()
.map(|context| context.id())
.collect::<Vec<_>>();
self.context.extend(
new_context
.into_iter()
.map(|context| (context.id(), context)),
let message_id = self.insert_message(
Role::User,
vec![MessageSegment::Text(text.into())],
loaded_context.loaded_context,
cx,
);
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
@ -834,10 +762,19 @@ impl Thread {
message_id
}
pub fn insert_assistant_message(
&mut self,
segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
}
pub fn insert_message(
&mut self,
role: Role,
segments: Vec<MessageSegment>,
loaded_context: LoadedContext,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
@ -845,8 +782,7 @@ impl Thread {
id,
role,
segments,
context: String::new(),
images: Vec::new(),
loaded_context,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@ -875,7 +811,6 @@ impl Thread {
return false;
};
self.messages.remove(index);
self.context_by_message.remove(&id);
self.touch_updated_at();
cx.emit(ThreadEvent::MessageDeleted(id));
true
@ -962,7 +897,7 @@ impl Thread {
content: tool_result.content.clone(),
})
.collect(),
context: message.context.clone(),
context: message.loaded_context.text.clone(),
})
.collect(),
initial_project_snapshot,
@ -1080,26 +1015,9 @@ impl Thread {
cache: false,
};
if !message.context.is_empty() {
request_message
.content
.push(MessageContent::Text(message.context.to_string()));
}
if !message.images.is_empty() {
// Some providers only support image parts after an initial text part
if request_message.content.is_empty() {
request_message
.content
.push(MessageContent::Text("Images attached by user:".to_string()));
}
for image in &message.images {
request_message
.content
.push(MessageContent::Image(image.clone()))
}
}
message
.loaded_context
.add_to_request_message(&mut request_message);
for segment in &message.segments {
match segment {
@ -1301,11 +1219,11 @@ impl Thread {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = Some(thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(String::new())],
cx,
));
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(String::new())],
cx,
));
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@ -1334,11 +1252,11 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
request_assistant_message_id = Some(thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(chunk.to_string())],
cx,
));
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(chunk.to_string())],
cx,
));
};
}
}
@ -1361,14 +1279,14 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
request_assistant_message_id = Some(thread.insert_message(
Role::Assistant,
vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx,
));
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx,
));
};
}
}
@ -1376,7 +1294,7 @@ impl Thread {
let last_assistant_message_id = request_assistant_message_id
.unwrap_or_else(|| {
let new_assistant_message_id =
thread.insert_message(Role::Assistant, vec![], cx);
thread.insert_assistant_message(vec![], cx);
request_assistant_message_id =
Some(new_assistant_message_id);
new_assistant_message_id
@ -2097,8 +2015,16 @@ impl Thread {
}
)?;
if !message.context.is_empty() {
writeln!(markdown, "{}", message.context)?;
if !message.loaded_context.text.is_empty() {
writeln!(markdown, "{}", message.loaded_context.text)?;
}
if !message.loaded_context.images.is_empty() {
writeln!(
markdown,
"\n{} images attached as context.\n",
message.loaded_context.images.len()
)?;
}
for segment in &message.segments {
@ -2373,7 +2299,7 @@ struct PendingCompletion {
#[cfg(test)]
mod tests {
use super::*;
use crate::{ThreadStore, context_store::ContextStore, thread_store};
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use context_server::ContextServerSettings;
use editor::EditorSettings;
@ -2404,12 +2330,14 @@ mod tests {
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Please explain this code", vec![context], None, cx)
thread.insert_user_message("Please explain this code", loaded_context, None, cx)
});
// Check content and context in message object
@ -2443,7 +2371,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
assert_eq!(message.context, expected_context);
assert_eq!(message.loaded_context.text, expected_context);
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@ -2470,48 +2398,50 @@ fn main() {{
let (_, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open files individually
// First message with context 1
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
// Get the context objects
let contexts = context_store.update(cx, |store, _| store.context().clone());
assert_eq!(contexts.len(), 3);
// First message with context 1
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
thread.insert_user_message("Message 1", loaded_context, None, cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 2",
vec![contexts[0].clone(), contexts[1].clone()],
None,
cx,
)
thread.insert_user_message("Message 2", loaded_context, None, cx)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
//
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 3",
vec![
contexts[0].clone(),
contexts[1].clone(),
contexts[2].clone(),
],
None,
cx,
)
thread.insert_user_message("Message 3", loaded_context, None, cx)
});
// Check what contexts are included in each message
@ -2524,16 +2454,16 @@ fn main() {{
});
// First message should include context 1
assert!(message1.context.contains("file1.rs"));
assert!(message1.loaded_context.text.contains("file1.rs"));
// Second message should include only context 2 (not 1)
assert!(!message2.context.contains("file1.rs"));
assert!(message2.context.contains("file2.rs"));
assert!(!message2.loaded_context.text.contains("file1.rs"));
assert!(message2.loaded_context.text.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
assert!(!message3.context.contains("file1.rs"));
assert!(!message3.context.contains("file2.rs"));
assert!(message3.context.contains("file3.rs"));
assert!(!message3.loaded_context.text.contains("file1.rs"));
assert!(!message3.loaded_context.text.contains("file2.rs"));
assert!(message3.loaded_context.text.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@ -2570,7 +2500,12 @@ fn main() {{
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
thread.insert_user_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
cx,
)
});
// Check content and context in message object
@ -2583,7 +2518,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
assert_eq!(message.context, "");
assert_eq!(message.loaded_context.text, "");
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@ -2596,12 +2531,17 @@ fn main() {{
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Are there any good books?", vec![], None, cx)
thread.insert_user_message(
"Are there any good books?",
ContextLoadResult::default(),
None,
cx,
)
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
assert_eq!(message2.context, "");
assert_eq!(message2.loaded_context.text, "");
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@ -2635,12 +2575,14 @@ fn main() {{
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", vec![context], None, cx)
thread.insert_user_message("Explain this code", loaded_context, None, cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
@ -2668,7 +2610,12 @@ fn main() {{
// Insert another user message without context
thread.update(cx, |thread, cx| {
thread.insert_user_message("What does the code do now?", vec![], None, cx)
thread.insert_user_message(
"What does the code do now?",
ContextLoadResult::default(),
None,
cx,
)
});
// Create a new request and check for the stale buffer warning
@ -2735,6 +2682,7 @@ fn main() {{
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
@ -2759,15 +2707,15 @@ fn main() {{
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.update(cx, |project, cx| {
project.open_buffer(buffer_path.clone(), cx)
})
.await
.unwrap();
context_store
.update(cx, |store, cx| {
store.add_file_from_buffer(buffer.clone(), cx)
})
.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
});
Ok(buffer)
}

View File

@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@ -82,12 +82,11 @@ impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
let prompt_store = prompt_store.await.ok();
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
@ -349,25 +348,8 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
@ -379,16 +361,12 @@ impl ThreadStore {
self.threads.len()
}
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads
}
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
self.threads().into_iter().take(limit).collect()
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| {
Thread::new(

View File

@ -1,14 +1,13 @@
use std::sync::Arc;
use std::{rc::Rc, time::Duration};
use file_icons::FileIcons;
use futures::FutureExt;
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
use gpui::{ClickEvent, Task};
use language_model::LanguageModelImage;
use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between};
use project::Project;
use prompt_store::PromptStore;
use text::OffsetRangeExt;
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
use crate::context::{AgentContext, ContextKind, ImageStatus};
#[derive(IntoElement)]
pub enum ContextPill {
@ -73,9 +72,7 @@ impl ContextPill {
pub fn id(&self) -> ElementId {
match self {
Self::Added { context, .. } => {
ElementId::NamedInteger("context-pill".into(), context.id.0)
}
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
Self::Suggested { .. } => "suggested-context-pill".into(),
}
}
@ -199,14 +196,17 @@ impl RenderOnce for ContextPill {
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
IconButton::new(("remove", context.id.0), IconName::Close)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
IconButton::new(
context.context.element_id("remove".into()),
IconName::Close,
)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
)
})
.when_some(on_click.as_ref(), |element, on_click| {
@ -262,9 +262,11 @@ pub enum ContextStatus {
Error { message: SharedString },
}
#[derive(RegisterComponent)]
// TODO: Component commented out due to new dependency on `Project`.
//
// #[derive(RegisterComponent)]
pub struct AddedContext {
pub id: ContextId,
pub context: AgentContext,
pub kind: ContextKind,
pub name: SharedString,
pub parent: Option<SharedString>,
@ -275,10 +277,19 @@ pub struct AddedContext {
}
impl AddedContext {
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
///
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
pub fn new(
context: AgentContext,
prompt_store: Option<&Entity<PromptStore>>,
project: &Project,
cx: &App,
) -> Option<AddedContext> {
match context {
AssistantContext::File(file_context) => {
let full_path = file_context.context_buffer.full_path(cx);
AgentContext::File(ref file_context) => {
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@ -289,8 +300,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: file_context.id,
Some(AddedContext {
kind: ContextKind::File,
name,
parent,
@ -298,18 +308,16 @@ impl AddedContext {
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Directory(directory_context) => {
let worktree = directory_context.worktree.read(cx);
// If the directory no longer exists, use its last known path.
let full_path = worktree
.entry_for_id(directory_context.entry_id)
.map_or_else(
|| directory_context.last_path.clone(),
|entry| worktree.full_path(&entry.path).into(),
);
AgentContext::Directory(ref directory_context) => {
let worktree = project
.worktree_for_entry(directory_context.entry_id, cx)?
.read(cx);
let entry = worktree.entry_for_id(directory_context.entry_id)?;
let full_path = worktree.full_path(&entry.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@ -320,8 +328,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: directory_context.id,
Some(AddedContext {
kind: ContextKind::Directory,
name,
parent,
@ -329,33 +336,34 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Symbol(symbol_context) => AddedContext {
id: symbol_context.id,
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
kind: ContextKind::Symbol,
name: symbol_context.context_symbol.id.name.clone(),
name: symbol_context.symbol.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Selection(selection_context) => {
let full_path = selection_context.context_buffer.full_path(cx);
AgentContext::Selection(ref selection_context) => {
let buffer = selection_context.buffer.read(cx);
let full_path = buffer.file()?.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| full_path_string.clone());
let line_range_text = format!(
" ({}-{})",
selection_context.line_range.start.row + 1,
selection_context.line_range.end.row + 1
);
let line_range = selection_context.range.to_point(&buffer.snapshot());
let line_range_text =
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
full_path_string.push_str(&line_range_text);
name.push_str(&line_range_text);
@ -365,16 +373,17 @@ impl AddedContext {
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: selection_context.id,
Some(AddedContext {
kind: ContextKind::Selection,
name: name.into(),
parent,
tooltip: None,
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
/*
render_preview: Some(Rc::new({
let content = selection_context.context_buffer.text.clone();
let content = selection_context.text.clone();
move |_, cx| {
div()
.id("context-pill-selection-preview")
@ -385,11 +394,12 @@ impl AddedContext {
.into_any_element()
}
})),
}
*/
context,
})
}
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
id: fetched_url_context.id,
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
kind: ContextKind::FetchedUrl,
name: fetched_url_context.url.clone(),
parent: None,
@ -397,12 +407,12 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Thread(thread_context) => AddedContext {
id: thread_context.id,
AgentContext::Thread(ref thread_context) => Some(AddedContext {
kind: ContextKind::Thread,
name: thread_context.summary(cx),
name: thread_context.name(cx),
parent: None,
tooltip: None,
icon_path: None,
@ -418,36 +428,41 @@ impl AddedContext {
ContextStatus::Ready
},
render_preview: None,
},
context,
}),
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
AgentContext::Rules(ref user_rules_context) => {
let name = prompt_store
.as_ref()?
.read(cx)
.metadata(user_rules_context.prompt_id.into())?
.title?;
Some(AddedContext {
kind: ContextKind::Rules,
name: name.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
context,
})
}
AssistantContext::Image(image_context) => AddedContext {
id: image_context.id,
AgentContext::Image(ref image_context) => Some(AddedContext {
kind: ContextKind::Image,
name: "Image".into(),
parent: None,
tooltip: None,
icon_path: None,
status: if image_context.is_loading() {
ContextStatus::Loading {
status: match image_context.status() {
ImageStatus::Loading => ContextStatus::Loading {
message: "Loading…".into(),
}
} else if image_context.is_error() {
ContextStatus::Error {
},
ImageStatus::Error => ContextStatus::Error {
message: "Failed to load image".into(),
}
} else {
ContextStatus::Ready
},
ImageStatus::Ready => ContextStatus::Ready,
},
render_preview: Some(Rc::new({
let image = image_context.original_image.clone();
@ -458,7 +473,8 @@ impl AddedContext {
.into_any_element()
}
})),
},
context,
}),
}
}
}
@ -478,6 +494,8 @@ impl Render for ContextPillPreview {
}
}
// TODO: Component commented out due to new dependency on `Project`.
/*
impl Component for AddedContext {
fn scope() -> ComponentScope {
ComponentScope::Agent
@ -487,12 +505,13 @@ impl Component for AddedContext {
"AddedContext"
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let next_context_id = ContextId::zero();
let image_ready = (
"Ready",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(0),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
}),
@ -503,8 +522,8 @@ impl Component for AddedContext {
let image_loading = (
"Loading",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(1),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: cx
.background_spawn(async move {
@ -520,8 +539,8 @@ impl Component for AddedContext {
let image_error = (
"Error",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(2),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
}),
@ -544,5 +563,8 @@ impl Component for AddedContext {
)
.into_any(),
)
None
}
}
*/

View File

@ -25,7 +25,7 @@ use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::Project;
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, UserPromptId};
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
@ -1059,9 +1059,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);

View File

@ -10,7 +10,7 @@ use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
};
use agent::ThreadEvent;
use agent::{ContextLoadResult, ThreadEvent};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
@ -115,7 +115,12 @@ impl ExampleContext {
pub fn push_user_message(&mut self, text: impl ToString) {
self.app
.update_entity(&self.agent_thread, |thread, cx| {
thread.insert_user_message(text.to_string(), vec![], None, cx);
thread.insert_user_message(
text.to_string(),
ContextLoadResult::default(),
None,
cx,
);
})
.unwrap();
}

View File

@ -218,8 +218,14 @@ impl ExampleInstance {
});
let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let prompt_store = None;
let thread_store = ThreadStore::load(
project.clone(),
tools,
prompt_store,
app_state.prompt_builder.clone(),
cx,
);
let meta = self.thread.meta();
let this = self.clone();

View File

@ -60,9 +60,7 @@ pub enum PromptId {
impl PromptId {
pub fn new() -> PromptId {
PromptId::User {
uuid: UserPromptId::new(),
}
UserPromptId::new().into()
}
pub fn is_built_in(&self) -> bool {
@ -70,6 +68,12 @@ impl PromptId {
}
}
impl From<UserPromptId> for PromptId {
fn from(uuid: UserPromptId) -> Self {
PromptId::User { uuid }
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserPromptId(pub Uuid);
@ -227,9 +231,7 @@ impl PromptStore {
.collect::<heed::Result<HashMap<_, _>>>()?;
for (prompt_id_v1, metadata_v1) in metadata_v1 {
let prompt_id_v2 = PromptId::User {
uuid: UserPromptId(prompt_id_v1.0),
};
let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
continue;
};