language_models: Count Google AI tokens through LLM service (#29319)
This PR wires the counting of Google AI tokens back up. It now goes through the LLM service instead of collab's RPC. Still only available for Zed staff. Release Notes: - N/A
This commit is contained in:
parent
8b5835de17
commit
fef2681cfa
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -18536,9 +18536,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.7.0"
|
version = "0.7.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4"
|
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -606,7 +606,7 @@ wasmtime-wasi = "29"
|
|||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
wit-component = "0.221"
|
wit-component = "0.221"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.7.0"
|
zed_llm_client = "0.7.1"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
metal = "0.29"
|
metal = "0.29"
|
||||||
|
|
||||||
|
@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ui::{TintColor, prelude::*};
|
use ui::{TintColor, prelude::*};
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
|
||||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||||
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
match self.model.clone() {
|
match self.model.clone() {
|
||||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||||
CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
|
CloudModel::Google(model) => {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
let request = into_google(request, model.id().into());
|
||||||
|
async move {
|
||||||
|
let http_client = &client.http_client();
|
||||||
|
let token = llm_api_token.acquire(&client).await?;
|
||||||
|
|
||||||
|
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||||
|
let request_builder =
|
||||||
|
if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
|
||||||
|
request_builder.uri(completions_url)
|
||||||
|
} else {
|
||||||
|
request_builder.uri(
|
||||||
|
http_client
|
||||||
|
.build_zed_llm_url("/count_tokens", &[])?
|
||||||
|
.as_ref(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let request_body = CountTokensBody {
|
||||||
|
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||||
|
model: model.id().into(),
|
||||||
|
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||||
|
contents: request.contents,
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let request = request_builder
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(serde_json::to_string(&request_body)?.into())?;
|
||||||
|
let mut response = http_client.send(request).await?;
|
||||||
|
let status = response.status();
|
||||||
|
let mut response_body = String::new();
|
||||||
|
response
|
||||||
|
.body_mut()
|
||||||
|
.read_to_string(&mut response_body)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
let response_body: CountTokensResponse =
|
||||||
|
serde_json::from_str(&response_body)?;
|
||||||
|
|
||||||
|
Ok(response_body.tokens)
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(ApiError {
|
||||||
|
status,
|
||||||
|
body: response_body
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user