From fef2681cfaa9fcaa72436a2ca9da79f84f5bee4c Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 23 Apr 2025 21:21:53 -0400 Subject: [PATCH] language_models: Count Google AI tokens through LLM service (#29319) This PR wires the counting of Google AI tokens back up. It now goes through the LLM service instead of collab's RPC. Still only available for Zed staff. Release Notes: - N/A --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/language_models/src/provider/cloud.rs | 59 ++++++++++++++++++-- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0cae49fa66..09d3090251 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18536,9 +18536,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4" +checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index ba9fc3673d..00cabcc79b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -606,7 +606,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.7.0" +zed_llm_client = "0.7.1" zstd = "0.11" metal = "0.29" diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index b8bc86d406..4c581194d3 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -35,9 +35,9 @@ use strum::IntoEnumIterator; use thiserror::Error; use ui::{TintColor, prelude::*}; use zed_llm_client::{ - CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME, - MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, - SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse, + EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, + MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, }; use crate::AllLanguageModelSettings; @@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel { match self.model.clone() { CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), - CloudModel::Google(_model) => async move { Ok(0) }.boxed(), + CloudModel::Google(model) => { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + let request = into_google(request, model.id().into()); + async move { + let http_client = &client.http_client(); + let token = llm_api_token.acquire(&client).await?; + + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = + if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") { + request_builder.uri(completions_url) + } else { + request_builder.uri( + http_client + .build_zed_llm_url("/count_tokens", &[])? + .as_ref(), + ) + }; + let request_body = CountTokensBody { + provider: zed_llm_client::LanguageModelProvider::Google, + model: model.id().into(), + provider_request: serde_json::to_value(&google_ai::CountTokensRequest { + contents: request.contents, + })?, + }; + let request = request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&request_body)?.into())?; + let mut response = http_client.send(request).await?; + let status = response.status(); + let mut response_body = String::new(); + response + .body_mut() + .read_to_string(&mut response_body) + .await?; + + if status.is_success() { + let response_body: CountTokensResponse = + serde_json::from_str(&response_body)?; + + Ok(response_body.tokens) + } else { + Err(anyhow!(ApiError { + status, + body: response_body + })) + } + } + .boxed() + } } }