feat: provider tests + grok integration

This commit is contained in:
Yeachan-Heo
2026-04-01 05:45:27 +00:00
parent 178934a9a0
commit f477dde4a6
6 changed files with 244 additions and 42 deletions

View File

@@ -3,9 +3,9 @@ use std::sync::Arc;
use std::time::Duration;
use api::{
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
StreamEvent, ToolChoice, ToolDefinition,
AnthropicClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() {
assert_eq!(state.lock().await.len(), 2);
}
#[tokio::test]
async fn provider_client_dispatches_anthropic_requests() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
)],
)
.await;
let client = ProviderClient::from_model_with_anthropic_auth(
"claude-sonnet-4-6",
Some(AuthSource::ApiKey("test-key".to_string())),
)
.expect("anthropic provider client should be constructed");
let client = match client {
ProviderClient::Anthropic(client) => {
ProviderClient::Anthropic(client.with_base_url(server.base_url()))
}
other => panic!("expected anthropic provider, got {other:?}"),
};
let response = client
.send_message(&sample_request(false))
.await
.expect("provider-dispatched request should succeed");
assert_eq!(response.total_tokens(), 5);
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert_eq!(request.path, "/v1/messages");
assert_eq!(
request.headers.get("x-api-key").map(String::as_str),
Some("test-key")
);
}
#[tokio::test]
async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));

View File

@@ -1,10 +1,12 @@
use std::collections::HashMap;
use std::ffi::OsString;
use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use api::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -158,6 +160,43 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
assert!(request.body.contains("\"stream\":true"));
}
#[tokio::test]
async fn provider_client_dispatches_xai_requests_from_env() {
let _lock = env_lock();
let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key");
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}",
)],
)
.await;
let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url());
let client =
ProviderClient::from_model("grok").expect("xAI provider client should be constructed");
assert!(matches!(client, ProviderClient::Xai(_)));
let response = client
.send_message(&sample_request(false))
.await
.expect("provider-dispatched request should succeed");
assert_eq!(response.total_tokens(), 13);
let captured = state.lock().await;
let request = captured.first().expect("captured request");
assert_eq!(request.path, "/chat/completions");
assert_eq!(
request.headers.get("authorization").map(String::as_str),
Some("Bearer xai-test-key")
);
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct CapturedRequest {
path: String,
@@ -310,3 +349,32 @@ fn sample_request(stream: bool) -> MessageRequest {
stream,
}
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
struct ScopedEnvVar {
key: &'static str,
previous: Option<OsString>,
}
impl ScopedEnvVar {
fn set(key: &'static str, value: impl AsRef<std::ffi::OsStr>) -> Self {
let previous = std::env::var_os(key);
std::env::set_var(key, value);
Self { key, previous }
}
}
impl Drop for ScopedEnvVar {
fn drop(&mut self) {
match &self.previous {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}

View File

@@ -0,0 +1,86 @@
use std::ffi::OsString;
use std::sync::{Mutex, OnceLock};
use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind};
#[test]
fn provider_client_routes_grok_aliases_through_xai() {
let _lock = env_lock();
let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key"));
let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve");
assert_eq!(client.provider_kind(), ProviderKind::Xai);
}
#[test]
fn provider_client_reports_missing_xai_credentials_for_grok_models() {
let _lock = env_lock();
let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None);
let error = ProviderClient::from_model("grok-3")
.expect_err("grok requests without XAI_API_KEY should fail fast");
match error {
ApiError::MissingCredentials { provider, env_vars } => {
assert_eq!(provider, "xAI");
assert_eq!(env_vars, &["XAI_API_KEY"]);
}
other => panic!("expected missing xAI credentials, got {other:?}"),
}
}
#[test]
fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() {
let _lock = env_lock();
let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
let client = ProviderClient::from_model_with_anthropic_auth(
"claude-sonnet-4-6",
Some(AuthSource::ApiKey("anthropic-test-key".to_string())),
)
.expect("explicit anthropic auth should avoid env lookup");
assert_eq!(client.provider_kind(), ProviderKind::Anthropic);
}
#[test]
fn read_xai_base_url_prefers_env_override() {
let _lock = env_lock();
let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1"));
assert_eq!(read_xai_base_url(), "https://example.xai.test/v1");
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: Option<&str>) -> Self {
let original = std::env::var_os(key);
match value {
Some(value) => std::env::set_var(key, value),
None => std::env::remove_var(key),
}
Self { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}