Display what the tool is doing (#27120)

<img width="639" alt="Screenshot 2025-03-19 at 4 56 47 PM"
src="https://github.com/user-attachments/assets/b997f04d-4aff-4070-87b1-ffdb61019bd1"
/>

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Richard Feldman 2025-03-20 09:16:39 -04:00 committed by GitHub
parent aae81fd54c
commit e3578fc44a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 349 additions and 132 deletions

View File

@ -35,6 +35,7 @@ pub struct ActiveThread {
list_state: ListState, list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>, rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>, rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>, editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>, expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>, last_error: Option<ThreadError>,
@ -70,6 +71,7 @@ impl ActiveThread {
messages: Vec::new(), messages: Vec::new(),
rendered_messages_by_id: HashMap::default(), rendered_messages_by_id: HashMap::default(),
rendered_scripting_tool_uses: HashMap::default(), rendered_scripting_tool_uses: HashMap::default(),
rendered_tool_use_labels: HashMap::default(),
expanded_tool_uses: HashMap::default(), expanded_tool_uses: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade(); let this = cx.entity().downgrade();
@ -86,10 +88,29 @@ impl ActiveThread {
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx); this.push_message(&message.id, message.text.clone(), window, cx);
for tool_use in thread.read(cx).scripting_tool_uses_for_message(message.id) { for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_label_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
}
for tool_use in thread
.read(cx)
.scripting_tool_uses_for_message(message.id, cx)
{
this.render_tool_use_label_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
this.render_scripting_tool_use_markdown( this.render_scripting_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.name.as_ref(), tool_use.ui_text.as_ref(),
tool_use.input.clone(), tool_use.input.clone(),
window, window,
cx, cx,
@ -287,6 +308,19 @@ impl ActiveThread {
.insert(tool_use_id, lua_script); .insert(tool_use_id, lua_script);
} }
fn render_tool_use_label_markdown(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.rendered_tool_use_labels.insert(
tool_use_id,
self.render_markdown(tool_label.into(), window, cx),
);
}
fn handle_thread_event( fn handle_thread_event(
&mut self, &mut self,
_thread: &Entity<Thread>, _thread: &Entity<Thread>,
@ -341,9 +375,18 @@ impl ActiveThread {
cx.notify(); cx.notify();
} }
ThreadEvent::UsePendingTools => { ThreadEvent::UsePendingTools => {
self.thread.update(cx, |thread, cx| { let tool_uses = self
thread.use_pending_tools(cx); .thread
}); .update(cx, |thread, cx| thread.use_pending_tools(cx));
for tool_use in tool_uses {
self.render_tool_use_label_markdown(
tool_use.id,
tool_use.ui_text.clone(),
window,
cx,
);
}
} }
ThreadEvent::ToolFinished { ThreadEvent::ToolFinished {
pending_tool_use, pending_tool_use,
@ -352,6 +395,13 @@ impl ActiveThread {
} => { } => {
let canceled = *canceled; let canceled = *canceled;
if let Some(tool_use) = pending_tool_use { if let Some(tool_use) = pending_tool_use {
self.render_tool_use_label_markdown(
tool_use.id.clone(),
SharedString::from(tool_use.ui_text.clone()),
window,
cx,
);
self.render_scripting_tool_use_markdown( self.render_scripting_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.name.as_ref(), tool_use.name.as_ref(),
@ -555,8 +605,8 @@ impl ActiveThread {
// Get all the data we need from thread before we start using it in closures // Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id); let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id); let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id, cx);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id, cx);
// Don't render user messages that are just there for returning tool results. // Don't render user messages that are just there for returning tool results.
if message.role == Role::User if message.role == Role::User
@ -709,27 +759,25 @@ impl ActiveThread {
) )
.child(div().p_2().child(message_content)), .child(div().p_2().child(message_content)),
), ),
Role::Assistant => { Role::Assistant => v_flex()
v_flex() .id(("message-container", ix))
.id(("message-container", ix)) .child(div().py_3().px_4().child(message_content))
.child(div().py_3().px_4().child(message_content)) .when(
.when( !tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(), |parent| {
|parent| { parent.child(
parent.child( v_flex()
v_flex() .children(
.children( tool_uses
tool_uses .into_iter()
.into_iter() .map(|tool_use| self.render_tool_use(tool_use, cx)),
.map(|tool_use| self.render_tool_use(tool_use, cx)), )
) .children(scripting_tool_uses.into_iter().map(|tool_use| {
.children(scripting_tool_uses.into_iter().map(|tool_use| { self.render_scripting_tool_use(tool_use, window, cx)
self.render_scripting_tool_use(tool_use, cx) })),
})), )
) },
}, ),
)
}
Role::System => div().id(("message-container", ix)).py_1().px_2().child( Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex() v_flex()
.bg(colors.editor_background) .bg(colors.editor_background)
@ -805,11 +853,10 @@ impl ActiveThread {
} }
}), }),
)) ))
.child( .child(div().text_ui_sm(cx).children(
Label::new(tool_use.name) self.rendered_tool_use_labels.get(&tool_use.id).cloned(),
.size(LabelSize::Small) ))
.buffer_font(cx), .truncate(),
),
) )
.child({ .child({
let (icon_name, color, animated) = match &tool_use.status { let (icon_name, color, animated) = match &tool_use.status {
@ -937,6 +984,7 @@ impl ActiveThread {
fn render_scripting_tool_use( fn render_scripting_tool_use(
&self, &self,
tool_use: ToolUse, tool_use: ToolUse,
window: &Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> impl IntoElement { ) -> impl IntoElement {
let is_open = self let is_open = self
@ -982,7 +1030,12 @@ impl ActiveThread {
} }
}), }),
)) ))
.child(Label::new(tool_use.name)), .child(div().text_ui_sm(cx).child(self.render_markdown(
tool_use.ui_text.clone(),
window,
cx,
)))
.truncate(),
) )
.child( .child(
Label::new(match tool_use.status { Label::new(match tool_use.status {

View File

@ -458,7 +458,7 @@ impl AssistantPanel {
workspace.update_in(cx, |workspace, window, cx| { workspace.update_in(cx, |workspace, window, cx| {
let thread = thread.read(cx); let thread = thread.read(cx);
let markdown = thread.to_markdown()?; let markdown = thread.to_markdown(cx)?;
let thread_summary = thread let thread_summary = thread
.summary() .summary()
.map(|summary| summary.to_string()) .map(|summary| summary.to_string())

View File

@ -146,10 +146,10 @@ impl Thread {
pending_completions: Vec::new(), pending_completions: Vec::new(),
project: project.clone(), project: project.clone(),
prompt_builder, prompt_builder,
tools, tools: tools.clone(),
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(tools.clone()),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(), scripting_tool_use: ToolUseState::new(tools),
action_log: cx.new(|_| ActionLog::new()), action_log: cx.new(|_| ActionLog::new()),
initial_project_snapshot: { initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx); let project_snapshot = Self::project_snapshot(project, cx);
@ -176,11 +176,12 @@ impl Thread {
.map(|message| message.id.0 + 1) .map(|message| message.id.0 + 1)
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| { let tool_use =
name != ScriptingTool::NAME ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
}); name != ScriptingTool::NAME
});
let scripting_tool_use = let scripting_tool_use =
ToolUseState::from_serialized_messages(&serialized.messages, |name| { ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
name == ScriptingTool::NAME name == ScriptingTool::NAME
}); });
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
@ -328,12 +329,12 @@ impl Thread {
all_pending_tool_uses.all(|tool_use| tool_use.status.is_error()) all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
} }
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> { pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id) self.tool_use.tool_uses_for_message(id, cx)
} }
pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> { pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id) self.scripting_tool_use.tool_uses_for_message(id, cx)
} }
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
@ -448,7 +449,7 @@ impl Thread {
let initial_project_snapshot = self.initial_project_snapshot.clone(); let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await; let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, _| SerializedThread { this.read_with(cx, |this, cx| SerializedThread {
summary: this.summary_or_default(), summary: this.summary_or_default(),
updated_at: this.updated_at(), updated_at: this.updated_at(),
messages: this messages: this
@ -458,9 +459,9 @@ impl Thread {
role: message.role, role: message.role,
text: message.text.clone(), text: message.text.clone(),
tool_uses: this tool_uses: this
.tool_uses_for_message(message.id) .tool_uses_for_message(message.id, cx)
.into_iter() .into_iter()
.chain(this.scripting_tool_uses_for_message(message.id)) .chain(this.scripting_tool_uses_for_message(message.id, cx))
.map(|tool_use| SerializedToolUse { .map(|tool_use| SerializedToolUse {
id: tool_use.id, id: tool_use.id,
name: tool_use.name, name: tool_use.name,
@ -809,13 +810,17 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant) .rfind(|message| message.role == Role::Assistant)
{ {
if tool_use.name.as_ref() == ScriptingTool::NAME { if tool_use.name.as_ref() == ScriptingTool::NAME {
thread thread.scripting_tool_use.request_tool_use(
.scripting_tool_use last_assistant_message.id,
.request_tool_use(last_assistant_message.id, tool_use); tool_use,
cx,
);
} else { } else {
thread thread.tool_use.request_tool_use(
.tool_use last_assistant_message.id,
.request_tool_use(last_assistant_message.id, tool_use); tool_use,
cx,
);
} }
} }
} }
@ -956,7 +961,10 @@ impl Thread {
}); });
} }
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) { pub fn use_pending_tools(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx); let request = self.to_completion_request(RequestKind::Chat, cx);
let pending_tool_uses = self let pending_tool_uses = self
.tool_use .tool_use
@ -966,17 +974,22 @@ impl Thread {
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for tool_use in pending_tool_uses { for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) { if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run( let task = tool.run(
tool_use.input, tool_use.input.clone(),
&request.messages, &request.messages,
self.project.clone(), self.project.clone(),
self.action_log.clone(), self.action_log.clone(),
cx, cx,
); );
self.insert_tool_output(tool_use.id.clone(), task, cx); self.insert_tool_output(
tool_use.id.clone(),
tool_use.ui_text.clone().into(),
task,
cx,
);
} }
} }
@ -988,8 +1001,8 @@ impl Thread {
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses { for scripting_tool_use in pending_scripting_tool_uses.iter() {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) { let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
Err(err) => Task::ready(Err(err.into())), Err(err) => Task::ready(Err(err.into())),
Ok(input) => { Ok(input) => {
let (script_id, script_task) = let (script_id, script_task) =
@ -1016,13 +1029,20 @@ impl Thread {
} }
}; };
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx); let ui_text: SharedString = scripting_tool_use.name.clone().into();
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
} }
pending_tool_uses
.into_iter()
.chain(pending_scripting_tool_uses)
} }
pub fn insert_tool_output( pub fn insert_tool_output(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>, output: Task<Result<String>>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -1047,12 +1067,13 @@ impl Thread {
}); });
self.tool_use self.tool_use
.run_pending_tool(tool_use_id, insert_output_task); .run_pending_tool(tool_use_id, ui_text, insert_output_task);
} }
pub fn insert_scripting_tool_output( pub fn insert_scripting_tool_output(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>, output: Task<Result<String>>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -1077,7 +1098,7 @@ impl Thread {
}); });
self.scripting_tool_use self.scripting_tool_use
.run_pending_tool(tool_use_id, insert_output_task); .run_pending_tool(tool_use_id, ui_text, insert_output_task);
} }
pub fn attach_tool_results( pub fn attach_tool_results(
@ -1250,7 +1271,7 @@ impl Thread {
}) })
} }
pub fn to_markdown(&self) -> Result<String> { pub fn to_markdown(&self, cx: &App) -> Result<String> {
let mut markdown = Vec::new(); let mut markdown = Vec::new();
if let Some(summary) = self.summary() { if let Some(summary) = self.summary() {
@ -1269,7 +1290,7 @@ impl Thread {
)?; )?;
writeln!(markdown, "{}\n", message.text)?; writeln!(markdown, "{}\n", message.text)?;
for tool_use in self.tool_uses_for_message(message.id) { for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!( writeln!(
markdown, markdown,
"**Use Tool: {} ({})**", "**Use Tool: {} ({})**",

View File

@ -1,10 +1,11 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use collections::HashMap; use collections::HashMap;
use futures::future::Shared; use futures::future::Shared;
use futures::FutureExt as _; use futures::FutureExt as _;
use gpui::{SharedString, Task}; use gpui::{App, SharedString, Task};
use language_model::{ use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, Role, LanguageModelToolUseId, MessageContent, Role,
@ -17,6 +18,7 @@ use crate::thread_store::SerializedMessage;
pub struct ToolUse { pub struct ToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
pub name: SharedString, pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus, pub status: ToolUseStatus,
pub input: serde_json::Value, pub input: serde_json::Value,
} }
@ -30,6 +32,7 @@ pub enum ToolUseStatus {
} }
pub struct ToolUseState { pub struct ToolUseState {
tools: Arc<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>, tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>, tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>, tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@ -37,8 +40,9 @@ pub struct ToolUseState {
} }
impl ToolUseState { impl ToolUseState {
pub fn new() -> Self { pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
Self { Self {
tools,
tool_uses_by_assistant_message: HashMap::default(), tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(), tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(), tool_results: HashMap::default(),
@ -50,10 +54,11 @@ impl ToolUseState {
/// ///
/// Accepts a function to filter the tools that should be used to populate the state. /// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages( pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>,
messages: &[SerializedMessage], messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool, mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self { ) -> Self {
let mut this = Self::new(); let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default(); let mut tool_names_by_id = HashMap::default();
for message in messages { for message in messages {
@ -138,7 +143,7 @@ impl ToolUseState {
self.pending_tool_uses_by_id.values().collect() self.pending_tool_uses_by_id.values().collect()
} }
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> { pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new(); return Vec::new();
}; };
@ -173,6 +178,7 @@ impl ToolUseState {
tool_uses.push(ToolUse { tool_uses.push(ToolUse {
id: tool_use.id.clone(), id: tool_use.id.clone(),
name: tool_use.name.clone().into(), name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
input: tool_use.input.clone(), input: tool_use.input.clone(),
status, status,
}) })
@ -181,6 +187,19 @@ impl ToolUseState {
tool_uses tool_uses
} }
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
"Unknown tool".into()
}
}
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> { pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
let empty = Vec::new(); let empty = Vec::new();
@ -209,6 +228,7 @@ impl ToolUseState {
&mut self, &mut self,
assistant_message_id: MessageId, assistant_message_id: MessageId,
tool_use: LanguageModelToolUse, tool_use: LanguageModelToolUse,
cx: &App,
) { ) {
self.tool_uses_by_assistant_message self.tool_uses_by_assistant_message
.entry(assistant_message_id) .entry(assistant_message_id)
@ -228,15 +248,24 @@ impl ToolUseState {
PendingToolUse { PendingToolUse {
assistant_message_id, assistant_message_id,
id: tool_use.id, id: tool_use.id,
name: tool_use.name, name: tool_use.name.clone(),
ui_text: self
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
.into(),
input: tool_use.input, input: tool_use.input,
status: PendingToolUseStatus::Idle, status: PendingToolUseStatus::Idle,
}, },
); );
} }
pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) { pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running { tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(), _task: task.shared(),
}; };
@ -335,6 +364,7 @@ pub struct PendingToolUse {
#[allow(unused)] #[allow(unused)]
pub assistant_message_id: MessageId, pub assistant_message_id: MessageId,
pub name: Arc<str>, pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value, pub input: serde_json::Value,
pub status: PendingToolUseStatus, pub status: PendingToolUseStatus,
} }

View File

@ -128,12 +128,7 @@ impl HeadlessAssistant {
} }
} }
} }
ThreadEvent::StreamedCompletion _ => {}
| ThreadEvent::SummaryChanged
| ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::MessageAdded(_)
| ThreadEvent::MessageEdited(_)
| ThreadEvent::MessageDeleted(_) => {}
} }
} }
} }

