diff --git a/Cargo.lock b/Cargo.lock index bf843943f5..d30e29f83c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -718,7 +718,6 @@ dependencies = [ "itertools 0.14.0", "language", "language_model", - "pretty_assertions", "project", "rand 0.8.5", "release_channel", @@ -728,7 +727,6 @@ dependencies = [ "settings", "theme", "ui", - "unindent", "util", "workspace", "worktree", diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 0facf8e2cb..0c765e8074 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -39,7 +39,5 @@ rand.workspace = true collections = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } -pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } -unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 10a2454c3d..fe2ec0ecc9 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -1,6 +1,5 @@ mod edit_action; pub mod log; -mod resolve_search_block; use anyhow::{anyhow, Context, Result}; use assistant_tool::{ActionLog, Tool}; @@ -8,17 +7,16 @@ use collections::HashSet; use edit_action::{EditAction, EditActionParser}; use futures::StreamExt; use gpui::{App, AsyncApp, Entity, Task}; -use language::OffsetRangeExt; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, }; use log::{EditToolLog, EditToolRequestId}; -use project::Project; -use resolve_search_block::resolve_search_block; +use project::{search::SearchQuery, Project}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::sync::Arc; +use util::paths::PathMatcher; use util::ResultExt; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -131,11 +129,24 @@ struct EditToolRequest { parser: EditActionParser, output: String, changed_buffers: HashSet>, + bad_searches: Vec, project: Entity, action_log: Entity, tool_log: Option<(Entity, EditToolRequestId)>, } +#[derive(Debug)] +enum DiffResult { + BadSearch(BadSearch), + Diff(language::Diff), +} + +#[derive(Debug)] +struct BadSearch { + file_path: String, + search: String, +} + impl EditToolRequest { fn new( input: EditFilesToolInput, @@ -193,6 +204,7 @@ impl EditToolRequest { // we start with the success header so we don't need to shift the output in the common case output: Self::SUCCESS_OUTPUT_HEADER.to_string(), changed_buffers: HashSet::default(), + bad_searches: Vec::new(), action_log, project, tool_log, @@ -239,30 +251,36 @@ impl EditToolRequest { .update(cx, |project, cx| project.open_buffer(project_path, cx))? .await?; - let diff = match action { + let result = match action { EditAction::Replace { old, new, - file_path: _, + file_path, } => { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let diff = cx - .background_executor() - .spawn(Self::replace_diff(old, new, snapshot)) - .await; - - anyhow::Ok(diff) + cx.background_executor() + .spawn(Self::replace_diff(old, new, file_path, snapshot)) + .await } - EditAction::Write { content, .. } => Ok(buffer - .read_with(cx, |buffer, cx| buffer.diff(content, cx))? - .await), + EditAction::Write { content, .. } => Ok(DiffResult::Diff( + buffer + .read_with(cx, |buffer, cx| buffer.diff(content, cx))? + .await, + )), }?; - let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; + match result { + DiffResult::BadSearch(invalid_replace) => { + self.bad_searches.push(invalid_replace); + } + DiffResult::Diff(diff) => { + let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; - write!(&mut self.output, "\n\n{}", source)?; - self.changed_buffers.insert(buffer); + write!(&mut self.output, "\n\n{}", source)?; + self.changed_buffers.insert(buffer); + } + } Ok(()) } @@ -270,9 +288,29 @@ impl EditToolRequest { async fn replace_diff( old: String, new: String, + file_path: std::path::PathBuf, snapshot: language::BufferSnapshot, - ) -> language::Diff { - let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot); + ) -> Result { + let query = SearchQuery::text( + old.clone(), + false, + true, + true, + PathMatcher::new(&[])?, + PathMatcher::new(&[])?, + None, + )?; + + let matches = query.search(&snapshot, None).await; + + if matches.is_empty() { + return Ok(DiffResult::BadSearch(BadSearch { + search: new.clone(), + file_path: file_path.display().to_string(), + })); + } + + let edit_range = matches[0].clone(); let diff = language::text_diff(&old, &new); let edits = diff @@ -290,7 +328,7 @@ impl EditToolRequest { edits, }; - diff + anyhow::Ok(DiffResult::Diff(diff)) } const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:"; @@ -314,7 +352,7 @@ impl EditToolRequest { let errors = self.parser.errors(); - if errors.is_empty() { + if errors.is_empty() && self.bad_searches.is_empty() { if changed_buffer_count == 0 { return Err(anyhow!( "The instructions didn't lead to any changes. You might need to consult the file contents first." @@ -337,6 +375,24 @@ impl EditToolRequest { ); } + if !self.bad_searches.is_empty() { + writeln!( + &mut output, + "\n\nThese searches failed because they didn't match any strings:" + )?; + + for replace in self.bad_searches { + writeln!( + &mut output, + "- '{}' does not appear in `{}`", + replace.search.replace("\r", "\\r").replace("\n", "\\n"), + replace.file_path + )?; + } + + write!(&mut output, "Make sure to use exact searches.")?; + } + if !errors.is_empty() { writeln!( &mut output, diff --git a/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs b/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs deleted file mode 100644 index 5d2f61f8bb..0000000000 --- a/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs +++ /dev/null @@ -1,226 +0,0 @@ -use language::{Anchor, Bias, BufferSnapshot}; -use std::ops::Range; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -enum SearchDirection { - Up, - Left, - Diagonal, -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct SearchState { - cost: u32, - direction: SearchDirection, -} - -impl SearchState { - fn new(cost: u32, direction: SearchDirection) -> Self { - Self { cost, direction } - } -} - -struct SearchMatrix { - cols: usize, - data: Vec, -} - -impl SearchMatrix { - fn new(rows: usize, cols: usize) -> Self { - SearchMatrix { - cols, - data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols], - } - } - - fn get(&self, row: usize, col: usize) -> SearchState { - self.data[row * self.cols + col] - } - - fn set(&mut self, row: usize, col: usize, cost: SearchState) { - self.data[row * self.cols + col] = cost; - } -} - -pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range { - const INSERTION_COST: u32 = 3; - const DELETION_COST: u32 = 10; - const WHITESPACE_INSERTION_COST: u32 = 1; - const WHITESPACE_DELETION_COST: u32 = 1; - - let buffer_len = buffer.len(); - let query_len = search_query.len(); - let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1); - let mut leading_deletion_cost = 0_u32; - for (row, query_byte) in search_query.bytes().enumerate() { - let deletion_cost = if query_byte.is_ascii_whitespace() { - WHITESPACE_DELETION_COST - } else { - DELETION_COST - }; - - leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost); - matrix.set( - row + 1, - 0, - SearchState::new(leading_deletion_cost, SearchDirection::Diagonal), - ); - - for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() { - let insertion_cost = if buffer_byte.is_ascii_whitespace() { - WHITESPACE_INSERTION_COST - } else { - INSERTION_COST - }; - - let up = SearchState::new( - matrix.get(row, col + 1).cost.saturating_add(deletion_cost), - SearchDirection::Up, - ); - let left = SearchState::new( - matrix.get(row + 1, col).cost.saturating_add(insertion_cost), - SearchDirection::Left, - ); - let diagonal = SearchState::new( - if query_byte == *buffer_byte { - matrix.get(row, col).cost - } else { - matrix - .get(row, col) - .cost - .saturating_add(deletion_cost + insertion_cost) - }, - SearchDirection::Diagonal, - ); - matrix.set(row + 1, col + 1, up.min(left).min(diagonal)); - } - } - - // Traceback to find the best match - let mut best_buffer_end = buffer_len; - let mut best_cost = u32::MAX; - for col in 1..=buffer_len { - let cost = matrix.get(query_len, col).cost; - if cost < best_cost { - best_cost = cost; - best_buffer_end = col; - } - } - - let mut query_ix = query_len; - let mut buffer_ix = best_buffer_end; - while query_ix > 0 && buffer_ix > 0 { - let current = matrix.get(query_ix, buffer_ix); - match current.direction { - SearchDirection::Diagonal => { - query_ix -= 1; - buffer_ix -= 1; - } - SearchDirection::Up => { - query_ix -= 1; - } - SearchDirection::Left => { - buffer_ix -= 1; - } - } - } - - let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left)); - start.column = 0; - let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right)); - if end.column > 0 { - end.column = buffer.line_len(end.row); - } - - buffer.anchor_after(start)..buffer.anchor_before(end) -} - -#[cfg(test)] -mod tests { - use crate::edit_files_tool::resolve_search_block::resolve_search_block; - use gpui::{prelude::*, App}; - use language::{Buffer, OffsetRangeExt as _}; - use unindent::Unindent as _; - use util::test::{generate_marked_text, marked_text_ranges}; - - #[gpui::test] - fn test_resolve_search_block(cx: &mut App) { - assert_resolved( - concat!( - " Lorem\n", - "« ipsum\n", - " dolor sit amet»\n", - " consecteur", - ), - "ipsum\ndolor", - cx, - ); - - assert_resolved( - &" - «fn foo1(a: usize) -> usize { - 40 - }» - - fn foo2(b: usize) -> usize { - 42 - } - " - .unindent(), - "fn foo1(b: usize) {\n40\n}", - cx, - ); - - assert_resolved( - &" - fn main() { - « Foo - .bar() - .baz() - .qux()» - } - - fn foo2(b: usize) -> usize { - 42 - } - " - .unindent(), - "Foo.bar.baz.qux()", - cx, - ); - - assert_resolved( - &" - class Something { - one() { return 1; } - « two() { return 2222; } - three() { return 333; } - four() { return 4444; } - five() { return 5555; } - six() { return 6666; } - » seven() { return 7; } - eight() { return 8; } - } - " - .unindent(), - &" - two() { return 2222; } - four() { return 4444; } - five() { return 5555; } - six() { return 6666; } - " - .unindent(), - cx, - ); - } - - #[track_caller] - fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) { - let (text, _) = marked_text_ranges(text_with_expected_range, false); - let buffer = cx.new(|cx| Buffer::local(text.clone(), cx)); - let snapshot = buffer.read(cx).snapshot(); - let range = resolve_search_block(&snapshot, query).to_offset(&snapshot); - let text_with_actual_range = generate_marked_text(&text, &[range], false); - pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range); - } -}