diff --git a/Cargo.lock b/Cargo.lock index c49a71c5eb..52398176d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -324,7 +324,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "workspace-hack", ] @@ -567,7 +567,7 @@ dependencies = [ "settings", "smallvec", "smol", - "strum 0.26.3", + "strum 0.27.1", "telemetry_events", "text", "theme", @@ -1884,7 +1884,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "tokio", "workspace-hack", @@ -3031,7 +3031,7 @@ dependencies = [ "settings", "sha2", "sqlx", - "strum 0.26.3", + "strum 0.27.1", "subtle", "supermaven_api", "telemetry_events", @@ -3051,6 +3051,7 @@ dependencies = [ "workspace", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -3363,7 +3364,7 @@ dependencies = [ "serde", "serde_json", "settings", - "strum 0.26.3", + "strum 0.27.1", "task", "theme", "ui", @@ -5125,7 +5126,7 @@ dependencies = [ "serde", "settings", "smallvec", - "strum 0.26.3", + "strum 0.27.1", "telemetry", "theme", "ui", @@ -5976,7 +5977,7 @@ dependencies = [ "serde_derive", "serde_json", "settings", - "strum 0.26.3", + "strum 0.27.1", "telemetry", "theme", "time", @@ -6069,7 +6070,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -6175,7 +6176,7 @@ dependencies = [ "slotmap", "smallvec", "smol", - "strum 0.26.3", + "strum 0.27.1", "sum_tree", "taffy", "thiserror 2.0.12", @@ -6823,7 +6824,7 @@ name = "icons" version = "0.1.0" dependencies = [ "serde", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -7091,7 +7092,7 @@ dependencies = [ "paths", "pretty_assertions", "serde", - "strum 0.26.3", + "strum 0.27.1", "util", "workspace-hack", ] @@ -7677,7 +7678,7 @@ dependencies = [ "serde", "serde_json", "smol", - "strum 0.26.3", + "strum 0.27.1", "telemetry_events", "thiserror 2.0.12", "util", @@ -7737,7 +7738,7 @@ dependencies = [ "serde_json", "settings", "smol", - "strum 0.26.3", + "strum 0.27.1", "theme", "thiserror 2.0.12", "tiktoken-rs", @@ -8710,7 +8711,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -9557,7 +9558,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "strum 0.26.3", + "strum 0.27.1", "workspace-hack", ] @@ -12136,7 +12137,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "strum 0.26.3", + "strum 0.27.1", "tracing", "util", "workspace-hack", @@ -13709,7 +13710,7 @@ dependencies = [ "settings", "simplelog", "story", - "strum 0.26.3", + "strum 0.27.1", "theme", "title_bar", "ui", @@ -14444,7 +14445,7 @@ dependencies = [ "serde_json_lenient", "serde_repr", "settings", - "strum 0.26.3", + "strum 0.27.1", "thiserror 2.0.12", "util", "uuid", @@ -14478,7 +14479,7 @@ dependencies = [ "serde_json", "serde_json_lenient", "simplelog", - "strum 0.26.3", + "strum 0.27.1", "theme", "vscode_theme", "workspace-hack", @@ -15479,7 +15480,7 @@ dependencies = [ "settings", "smallvec", "story", - "strum 0.26.3", + "strum 0.27.1", "theme", "ui_macros", "util", @@ -17680,7 +17681,7 @@ dependencies = [ "settings", "smallvec", "sqlez", - "strum 0.26.3", + "strum 0.27.1", "task", "telemetry", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index b002855337..099ea3ddad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -540,7 +540,7 @@ smol = "2.0" sqlformat = "0.2" streaming-iterator = "0.1" strsim = "0.11" -strum = { version = "0.26.0", features = ["derive"] } +strum = { version = "0.27.0", features = ["derive"] } subtle = "2.5.0" syn = { version = "1.0.72", features = ["full", "extra-traits"] } sys-locale = "0.3.1" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index c4aa90e2c2..4d6787f857 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -75,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] assistant = { workspace = true, features = ["test-support"] } diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 36843ced56..bc2b508e84 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -330,8 +330,10 @@ async fn create_billing_subscription( .await? } None => { - let default_model = - llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; + let default_model = llm_db.model( + zed_llm_client::LanguageModelProvider::Anthropic, + "claude-3-7-sonnet", + )?; let stripe_model = stripe_billing.register_model(default_model).await?; stripe_billing .checkout(customer_id, &user.github_login, &stripe_model, &success_url) diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index f56e9e61e3..e445450ff4 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -8,9 +8,9 @@ mod tests; use collections::HashMap; pub use ids::*; -use rpc::LanguageModelProvider; pub use seed::*; pub use tables::*; +use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 0bb55ee4b6..7d52964b93 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; +use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 32dc5f3f99..ee4c8540d8 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,9 +1,6 @@ use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; -use client::{ - Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, - PerformCompletionParams, UserStore, zed_urls, -}; +use client::{Client, UserStore, zed_urls}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro}; use futures::{ @@ -26,7 +23,6 @@ use language_model::{ use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; use smol::Timer; use smol::io::{AsyncReadExt, BufReader}; @@ -38,7 +34,10 @@ use std::{ use strum::IntoEnumIterator; use thiserror::Error; use ui::{TintColor, prelude::*}; -use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME}; +use zed_llm_client::{ + CURRENT_PLAN_HEADER_NAME, CompletionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, + MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, +}; use crate::AllLanguageModelSettings; use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; @@ -517,7 +516,7 @@ impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, - body: PerformCompletionParams, + body: CompletionBody, ) -> Result> { let http_client = &client.http_client(); @@ -724,12 +723,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Anthropic, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await @@ -765,12 +762,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::OpenAi, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await?; @@ -790,12 +785,10 @@ impl LanguageModel for CloudLanguageModel { let response = Self::perform_llm_completion( client.clone(), llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Google, + CompletionBody { + provider: zed_llm_client::LanguageModelProvider::Google, model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, + provider_request: serde_json::to_value(&request)?, }, ) .await?; diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs deleted file mode 100644 index 0a7510d891..0000000000 --- a/crates/rpc/src/llm.rs +++ /dev/null @@ -1,35 +0,0 @@ -use serde::{Deserialize, Serialize}; -use strum::{Display, EnumIter, EnumString}; - -pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; - -pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached"; - -#[derive( - Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, -)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum LanguageModelProvider { - Anthropic, - OpenAi, - Google, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LanguageModel { - pub provider: LanguageModelProvider, - pub name: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListModelsResponse { - pub models: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PerformCompletionParams { - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: Box, -} diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index c10ee9d6c8..ad1ebb757c 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -1,14 +1,12 @@ pub mod auth; mod conn; mod extension; -mod llm; mod message_stream; mod notification; mod peer; pub use conn::Connection; pub use extension::*; -pub use llm::*; pub use notification::*; pub use peer::*; pub use proto;