View File

@ -5,8 +5,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use gpui::Context; use gpui::{App, Context, Entity, SharedString, Task};
use gpui::{App, Entity, SharedString, Task};
use language::Buffer; use language::Buffer;
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -44,6 +43,9 @@ pub trait Tool: 'static + Send + Sync {
serde_json::Value::Object(serde_json::Map::default()) serde_json::Value::Object(serde_json::Map::default())
} }
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Runs the tool with the provided input. /// Runs the tool with the provided input.
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,

View File

@ -32,6 +32,13 @@ impl Tool for BashTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BashToolInput>(input.clone()) {
Ok(input) => format!("`$ {}`", input.command),
Err(_) => "Run bash command".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View File

@ -39,6 +39,13 @@ impl Tool for DeletePathTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -59,13 +66,12 @@ impl Tool for DeletePathTool {
{ {
Some(deletion_task) => cx.background_spawn(async move { Some(deletion_task) => cx.background_spawn(async move {
match deletion_task.await { match deletion_task.await {
Ok(()) => Ok(format!("Deleted {}", &path_str)), Ok(()) => Ok(format!("Deleted {path_str}")),
Err(err) => Err(anyhow!("Failed to delete {}: {}", &path_str, err)), Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")),
} }
}), }),
None => Task::ready(Err(anyhow!( None => Task::ready(Err(anyhow!(
"Couldn't delete {} because that path isn't in this project.", "Couldn't delete {path_str} because that path isn't in this project."
path_str
))), ))),
} }
} }

