assistant edit tool: Revert fuzzy matching (#26996)

#26935 is leading to bad edits, so let's revert it for now. I'll bring
back a version of this, but it'll likely just focus on indentation
instead of making the whole search fuzzy.

Release Notes: 

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-18 13:08:09 -03:00 committed by GitHub
parent 06e9f0e309
commit 5615be51cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 252 deletions

2
Cargo.lock generated
View File

@ -718,7 +718,6 @@ dependencies = [
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model", "language_model",
"pretty_assertions",
"project", "project",
"rand 0.8.5", "rand 0.8.5",
"release_channel", "release_channel",
@ -728,7 +727,6 @@ dependencies = [
"settings", "settings",
"theme", "theme",
"ui", "ui",
"unindent",
"util", "util",
"workspace", "workspace",
"worktree", "worktree",

View File

@ -39,7 +39,5 @@ rand.workspace = true
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] }
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] }

View File

@ -1,6 +1,5 @@
mod edit_action; mod edit_action;
pub mod log; pub mod log;
mod resolve_search_block;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool}; use assistant_tool::{ActionLog, Tool};
@ -8,17 +7,16 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser}; use edit_action::{EditAction, EditActionParser};
use futures::StreamExt; use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task}; use gpui::{App, AsyncApp, Entity, Task};
use language::OffsetRangeExt;
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
}; };
use log::{EditToolLog, EditToolRequestId}; use log::{EditToolLog, EditToolRequestId};
use project::Project; use project::{search::SearchQuery, Project};
use resolve_search_block::resolve_search_block;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Write; use std::fmt::Write;
use std::sync::Arc; use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt; use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -131,11 +129,24 @@ struct EditToolRequest {
parser: EditActionParser, parser: EditActionParser,
output: String, output: String,
changed_buffers: HashSet<Entity<language::Buffer>>, changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>, tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
} }
#[derive(Debug)]
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest { impl EditToolRequest {
fn new( fn new(
input: EditFilesToolInput, input: EditFilesToolInput,
@ -193,6 +204,7 @@ impl EditToolRequest {
// we start with the success header so we don't need to shift the output in the common case // we start with the success header so we don't need to shift the output in the common case
output: Self::SUCCESS_OUTPUT_HEADER.to_string(), output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
changed_buffers: HashSet::default(), changed_buffers: HashSet::default(),
bad_searches: Vec::new(),
action_log, action_log,
project, project,
tool_log, tool_log,
@ -239,30 +251,36 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))? .update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?; .await?;
let diff = match action { let result = match action {
EditAction::Replace { EditAction::Replace {
old, old,
new, new,
file_path: _, file_path,
} => { } => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let diff = cx cx.background_executor()
.background_executor() .spawn(Self::replace_diff(old, new, file_path, snapshot))
.spawn(Self::replace_diff(old, new, snapshot)) .await
.await;
anyhow::Ok(diff)
} }
EditAction::Write { content, .. } => Ok(buffer EditAction::Write { content, .. } => Ok(DiffResult::Diff(
.read_with(cx, |buffer, cx| buffer.diff(content, cx))? buffer
.await), .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await,
)),
}?; }?;
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; match result {
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
write!(&mut self.output, "\n\n{}", source)?; write!(&mut self.output, "\n\n{}", source)?;
self.changed_buffers.insert(buffer); self.changed_buffers.insert(buffer);
}
}
Ok(()) Ok(())
} }
@ -270,9 +288,29 @@ impl EditToolRequest {
async fn replace_diff( async fn replace_diff(
old: String, old: String,
new: String, new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot, snapshot: language::BufferSnapshot,
) -> language::Diff { ) -> Result<DiffResult> {
let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot); let query = SearchQuery::text(
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new); let diff = language::text_diff(&old, &new);
let edits = diff let edits = diff
@ -290,7 +328,7 @@ impl EditToolRequest {
edits, edits,
}; };
diff anyhow::Ok(DiffResult::Diff(diff))
} }
const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:"; const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@ -314,7 +352,7 @@ impl EditToolRequest {
let errors = self.parser.errors(); let errors = self.parser.errors();
if errors.is_empty() { if errors.is_empty() && self.bad_searches.is_empty() {
if changed_buffer_count == 0 { if changed_buffer_count == 0 {
return Err(anyhow!( return Err(anyhow!(
"The instructions didn't lead to any changes. You might need to consult the file contents first." "The instructions didn't lead to any changes. You might need to consult the file contents first."
@ -337,6 +375,24 @@ impl EditToolRequest {
); );
} }
if !self.bad_searches.is_empty() {
writeln!(
&mut output,
"\n\nThese searches failed because they didn't match any strings:"
)?;
for replace in self.bad_searches {
writeln!(
&mut output,
"- '{}' does not appear in `{}`",
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
replace.file_path
)?;
}
write!(&mut output, "Make sure to use exact searches.")?;
}
if !errors.is_empty() { if !errors.is_empty() {
writeln!( writeln!(
&mut output, &mut output,

View File

@ -1,226 +0,0 @@
use language::{Anchor, Bias, BufferSnapshot};
use std::ops::Range;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(rows: usize, cols: usize) -> Self {
SearchMatrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
const WHITESPACE_INSERTION_COST: u32 = 1;
const WHITESPACE_DELETION_COST: u32 = 1;
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
let mut leading_deletion_cost = 0_u32;
for (row, query_byte) in search_query.bytes().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
);
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).cost
} else {
matrix
.get(row, col)
.cost
.saturating_add(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_cost = u32::MAX;
for col in 1..=buffer_len {
let cost = matrix.get(query_len, col).cost;
if cost < best_cost {
best_cost = cost;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
#[cfg(test)]
mod tests {
use crate::edit_files_tool::resolve_search_block::resolve_search_block;
use gpui::{prelude::*, App};
use language::{Buffer, OffsetRangeExt as _};
use unindent::Unindent as _;
use util::test::{generate_marked_text, marked_text_ranges};
#[gpui::test]
fn test_resolve_search_block(cx: &mut App) {
assert_resolved(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor",
cx,
);
assert_resolved(
&"
«fn foo1(a: usize) -> usize {
40
}»
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"fn foo1(b: usize) {\n40\n}",
cx,
);
assert_resolved(
&"
fn main() {
« Foo
.bar()
.baz()
.qux()»
}
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"Foo.bar.baz.qux()",
cx,
);
assert_resolved(
&"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
» seven() { return 7; }
eight() { return 8; }
}
"
.unindent(),
&"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"
.unindent(),
cx,
);
}
#[track_caller]
fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
let (text, _) = marked_text_ranges(text_with_expected_range, false);
let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
let snapshot = buffer.read(cx).snapshot();
let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
let text_with_actual_range = generate_marked_text(&text, &[range], false);
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
}
}