language_models: Count Google AI tokens through LLM service (#29319)

This PR wires the counting of Google AI tokens back up.

It now goes through the LLM service instead of collab's RPC.

Still only available for Zed staff.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-23 21:21:53 -04:00 committed by GitHub
parent 8b5835de17
commit fef2681cfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 7 deletions

4
Cargo.lock generated
View File

@ -18536,9 +18536,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.7.0"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4"
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
dependencies = [
"anyhow",
"serde",

View File

@ -606,7 +606,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.7.0"
zed_llm_client = "0.7.1"
zstd = "0.11"
metal = "0.29"

View File

@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{TintColor, prelude::*};
use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
};
use crate::AllLanguageModelSettings;
@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
match self.model.clone() {
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
CloudModel::Google(model) => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let request = into_google(request, model.id().into());
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder =
if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
request_builder.uri(completions_url)
} else {
request_builder.uri(
http_client
.build_zed_llm_url("/count_tokens", &[])?
.as_ref(),
)
};
let request_body = CountTokensBody {
provider: zed_llm_client::LanguageModelProvider::Google,
model: model.id().into(),
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
contents: request.contents,
})?,
};
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
let mut response_body = String::new();
response
.body_mut()
.read_to_string(&mut response_body)
.await?;
if status.is_success() {
let response_body: CountTokensResponse =
serde_json::from_str(&response_body)?;
Ok(response_body.tokens)
} else {
Err(anyhow!(ApiError {
status,
body: response_body
}))
}
}
.boxed()
}
}
}