From cbc0a83059a93a99d88e3b83c4c200546f31ed27 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:01:37 +0000 Subject: [PATCH] auto: save WIP progress from rcc session --- rust/crates/api/src/error.rs | 43 +- rust/crates/api/src/providers/anthropic.rs | 994 +++++++++++++++++++++ rust/crates/api/src/providers/mod.rs | 202 +++++ 3 files changed, 1218 insertions(+), 21 deletions(-) create mode 100644 rust/crates/api/src/providers/anthropic.rs create mode 100644 rust/crates/api/src/providers/mod.rs diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 2c31691..7649889 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -4,7 +4,10 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { - MissingApiKey, + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -30,13 +33,21 @@ pub enum ApiError { } impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + #[must_use] pub fn is_retryable(&self) -> bool { match self { Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), - Self::MissingApiKey + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -51,12 +62,11 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingApiKey => { - write!( - f, - "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" - ) - } + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), Self::ExpiredOAuthToken => { write!( f, @@ -65,10 +75,7 @@ impl Display for ApiError { } Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { - write!( - f, - "failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}" - ) + write!(f, "failed to read credential environment variable: {error}") } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), @@ -81,20 +88,14 @@ impl Display for ApiError { .. } => match (error_type, message) { (Some(error_type), Some(message)) => { - write!( - f, - "anthropic api returned {status} ({error_type}): {message}" - ) + write!(f, "api returned {status} ({error_type}): {message}") } - _ => write!(f, "anthropic api returned {status}: {body}"), + _ => write!(f, "api returned {status}: {body}"), }, Self::RetriesExhausted { attempts, last_error, - } => write!( - f, - "anthropic api failed after {attempts} attempts: {last_error}" - ), + } => write!(f, "api failed after {attempts} attempts: {last_error}"), Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), Self::BackoffOverflow { attempt, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs new file mode 100644 index 0000000..4f6dd98 --- /dev/null +++ b/rust/crates/api/src/providers/anthropic.rs @@ -0,0 +1,994 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct AnthropicClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl AnthropicClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorEnvelope { + error: AnthropicErrorBody, +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = AnthropicClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..cf891cc --- /dev/null +++ b/rust/crates/api/src/providers/mod.rs @@ -0,0 +1,202 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod anthropic; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + Anthropic, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub canonical_model: &'static str, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-opus-4-6", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-sonnet-4-6", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-haiku-4-5-20251213", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3-mini", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3-mini", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-2", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model)) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + if canonical.starts_with("claude") { + return Some(ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: Box::leak(canonical.into_boxed_str()), + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }); + } + if canonical.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: Box::leak(canonical.into_boxed_str()), + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if anthropic::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::Anthropic; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::Anthropic +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +}