agent: Expose web search tool to beta users (#29273)

This gives all beta users access to the web search tool

Release Notes:

- agent: Added `web_search` tool
This commit is contained in:
Bennet Bo Fenner 2025-04-23 17:30:20 +02:00 committed by GitHub
parent 09db31288a
commit 822b6f837d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 67 additions and 40 deletions

4
Cargo.lock generated
View File

@ -703,9 +703,10 @@ dependencies = [
"anyhow",
"assistant_tool",
"chrono",
"client",
"clock",
"collections",
"component",
"feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
@ -16631,7 +16632,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",

View File

@ -23,7 +23,6 @@ use gpui::{
use language::LanguageRegistry;
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
@ -489,8 +488,8 @@ impl AssistantPanel {
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in.
let show_zed_ai_notice = client_status.is_signed_out()
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
let show_zed_ai_notice =
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify();

View File

@ -17,7 +17,6 @@ assistant_tool.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@ -41,6 +40,8 @@ worktree.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }

View File

@ -29,9 +29,9 @@ use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
@ -85,34 +85,45 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
}
})
_ => {}
},
)
.detach();
}
#[cfg(test)]
mod tests {
use client::Client;
use clock::FakeSystemClock;
use http_client::FakeHttpClient;
use super::*;
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
"https://zed.dev",
None,
)),
settings::init(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_200_response(),
cx,
);
language_model::init(client.clone(), cx);
crate::init(client.http_client(), cx);
for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool

View File

@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
pub struct ZedProWebSearchTool {}
impl FeatureFlag for ZedProWebSearchTool {
const NAME: &'static str = "zed-pro-web-search-tool";
}
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {

View File

@ -42,6 +42,10 @@ impl ConfiguredModel {
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
}
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
}
}
pub enum Event {

View File

@ -61,4 +61,11 @@ impl WebSearchRegistry {
self.active_provider = Some(provider);
}
}
pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
self.providers.remove(&id);
if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
self.active_provider = None;
}
}
}

View File

@ -14,7 +14,6 @@ path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true

View File

@ -50,9 +50,11 @@ impl State {
}
}
pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
impl WebSearchProvider for CloudWebSearchProvider {
fn id(&self) -> WebSearchProviderId {
WebSearchProviderId("zed.dev".into())
WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
}
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {

View File

@ -1,10 +1,10 @@
mod cloud;
use client::Client;
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
use gpui::{App, Context};
use language_model::LanguageModelRegistry;
use std::sync::Arc;
use web_search::WebSearchRegistry;
use web_search::{WebSearchProviderId, WebSearchRegistry};
pub fn init(client: Arc<Client>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
@ -18,18 +18,27 @@ fn register_web_search_providers(
client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>,
) {
cx.observe_flag::<ZedProWebSearchTool, _>({
let client = client.clone();
move |is_enabled, cx| {
if is_enabled {
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
registry.register_provider(
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |this, registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
this.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx,
);
});
)
} else {
this.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
));
}
}
}
})
_ => {}
},
)
.detach();
}