agent: Expose web search tool to beta users (#29273)
This gives all beta users access to the web search tool Release Notes: - agent: Added `web_search` tool
This commit is contained in:
parent
09db31288a
commit
822b6f837d
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -703,9 +703,10 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"client",
|
||||||
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
"component",
|
"component",
|
||||||
"feature_flags",
|
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"html_to_markdown",
|
"html_to_markdown",
|
||||||
@ -16631,7 +16632,6 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"client",
|
"client",
|
||||||
"feature_flags",
|
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"http_client",
|
"http_client",
|
||||||
|
@ -23,7 +23,6 @@ use gpui::{
|
|||||||
use language::LanguageRegistry;
|
use language::LanguageRegistry;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||||
ZED_CLOUD_PROVIDER_ID,
|
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||||
@ -489,8 +488,8 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
||||||
// the provider, we want to show a nudge to sign in.
|
// the provider, we want to show a nudge to sign in.
|
||||||
let show_zed_ai_notice = client_status.is_signed_out()
|
let show_zed_ai_notice =
|
||||||
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
|
||||||
|
|
||||||
self.show_zed_ai_notice = show_zed_ai_notice;
|
self.show_zed_ai_notice = show_zed_ai_notice;
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -17,7 +17,6 @@ assistant_tool.workspace = true
|
|||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
component.workspace = true
|
component.workspace = true
|
||||||
feature_flags.workspace = true
|
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
html_to_markdown.workspace = true
|
html_to_markdown.workspace = true
|
||||||
@ -41,6 +40,8 @@ worktree.workspace = true
|
|||||||
zed_llm_client.workspace = true
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
client = { workspace = true, features = ["test-support"] }
|
||||||
|
clock = { workspace = true, features = ["test-support"] }
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
language = { workspace = true, features = ["test-support"] }
|
language = { workspace = true, features = ["test-support"] }
|
||||||
|
@ -29,9 +29,9 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use assistant_tool::ToolRegistry;
|
use assistant_tool::ToolRegistry;
|
||||||
use copy_path_tool::CopyPathTool;
|
use copy_path_tool::CopyPathTool;
|
||||||
use feature_flags::FeatureFlagAppExt;
|
|
||||||
use gpui::App;
|
use gpui::App;
|
||||||
use http_client::HttpClientWithUrl;
|
use http_client::HttpClientWithUrl;
|
||||||
|
use language_model::LanguageModelRegistry;
|
||||||
use move_path_tool::MovePathTool;
|
use move_path_tool::MovePathTool;
|
||||||
use web_search_tool::WebSearchTool;
|
use web_search_tool::WebSearchTool;
|
||||||
|
|
||||||
@ -85,34 +85,45 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
|||||||
registry.register_tool(ThinkingTool);
|
registry.register_tool(ThinkingTool);
|
||||||
registry.register_tool(FetchTool::new(http_client));
|
registry.register_tool(FetchTool::new(http_client));
|
||||||
|
|
||||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
cx.subscribe(
|
||||||
move |is_enabled, cx| {
|
&LanguageModelRegistry::global(cx),
|
||||||
if is_enabled {
|
move |registry, event, cx| match event {
|
||||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
language_model::Event::DefaultModelChanged => {
|
||||||
} else {
|
let using_zed_provider = registry
|
||||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
.read(cx)
|
||||||
|
.default_model()
|
||||||
|
.map_or(false, |default| default.is_provided_by_zed());
|
||||||
|
if using_zed_provider {
|
||||||
|
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||||
|
} else {
|
||||||
|
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
_ => {}
|
||||||
})
|
},
|
||||||
|
)
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use client::Client;
|
||||||
|
use clock::FakeSystemClock;
|
||||||
use http_client::FakeHttpClient;
|
use http_client::FakeHttpClient;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||||
crate::init(
|
settings::init(cx);
|
||||||
Arc::new(http_client::HttpClientWithUrl::new(
|
|
||||||
FakeHttpClient::with_200_response(),
|
let client = Client::new(
|
||||||
"https://zed.dev",
|
Arc::new(FakeSystemClock::new()),
|
||||||
None,
|
FakeHttpClient::with_200_response(),
|
||||||
)),
|
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
crate::init(client.http_client(), cx);
|
||||||
|
|
||||||
for tool in ToolRegistry::global(cx).tools() {
|
for tool in ToolRegistry::global(cx).tools() {
|
||||||
let actual_schema = tool
|
let actual_schema = tool
|
||||||
|
@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
|
|||||||
const NAME: &'static str = "zed-pro";
|
const NAME: &'static str = "zed-pro";
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ZedProWebSearchTool {}
|
|
||||||
impl FeatureFlag for ZedProWebSearchTool {
|
|
||||||
const NAME: &'static str = "zed-pro-web-search-tool";
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct NotebookFeatureFlag;
|
pub struct NotebookFeatureFlag;
|
||||||
|
|
||||||
impl FeatureFlag for NotebookFeatureFlag {
|
impl FeatureFlag for NotebookFeatureFlag {
|
||||||
|
@ -42,6 +42,10 @@ impl ConfiguredModel {
|
|||||||
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
|
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
|
||||||
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
|
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_provided_by_zed(&self) -> bool {
|
||||||
|
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Event {
|
pub enum Event {
|
||||||
|
@ -61,4 +61,11 @@ impl WebSearchRegistry {
|
|||||||
self.active_provider = Some(provider);
|
self.active_provider = Some(provider);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn unregister_provider(&mut self, id: WebSearchProviderId) {
|
||||||
|
self.providers.remove(&id);
|
||||||
|
if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) {
|
||||||
|
self.active_provider = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ path = "src/web_search_providers.rs"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
feature_flags.workspace = true
|
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
|
@ -50,9 +50,11 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
|
||||||
|
|
||||||
impl WebSearchProvider for CloudWebSearchProvider {
|
impl WebSearchProvider for CloudWebSearchProvider {
|
||||||
fn id(&self) -> WebSearchProviderId {
|
fn id(&self) -> WebSearchProviderId {
|
||||||
WebSearchProviderId("zed.dev".into())
|
WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
mod cloud;
|
mod cloud;
|
||||||
|
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
|
||||||
use gpui::{App, Context};
|
use gpui::{App, Context};
|
||||||
|
use language_model::LanguageModelRegistry;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use web_search::WebSearchRegistry;
|
use web_search::{WebSearchProviderId, WebSearchRegistry};
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||||
let registry = WebSearchRegistry::global(cx);
|
let registry = WebSearchRegistry::global(cx);
|
||||||
@ -18,18 +18,27 @@ fn register_web_search_providers(
|
|||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
cx: &mut Context<WebSearchRegistry>,
|
cx: &mut Context<WebSearchRegistry>,
|
||||||
) {
|
) {
|
||||||
cx.observe_flag::<ZedProWebSearchTool, _>({
|
cx.subscribe(
|
||||||
let client = client.clone();
|
&LanguageModelRegistry::global(cx),
|
||||||
move |is_enabled, cx| {
|
move |this, registry, event, cx| match event {
|
||||||
if is_enabled {
|
language_model::Event::DefaultModelChanged => {
|
||||||
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
let using_zed_provider = registry
|
||||||
registry.register_provider(
|
.read(cx)
|
||||||
|
.default_model()
|
||||||
|
.map_or(false, |default| default.is_provided_by_zed());
|
||||||
|
if using_zed_provider {
|
||||||
|
this.register_provider(
|
||||||
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||||
cx,
|
cx,
|
||||||
);
|
)
|
||||||
});
|
} else {
|
||||||
|
this.unregister_provider(WebSearchProviderId(
|
||||||
|
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
_ => {}
|
||||||
})
|
},
|
||||||
|
)
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user