View File

@ -46,6 +46,17 @@ impl Tool for DiagnosticsTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
.ok()
.and_then(|input| input.path)
{
format!("Check diagnostics for “`{}`”", path.display())
} else {
"Check project diagnostics".to_string()
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -54,14 +65,15 @@ impl Tool for DiagnosticsTool {
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<DiagnosticsToolInput>(input) { if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input)
Ok(input) => input, .ok()
Err(err) => return Task::ready(Err(anyhow!(err))), .and_then(|input| input.path)
}; {
if let Some(path) = input.path {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path in project"))); return Task::ready(Err(anyhow!(
"Could not find path {} in project",
path.display()
)));
}; };
let buffer = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); let buffer = project.update(cx, |project, cx| project.open_buffer(project_path, cx));

View File

@ -24,10 +24,7 @@ use util::ResultExt;
pub struct EditFilesToolInput { pub struct EditFilesToolInput {
/// High-level edit instructions. These will be interpreted by a smaller /// High-level edit instructions. These will be interpreted by a smaller
/// model, so explain the changes you want that model to make and which /// model, so explain the changes you want that model to make and which
/// file paths need changing. /// file paths need changing. The description should be concise and clear.
///
/// The description should be concise and clear. We will show this
/// description to the user as well.
/// ///
/// WARNING: When specifying which file paths need changing, you MUST /// WARNING: When specifying which file paths need changing, you MUST
/// start each path with one of the project's root directories. /// start each path with one of the project's root directories.
@ -58,6 +55,21 @@ pub struct EditFilesToolInput {
/// Notice how we never specify code snippets in the instructions! /// Notice how we never specify code snippets in the instructions!
/// </example> /// </example>
pub edit_instructions: String, pub edit_instructions: String,
/// A user-friendly description of what changes are being made.
/// This will be shown to the user in the UI to describe the edit operation. The screen real estate for this UI will be extremely
/// constrained, so make the description extremely terse.
///
/// <example>
/// For fixing a broken authentication system:
/// "Fix auth bug in login flow"
/// </example>
///
/// <example>
/// For adding unit tests to a module:
/// "Add tests for user profile logic"
/// </example>
pub display_description: String,
} }
pub struct EditFilesTool; pub struct EditFilesTool;
@ -76,6 +88,13 @@ impl Tool for EditFilesTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFilesToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit files".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View File

@ -122,6 +122,13 @@ impl Tool for FetchTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) {
Ok(input) => format!("Fetch `{}`", input.url),
Err(_) => "Fetch URL".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View File

@ -50,6 +50,13 @@ impl Tool for ListDirectoryTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
Ok(input) => format!("List the `{}` directory's contents", input.path.display()),
Err(_) => "List directory".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -64,7 +71,10 @@ impl Tool for ListDirectoryTool {
}; };
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path not found in project"))); return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)));
}; };
let Some(worktree) = project let Some(worktree) = project
.read(cx) .read(cx)
@ -79,7 +89,7 @@ impl Tool for ListDirectoryTool {
}; };
if !entry.is_dir() { if !entry.is_dir() {
return Task::ready(Err(anyhow!("{} is a file.", input.path.display()))); return Task::ready(Err(anyhow!("{} is not a directory.", input.path.display())));
} }
let mut output = String::new(); let mut output = String::new();

