assistant edit tool: Revert fuzzy matching (#26996)

#26935 is leading to bad edits, so let's revert it for now. I'll bring
back a version of this, but it'll likely just focus on indentation
instead of making the whole search fuzzy.

Release Notes: 

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-18 13:08:09 -03:00 committed by GitHub
parent 06e9f0e309
commit 5615be51cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 252 deletions

2
Cargo.lock generated
View File

@ -718,7 +718,6 @@ dependencies = [
"itertools 0.14.0",
"language",
"language_model",
"pretty_assertions",
"project",
"rand 0.8.5",
"release_channel",
@ -728,7 +727,6 @@ dependencies = [
"settings",
"theme",
"ui",
"unindent",
"util",
"workspace",
"worktree",

View File

@ -39,7 +39,5 @@ rand.workspace = true
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] }

View File

@ -1,6 +1,5 @@
mod edit_action;
pub mod log;
mod resolve_search_block;
use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool};
@ -8,17 +7,16 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
use language::OffsetRangeExt;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use log::{EditToolLog, EditToolRequestId};
use project::Project;
use resolve_search_block::resolve_search_block;
use project::{search::SearchQuery, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -131,11 +129,24 @@ struct EditToolRequest {
parser: EditActionParser,
output: String,
changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
#[derive(Debug)]
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest {
fn new(
input: EditFilesToolInput,
@ -193,6 +204,7 @@ impl EditToolRequest {
// we start with the success header so we don't need to shift the output in the common case
output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
changed_buffers: HashSet::default(),
bad_searches: Vec::new(),
action_log,
project,
tool_log,
@ -239,30 +251,36 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let diff = match action {
let result = match action {
EditAction::Replace {
old,
new,
file_path: _,
file_path,
} => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let diff = cx
.background_executor()
.spawn(Self::replace_diff(old, new, snapshot))
.await;
anyhow::Ok(diff)
cx.background_executor()
.spawn(Self::replace_diff(old, new, file_path, snapshot))
.await
}
EditAction::Write { content, .. } => Ok(buffer
.read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await),
EditAction::Write { content, .. } => Ok(DiffResult::Diff(
buffer
.read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await,
)),
}?;
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
match result {
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
write!(&mut self.output, "\n\n{}", source)?;
self.changed_buffers.insert(buffer);
write!(&mut self.output, "\n\n{}", source)?;
self.changed_buffers.insert(buffer);
}
}
Ok(())
}
@ -270,9 +288,29 @@ impl EditToolRequest {
async fn replace_diff(
old: String,
new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot,
) -> language::Diff {
let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
) -> Result<DiffResult> {
let query = SearchQuery::text(
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
@ -290,7 +328,7 @@ impl EditToolRequest {
edits,
};
diff
anyhow::Ok(DiffResult::Diff(diff))
}
const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@ -314,7 +352,7 @@ impl EditToolRequest {
let errors = self.parser.errors();
if errors.is_empty() {
if errors.is_empty() && self.bad_searches.is_empty() {
if changed_buffer_count == 0 {
return Err(anyhow!(
"The instructions didn't lead to any changes. You might need to consult the file contents first."
@ -337,6 +375,24 @@ impl EditToolRequest {
);
}
if !self.bad_searches.is_empty() {
writeln!(
&mut output,
"\n\nThese searches failed because they didn't match any strings:"
)?;
for replace in self.bad_searches {
writeln!(
&mut output,
"- '{}' does not appear in `{}`",
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
replace.file_path
)?;
}
write!(&mut output, "Make sure to use exact searches.")?;
}
if !errors.is_empty() {
writeln!(
&mut output,

View File

@ -1,226 +0,0 @@
use language::{Anchor, Bias, BufferSnapshot};
use std::ops::Range;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(rows: usize, cols: usize) -> Self {
SearchMatrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
const WHITESPACE_INSERTION_COST: u32 = 1;
const WHITESPACE_DELETION_COST: u32 = 1;
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
let mut leading_deletion_cost = 0_u32;
for (row, query_byte) in search_query.bytes().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
);
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).cost
} else {
matrix
.get(row, col)
.cost
.saturating_add(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_cost = u32::MAX;
for col in 1..=buffer_len {
let cost = matrix.get(query_len, col).cost;
if cost < best_cost {
best_cost = cost;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
#[cfg(test)]
mod tests {
use crate::edit_files_tool::resolve_search_block::resolve_search_block;
use gpui::{prelude::*, App};
use language::{Buffer, OffsetRangeExt as _};
use unindent::Unindent as _;
use util::test::{generate_marked_text, marked_text_ranges};
#[gpui::test]
fn test_resolve_search_block(cx: &mut App) {
assert_resolved(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor",
cx,
);
assert_resolved(
&"
«fn foo1(a: usize) -> usize {
40
}»
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"fn foo1(b: usize) {\n40\n}",
cx,
);
assert_resolved(
&"
fn main() {
« Foo
.bar()
.baz()
.qux()»
}
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"Foo.bar.baz.qux()",
cx,
);
assert_resolved(
&"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
» seven() { return 7; }
eight() { return 8; }
}
"
.unindent(),
&"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"
.unindent(),
cx,
);
}
#[track_caller]
fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
let (text, _) = marked_text_ranges(text_with_expected_range, false);
let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
let snapshot = buffer.read(cx).snapshot();
let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
let text_with_actual_range = generate_marked_text(&text, &[range], false);
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
}
}