language_models: Count Google AI tokens through LLM service (#29319)
This PR wires the counting of Google AI tokens back up. It now goes through the LLM service instead of collab's RPC. Still only available for Zed staff. Release Notes: - N/A
This commit is contained in:
parent
8b5835de17
commit
fef2681cfa
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -18536,9 +18536,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.7.0"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4"
|
||||
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
|
@ -606,7 +606,7 @@ wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.7.0"
|
||||
zed_llm_client = "0.7.1"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
|
@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
|
||||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use zed_llm_client::{
|
||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
|
||||
match self.model.clone() {
|
||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
let request_builder =
|
||||
if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
|
||||
request_builder.uri(completions_url)
|
||||
} else {
|
||||
request_builder.uri(
|
||||
http_client
|
||||
.build_zed_llm_url("/count_tokens", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
};
|
||||
let request_body = CountTokensBody {
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: model.id().into(),
|
||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
})?,
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&request_body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
let mut response_body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut response_body)
|
||||
.await?;
|
||||
|
||||
if status.is_success() {
|
||||
let response_body: CountTokensResponse =
|
||||
serde_json::from_str(&response_body)?;
|
||||
|
||||
Ok(response_body.tokens)
|
||||
} else {
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
body: response_body
|
||||
}))
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user