View File

@ -40,6 +40,10 @@ impl Tool for NowTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Get current time".to_string()
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View File

@ -48,6 +48,13 @@ impl Tool for PathSearchTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -62,7 +69,7 @@ impl Tool for PathSearchTool {
}; };
let path_matcher = match PathMatcher::new(&[glob.clone()]) { let path_matcher = match PathMatcher::new(&[glob.clone()]) {
Ok(matcher) => matcher, Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))), Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
}; };
let snapshots: Vec<Snapshot> = project let snapshots: Vec<Snapshot> = project
.read(cx) .read(cx)

View File

@ -53,6 +53,13 @@ impl Tool for ReadFileTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
Ok(input) => format!("Read file `{}`", input.path.display()),
Err(_) => "Read file".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -67,7 +74,10 @@ impl Tool for ReadFileTool {
}; };
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path not found in project"))); return Task::ready(Err(anyhow!(
"Path {} not found in project",
&input.path.display()
)));
}; };
cx.spawn(async move |cx| { cx.spawn(async move |cx| {

View File

@ -22,10 +22,17 @@ pub struct RegexSearchToolInput {
/// Optional starting position for paginated results (0-based). /// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning. /// When not provided, starts from the beginning.
#[serde(default)] #[serde(default)]
pub offset: Option<usize>, pub offset: Option<u32>,
} }
const RESULTS_PER_PAGE: usize = 20; impl RegexSearchToolInput {
/// Which page of search results this is.
pub fn page(&self) -> u32 {
1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
}
}
const RESULTS_PER_PAGE: u32 = 20;
pub struct RegexSearchTool; pub struct RegexSearchTool;
@ -43,6 +50,24 @@ impl Tool for RegexSearchTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
if page > 1 {
format!(
"Get page {page} of search results for regex “`{}`”",
input.regex
)
} else {
format!("Search files for regex “`{}`”", input.regex)
}
}
Err(_) => "Search with regex".to_string(),
}
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -154,7 +179,7 @@ impl Tool for RegexSearchTool {
offset + matches_found, offset + matches_found,
offset + RESULTS_PER_PAGE, offset + RESULTS_PER_PAGE,
)) ))
} else { } else {
Ok(format!("Found {matches_found} matches:\n{output}")) Ok(format!("Found {matches_found} matches:\n{output}"))
} }
}) })

