agent: Expose web search tool to beta users (#29273)
This gives all beta users access to the web search tool Release Notes: - agent: Added `web_search` tool
This commit is contained in:
parent
09db31288a
commit
822b6f837d
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -703,9 +703,10 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"component",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"html_to_markdown",
|
||||
@ -16631,7 +16632,6 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
|
@ -23,7 +23,6 @@ use gpui::{
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
@ -489,8 +488,8 @@ impl AssistantPanel {
|
||||
|
||||
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
||||
// the provider, we want to show a nudge to sign in.
|
||||
let show_zed_ai_notice = client_status.is_signed_out()
|
||||
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
||||
let show_zed_ai_notice =
|
||||
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
|
||||
|
||||
self.show_zed_ai_notice = show_zed_ai_notice;
|
||||
cx.notify();
|
||||
|
@ -17,7 +17,6 @@ assistant_tool.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
@ -41,6 +40,8 @@ worktree.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
|
@ -29,9 +29,9 @@ use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use move_path_tool::MovePathTool;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
@ -85,34 +85,45 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
move |registry, event, cx| match event {
|
||||
language_model::Event::DefaultModelChanged => {
|
||||
let using_zed_provider = registry
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.map_or(false, |default| default.is_provided_by_zed());
|
||||
if using_zed_provider {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use client::Client;
|
||||
use clock::FakeSystemClock;
|
||||
use http_client::FakeHttpClient;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
crate::init(
|
||||
Arc::new(http_client::HttpClientWithUrl::new(
|
||||
FakeHttpClient::with_200_response(),
|
||||
"https://zed.dev",
|
||||
None,
|
||||
)),
|
||||
settings::init(cx);
|
||||
|
||||
let client = Client::new(
|
||||
Arc::new(FakeSystemClock::new()),
|
||||
FakeHttpClient::with_200_response(),
|
||||
cx,
|
||||
);
|
||||
language_model::init(client.clone(), cx);
|
||||
crate::init(client.http_client(), cx);
|
||||
|
||||
for tool in ToolRegistry::global(cx).tools() {
|
||||
let actual_schema = tool
|
||||
|
@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
|
||||
const NAME: &'static str = "zed-pro";
|
||||
}
|
||||
|
||||
pub struct ZedProWebSearchTool {}
|
||||
impl FeatureFlag for ZedProWebSearchTool {
|
||||
const NAME: &'static str = "zed-pro-web-search-tool";
|
||||
}
|
||||
|
||||
pub struct NotebookFeatureFlag;
|
||||
|
||||
impl FeatureFlag for NotebookFeatureFlag {
|
||||
|
@ -42,6 +42,10 @@ impl ConfiguredModel {
|
||||
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
|
||||
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
|
||||
}
|
||||
|
||||
pub fn is_provided_by_zed(&self) -> bool {
|
||||
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
|
@ -61,4 +61,11 @@ impl WebSearchRegistry {
|
||||
self.active_provider = Some(provider);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
|
||||
self.providers.remove(&id);
|
||||
if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
|
||||
self.active_provider = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ path = "src/web_search_providers.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
|
@ -50,9 +50,11 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
|
||||
|
||||
impl WebSearchProvider for CloudWebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId {
|
||||
WebSearchProviderId("zed.dev".into())
|
||||
WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||
|
@ -1,10 +1,10 @@
|
||||
mod cloud;
|
||||
|
||||
use client::Client;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
||||
use gpui::{App, Context};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use std::sync::Arc;
|
||||
use web_search::WebSearchRegistry;
|
||||
use web_search::{WebSearchProviderId, WebSearchRegistry};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
let registry = WebSearchRegistry::global(cx);
|
||||
@ -18,18 +18,27 @@ fn register_web_search_providers(
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
cx.observe_flag::<ZedProWebSearchTool, _>({
|
||||
let client = client.clone();
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
move |this, registry, event, cx| match event {
|
||||
language_model::Event::DefaultModelChanged => {
|
||||
let using_zed_provider = registry
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.map_or(false, |default| default.is_provided_by_zed());
|
||||
if using_zed_provider {
|
||||
this.register_provider(
|
||||
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
)
|
||||
} else {
|
||||
this.unregister_provider(WebSearchProviderId(
|
||||
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user