Allow StreamingEditFileTool to also create files (#29785)

Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-02 11:57:04 +02:00 committed by GitHub
parent f619d5f02a
commit 35539847a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2914 additions and 140 deletions

View File

@ -77,7 +77,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(SymbolInfoTool);
@ -125,12 +124,14 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(CreateFileTool);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(CreateFileTool);
registry.register_tool(EditFileTool);
}
}

View File

@ -10,6 +10,7 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
use futures::{
Stream, StreamExt,
channel::mpsc::{self, UnboundedReceiver},
pin_mut,
stream::BoxStream,
};
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
#[derive(Serialize)]
pub struct EditAgentTemplate {
struct CreateFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditAgentTemplate {
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
impl Template for CreateFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
}
#[derive(Serialize)]
struct EditFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EditAgentOutputEvent {
Edited,
HallucinatedOldText(SharedString),
OldTextNotFound(SharedString),
}
#[derive(Clone, Debug)]
@ -64,6 +75,82 @@ impl EditAgent {
}
}
pub fn overwrite(
&self,
buffer: Entity<Buffer>,
edit_description: String,
previous_messages: Vec<LanguageModelRequestMessage>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = CreateFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let new_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
output.await
});
(output, events_rx)
}
fn replace_text_with_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |cx| {
// Ensure the buffer is tracked by the action log.
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
});
(task, output_events_rx)
}
pub fn edit(
&self,
buffer: Entity<Buffer>,
@ -78,10 +165,15 @@ impl EditAgent {
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let edit_chunks = this
.request_edits(snapshot, edit_description, previous_messages, cx)
.await?;
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
@ -90,7 +182,7 @@ impl EditAgent {
(output, events_rx)
}
fn apply_edits(
fn apply_edit_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
@ -138,7 +230,7 @@ impl EditAgent {
let Some(old_range) = old_range else {
// We couldn't find the old text in the buffer. Report the error.
output_events
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
.ok();
continue;
};
@ -232,7 +324,7 @@ impl EditAgent {
) {
let (tx, rx) = mpsc::unbounded();
let output = cx.background_spawn(async move {
futures::pin_mut!(chunks);
pin_mut!(chunks);
let mut parser = EditParser::new();
let mut raw_edits = String::new();
@ -336,20 +428,12 @@ impl EditAgent {
})
}
async fn request_edits(
async fn request(
&self,
snapshot: BufferSnapshot,
edit_description: String,
mut messages: Vec<LanguageModelRequestMessage>,
prompt: String,
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditAgentTemplate {
path,
edit_description,
}
.render(&self.templates)?;
let mut message_content = Vec::new();
if let Some(last_message) = messages.last_mut() {
if last_message.role == Role::Assistant {
@ -611,7 +695,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
pretty_assertions::assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -648,7 +733,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -679,7 +765,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -692,7 +779,7 @@ mod tests {
let agent = init_test(cx).await;
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.apply_edits(
let (apply, mut events) = agent.apply_edit_chunks(
buffer.clone(),
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
&mut cx.to_async(),
@ -744,7 +831,7 @@ mod tests {
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::HallucinatedOldText(
vec![EditAgentOutputEvent::OldTextNotFound(
"hallucinated old".into()
)]
);

View File

@ -4,10 +4,11 @@ use crate::{
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use anyhow::anyhow;
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::assert_eq(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_zode() {
let input_file_path = "root/zode.py";
let edit_description = "Create the main Zode CLI script";
eval(
200,
1.,
EvalInput {
conversation: vec![
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
message(
Assistant,
[
tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: "root/eval/react.py".into(),
start_line: None,
end_line: None,
},
),
tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: "root/eval/react_test.py".into(),
start_line: None,
end_line: None,
},
),
],
),
message(
User,
[
tool_result(
"tool_1",
"read_file",
include_str!("evals/fixtures/zode/react.py"),
),
tool_result(
"tool_2",
"read_file",
include_str!("evals/fixtures/zode/react_test.py"),
),
],
),
message(
Assistant,
[
text(
"Now that I understand what we need to build, I'll create the main Python script:",
),
tool_use(
"tool_3",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: true,
},
),
],
),
],
input_path: input_file_path.into(),
input_content: None,
edit_description: edit_description.into(),
assertion: EvalAssertion::new(async move |sample, _, _cx| {
let invalid_starts = [' ', '`', '\n'];
let mut message = String::new();
for start in invalid_starts {
if sample.text.starts_with(start) {
message.push_str(&format!("The sample starts with a {:?}\n", start));
break;
}
}
// Remove trailing newline.
message.pop();
if message.is_empty() {
Ok(EvalAssertionOutcome {
score: 100,
message: None,
})
} else {
Ok(EvalAssertionOutcome {
score: 0,
message: Some(message),
})
}
}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
@ -574,11 +678,135 @@ fn tool_result(
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
input_content: Option<String>,
edit_description: String,
assertion: EvalAssertion,
}
#[derive(Clone)]
struct EvalSample {
text: String,
edit_output: EditAgentOutput,
diff: String,
}
trait AssertionFn: 'static + Send + Sync {
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
}
impl<F> AssertionFn for F
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
(self)(sample, judge_model, cx).boxed_local()
}
}
#[derive(Clone)]
struct EvalAssertion(Arc<dyn AssertionFn>);
impl EvalAssertion {
fn new<F>(f: F) -> Self
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
EvalAssertion(Arc::new(f))
}
fn assert_eq(expected: impl Into<String>) -> Self {
let expected = expected.into();
Self::new(async move |sample, _judge, _cx| {
Ok(EvalAssertionOutcome {
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
100
} else {
0
},
message: None,
})
})
}
fn judge_diff(assertions: &'static str) -> Self {
Self::new(async move |sample, judge, cx| {
let prompt = DiffJudgeTemplate {
diff: sample.diff.clone(),
assertions,
}
.render(&Templates::new())
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = judge
.stream_completion_text(request, &cx.to_async())
.await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
})
}
async fn run(
&self,
input: &EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &mut TestAppContext,
) -> Result<EvalAssertionOutcome> {
self.0.assert(input, judge_model, cx).await
}
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.entry(output.sample.text.clone())
.or_insert(Vec::new())
.push(output);
}
@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
sample: EvalSample,
assertion: EvalAssertionOutcome,
}
impl Display for EvalOutput {
@ -684,14 +910,14 @@ impl Display for EvalOutput {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(f, "Diff:\n{}", self.sample.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
self.sample.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
Ok(())
}
}
@ -777,96 +1003,45 @@ impl EditAgentTest {
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
let (edit_output, _) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
} else {
let (edit_output, _) = self.agent.overwrite(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
let sample = EvalSample {
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
diff: language::unified_diff(
eval.input_content.as_deref().unwrap_or_default(),
&buffer_text,
),
text: buffer_text,
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
let assertion = eval
.assertion
.run(&sample, self.judge_model.clone(), cx)
.await?;
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
Ok(EvalOutput { assertion, sample })
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
struct EvalAssertionOutcome {
score: usize,
message: Option<String>,
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,14 @@
class InputCell:
def __init__(self, initial_value):
self.value = None
class ComputeCell:
def __init__(self, inputs, compute_function):
self.value = None
def add_callback(self, callback):
pass
def remove_callback(self, callback):
pass

View File

@ -0,0 +1,271 @@
# These tests are auto-generated with test data from:
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
# File last updated on 2023-07-19
from functools import partial
import unittest
from react import (
InputCell,
ComputeCell,
)
class ReactTest(unittest.TestCase):
def test_input_cells_have_a_value(self):
input = InputCell(10)
self.assertEqual(input.value, 10)
def test_an_input_cell_s_value_can_be_set(self):
input = InputCell(4)
input.value = 20
self.assertEqual(input.value, 20)
def test_compute_cells_calculate_initial_value(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
self.assertEqual(output.value, 2)
def test_compute_cells_take_inputs_in_the_right_order(self):
one = InputCell(1)
two = InputCell(2)
output = ComputeCell(
[
one,
two,
],
lambda inputs: inputs[0] + inputs[1] * 10,
)
self.assertEqual(output.value, 21)
def test_compute_cells_update_value_when_dependencies_are_changed(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
input.value = 3
self.assertEqual(output.value, 4)
def test_compute_cells_can_depend_on_other_compute_cells(self):
input = InputCell(1)
times_two = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 2,
)
times_thirty = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 30,
)
output = ComputeCell(
[
times_two,
times_thirty,
],
lambda inputs: inputs[0] + inputs[1],
)
self.assertEqual(output.value, 32)
input.value = 3
self.assertEqual(output.value, 96)
def test_compute_cells_fire_callbacks(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callback_cells_only_fire_on_change(self):
input = InputCell(1)
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer[-1], 222)
def test_callbacks_do_not_report_already_reported_values(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer[-1], 3)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callbacks_can_fire_from_multiple_cells(self):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
plus_one.add_callback(callback1)
minus_one.add_callback(callback2)
input.value = 10
self.assertEqual(cb1_observer[-1], 11)
self.assertEqual(cb2_observer[-1], 9)
def test_callbacks_can_be_added_and_removed(self):
input = InputCell(11)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
cb3_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
callback3 = self.callback_factory(cb3_observer)
output.add_callback(callback1)
output.add_callback(callback2)
input.value = 31
self.assertEqual(cb1_observer[-1], 32)
self.assertEqual(cb2_observer[-1], 32)
output.remove_callback(callback1)
output.add_callback(callback3)
input.value = 41
self.assertEqual(len(cb1_observer), 1)
self.assertEqual(cb2_observer[-1], 42)
self.assertEqual(cb3_observer[-1], 42)
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
self,
):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
output.add_callback(callback1)
output.add_callback(callback2)
output.remove_callback(callback1)
output.remove_callback(callback1)
output.remove_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
self.assertEqual(cb2_observer[-1], 3)
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one1 = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
minus_one2 = ComputeCell(
[
minus_one1,
],
lambda inputs: inputs[0] - 1,
)
output = ComputeCell(
[
plus_one,
minus_one2,
],
lambda inputs: inputs[0] * inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 4
self.assertEqual(cb1_observer[-1], 10)
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
always_two = ComputeCell(
[
plus_one,
minus_one,
],
lambda inputs: inputs[0] - inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
always_two.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 3
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer, [])
input.value = 5
self.assertEqual(cb1_observer, [])
# Utility functions.
def callback_factory(self, observer):
def callback(observer, value):
observer.append(value)
return partial(callback, observer)

View File

@ -38,7 +38,7 @@ pub struct StreamingEditFileToolInput {
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
/// The full path of the file to create or modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
@ -58,6 +58,10 @@ pub struct StreamingEditFileToolInput {
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -158,7 +162,7 @@ impl Tool for StreamingEditFileTool {
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !exists.await? {
if !input.create_or_overwrite && !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
@ -182,12 +186,21 @@ impl Tool for StreamingEditFileTool {
})
.await;
let (output, mut events) = edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
);
let (output, mut events) = if input.create_or_overwrite {
edit_agent.overwrite(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
} else {
edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
};
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
@ -213,7 +226,7 @@ impl Tool for StreamingEditFileTool {
.log_err();
}
}
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
output.await?;

View File

@ -1,4 +1,4 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
Before using this tool:

View File

@ -0,0 +1,12 @@
You are an expert engineer and your task is to write a new file from scratch.
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
The text you output will be saved verbatim as the content of the file.

View File

@ -2141,6 +2141,14 @@ impl Buffer {
self.edit([(0..self.len(), text)], None, cx)
}
/// Appends the given text to the end of the buffer.
pub fn append<T>(&mut self, text: T, cx: &mut Context<Self>) -> Option<clock::Lamport>
where
T: Into<Arc<str>>,
{
self.edit([(self.len()..self.len(), text)], None, cx)
}
/// Applies the given edits to the buffer. Each edit is specified as a range of text to
/// delete, and a string of text to insert at that location.
///