agent: Expose web search tool to beta users (#29273)

This gives all beta users access to the web search tool

Release Notes:

- agent: Added `web_search` tool
This commit is contained in:
Bennet Bo Fenner 2025-04-23 17:30:20 +02:00 committed by GitHub
parent 09db31288a
commit 822b6f837d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 67 additions and 40 deletions

4
Cargo.lock generated
View File

@ -703,9 +703,10 @@ dependencies = [
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
"chrono", "chrono",
"client",
"clock",
"collections", "collections",
"component", "component",
"feature_flags",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"html_to_markdown", "html_to_markdown",
@ -16631,7 +16632,6 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"client", "client",
"feature_flags",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"http_client", "http_client",

View File

@ -23,7 +23,6 @@ use gpui::{
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{ use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
}; };
use project::Project; use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library}; use prompt_library::{PromptLibrary, open_prompt_library};
@ -489,8 +488,8 @@ impl AssistantPanel {
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is // If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in. // the provider, we want to show a nudge to sign in.
let show_zed_ai_notice = client_status.is_signed_out() let show_zed_ai_notice =
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID); client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
self.show_zed_ai_notice = show_zed_ai_notice; self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify(); cx.notify();

View File

@ -17,7 +17,6 @@ assistant_tool.workspace = true
chrono.workspace = true chrono.workspace = true
collections.workspace = true collections.workspace = true
component.workspace = true component.workspace = true
feature_flags.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
html_to_markdown.workspace = true html_to_markdown.workspace = true
@ -41,6 +40,8 @@ worktree.workspace = true
zed_llm_client.workspace = true zed_llm_client.workspace = true
[dev-dependencies] [dev-dependencies]
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] }

View File

@ -29,9 +29,9 @@ use std::sync::Arc;
use assistant_tool::ToolRegistry; use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool; use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App; use gpui::App;
use http_client::HttpClientWithUrl; use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool; use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool; use web_search_tool::WebSearchTool;
@ -85,34 +85,45 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool); registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client)); registry.register_tool(FetchTool::new(http_client));
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({ cx.subscribe(
move |is_enabled, cx| { &LanguageModelRegistry::global(cx),
if is_enabled { move |registry, event, cx| match event {
ToolRegistry::global(cx).register_tool(WebSearchTool); language_model::Event::DefaultModelChanged => {
} else { let using_zed_provider = registry
ToolRegistry::global(cx).unregister_tool(WebSearchTool); .read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
} }
} _ => {}
}) },
)
.detach(); .detach();
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use client::Client;
use clock::FakeSystemClock;
use http_client::FakeHttpClient; use http_client::FakeHttpClient;
use super::*; use super::*;
#[gpui::test] #[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) { fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init( settings::init(cx);
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(), let client = Client::new(
"https://zed.dev", Arc::new(FakeSystemClock::new()),
None, FakeHttpClient::with_200_response(),
)),
cx, cx,
); );
language_model::init(client.clone(), cx);
crate::init(client.http_client(), cx);
for tool in ToolRegistry::global(cx).tools() { for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool let actual_schema = tool

View File

@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro"; const NAME: &'static str = "zed-pro";
} }
pub struct ZedProWebSearchTool {}
impl FeatureFlag for ZedProWebSearchTool {
const NAME: &'static str = "zed-pro-web-search-tool";
}
pub struct NotebookFeatureFlag; pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag { impl FeatureFlag for NotebookFeatureFlag {

View File

@ -42,6 +42,10 @@ impl ConfiguredModel {
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool { pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
self.model.id() == other.model.id() && self.provider.id() == other.provider.id() self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
} }
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
}
} }
pub enum Event { pub enum Event {

View File

@ -61,4 +61,11 @@ impl WebSearchRegistry {
self.active_provider = Some(provider); self.active_provider = Some(provider);
} }
} }
pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
self.providers.remove(&id);
if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
self.active_provider = None;
}
}
} }

View File

@ -14,7 +14,6 @@ path = "src/web_search_providers.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
client.workspace = true client.workspace = true
feature_flags.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
http_client.workspace = true http_client.workspace = true

View File

@ -50,9 +50,11 @@ impl State {
} }
} }
pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
impl WebSearchProvider for CloudWebSearchProvider { impl WebSearchProvider for CloudWebSearchProvider {
fn id(&self) -> WebSearchProviderId { fn id(&self) -> WebSearchProviderId {
WebSearchProviderId("zed.dev".into()) WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
} }
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> { fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {

View File

@ -1,10 +1,10 @@
mod cloud; mod cloud;
use client::Client; use client::Client;
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
use gpui::{App, Context}; use gpui::{App, Context};
use language_model::LanguageModelRegistry;
use std::sync::Arc; use std::sync::Arc;
use web_search::WebSearchRegistry; use web_search::{WebSearchProviderId, WebSearchRegistry};
pub fn init(client: Arc<Client>, cx: &mut App) { pub fn init(client: Arc<Client>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx); let registry = WebSearchRegistry::global(cx);
@ -18,18 +18,27 @@ fn register_web_search_providers(
client: Arc<Client>, client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>, cx: &mut Context<WebSearchRegistry>,
) { ) {
cx.observe_flag::<ZedProWebSearchTool, _>({ cx.subscribe(
let client = client.clone(); &LanguageModelRegistry::global(cx),
move |is_enabled, cx| { move |this, registry, event, cx| match event {
if is_enabled { language_model::Event::DefaultModelChanged => {
WebSearchRegistry::global(cx).update(cx, |registry, cx| { let using_zed_provider = registry
registry.register_provider( .read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
this.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx), cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx, cx,
); )
}); } else {
this.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
));
}
} }
} _ => {}
}) },
)
.detach(); .detach();
} }