View File

@ -31,6 +31,10 @@ impl Tool for ThinkingTool {
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Thinking".to_string()
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View File

@ -56,6 +56,10 @@ impl Tool for ContextServerTool {
} }
} }
fn ui_text(&self, _input: &serde_json::Value) -> String {
format!("Run MCP tool `{}`", self.tool.name)
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -65,42 +69,43 @@ impl Tool for ContextServerTool {
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
cx.foreground_executor().spawn({ let tool_name = self.tool.name.clone();
let tool_name = self.tool.name.clone(); let server_clone = server.clone();
async move { let input_clone = input.clone();
let Some(protocol) = server.client() else {
bail!("Context server not initialized");
};
let arguments = if let serde_json::Value::Object(map) = input { cx.spawn(async move |_cx| {
Some(map.into_iter().collect()) let Some(protocol) = server_clone.client() else {
} else { bail!("Context server not initialized");
None };
};
log::trace!( let arguments = if let serde_json::Value::Object(map) = input_clone {
"Running tool: {} with arguments: {:?}", Some(map.into_iter().collect())
tool_name, } else {
arguments None
); };
let response = protocol.run_tool(tool_name, arguments).await?;
let mut result = String::new(); log::trace!(
for content in response.content { "Running tool: {} with arguments: {:?}",
match content { tool_name,
types::ToolResponseContent::Text { text } => { arguments
result.push_str(&text); );
} let response = protocol.run_tool(tool_name, arguments).await?;
types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response"); let mut result = String::new();
} for content in response.content {
types::ToolResponseContent::Resource { .. } => { match content {
log::warn!("Ignoring resource content from tool response"); types::ToolResponseContent::Text { text } => {
} result.push_str(&text);
}
types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
} }
} }
Ok(result)
} }
Ok(result)
}) })
} else { } else {
Task::ready(Err(anyhow!("Context server not found"))) Task::ready(Err(anyhow!("Context server not found")))