From f477dde4a6c4ab76253c3d7b96d78183f0a94c0a Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 05:45:27 +0000 Subject: [PATCH] feat: provider tests + grok integration --- rust/crates/api/src/client.rs | 9 +- rust/crates/api/tests/client_integration.rs | 47 +++++++++- .../api/tests/openai_compat_integration.rs | 70 ++++++++++++++- .../api/tests/provider_client_integration.rs | 86 +++++++++++++++++++ rust/crates/runtime/src/conversation.rs | 12 +-- rust/crates/runtime/src/hooks.rs | 62 +++++++------ 6 files changed, 244 insertions(+), 42 deletions(-) create mode 100644 rust/crates/api/tests/provider_client_integration.rs diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 467697e..a4ac1c0 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -36,11 +36,10 @@ impl ProviderClient { ) -> Result { let resolved_model = providers::resolve_model_alias(model); match providers::detect_provider_kind(&resolved_model) { - ProviderKind::Anthropic => Ok(Self::Anthropic( - anthropic_auth - .map(AnthropicClient::from_auth) - .unwrap_or(AnthropicClient::from_env()?), - )), + ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth { + Some(auth) => AnthropicClient::from_auth(auth), + None => AnthropicClient::from_env()?, + })), ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( OpenAiCompatConfig::xai(), )?)), diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index c37fa99..b52f890 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use std::time::Duration; use api::{ - AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + AnthropicClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() { assert_eq!(state.lock().await.len(), 2); } +#[tokio::test] +async fn provider_client_dispatches_anthropic_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("anthropic provider client should be constructed"); + let client = match client { + ProviderClient::Anthropic(client) => { + ProviderClient::Anthropic(client.with_base_url(server.base_url())) + } + other => panic!("expected anthropic provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + #[tokio::test] async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { let state = Arc::new(Mutex::new(Vec::::new())); diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index b1b6a0a..81a65f4 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; +use std::ffi::OsString; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use api::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, - OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -158,6 +160,43 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() { assert!(request.body.contains("\"stream\":true")); } +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + #[derive(Debug, Clone, PartialEq, Eq)] struct CapturedRequest { path: String, @@ -310,3 +349,32 @@ fn sample_request(stream: bool) -> MessageRequest { stream, } } + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/api/tests/provider_client_integration.rs b/rust/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..204bf35 --- /dev/null +++ b/rust/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() { + let _lock = env_lock(); + let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("anthropic-test-key".to_string())), + ) + .expect("explicit anthropic auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::Anthropic); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 4ffbabc..1abdce4 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -118,7 +118,7 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } @@ -129,7 +129,7 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -140,7 +140,7 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), + hook_runner: HookRunner::from_feature_config(feature_config), } } @@ -609,7 +609,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), )), @@ -675,7 +675,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], )), @@ -697,7 +697,7 @@ mod tests { "post hook should preserve non-error result: {output:?}" ); assert!( - output.contains("4"), + output.contains('4'), "tool output missing value: {output:?}" ); assert!( diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 36756a0..63ef9ff 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -51,6 +51,16 @@ pub struct HookRunner { config: RuntimeHookConfig, } +#[derive(Debug, Clone, Copy)] +struct HookCommandRequest<'a> { + event: HookEvent, + tool_name: &'a str, + tool_input: &'a str, + tool_output: Option<&'a str>, + is_error: bool, + payload: &'a str, +} + impl HookRunner { #[must_use] pub fn new(config: RuntimeHookConfig) -> Self { @@ -118,14 +128,16 @@ impl HookRunner { let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, - event, - tool_name, - tool_input, - tool_output, - is_error, - &payload, + HookCommandRequest { + event, + tool_name, + tool_input, + tool_output, + is_error, + payload: &payload, + }, ) { HookCommandOutcome::Allow { message } => { if let Some(message) = message { @@ -149,29 +161,23 @@ impl HookRunner { HookRunResult::allow(messages) } - fn run_command( - &self, - command: &str, - event: HookEvent, - tool_name: &str, - tool_input: &str, - tool_output: Option<&str>, - is_error: bool, - payload: &str, - ) -> HookCommandOutcome { + fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { let mut child = shell_command(command); child.stdin(std::process::Stdio::piped()); child.stdout(std::process::Stdio::piped()); child.stderr(std::process::Stdio::piped()); - child.env("HOOK_EVENT", event.as_str()); - child.env("HOOK_TOOL_NAME", tool_name); - child.env("HOOK_TOOL_INPUT", tool_input); - child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); - if let Some(tool_output) = tool_output { + child.env("HOOK_EVENT", request.event.as_str()); + child.env("HOOK_TOOL_NAME", request.tool_name); + child.env("HOOK_TOOL_INPUT", request.tool_input); + child.env( + "HOOK_TOOL_IS_ERROR", + if request.is_error { "1" } else { "0" }, + ); + if let Some(tool_output) = request.tool_output { child.env("HOOK_TOOL_OUTPUT", tool_output); } - match child.output_with_stdin(payload.as_bytes()) { + match child.output_with_stdin(request.payload.as_bytes()) { Ok(output) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); @@ -189,16 +195,18 @@ impl HookRunner { }, None => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` terminated by signal while handling `{tool_name}`", - event.as_str() + "{} hook `{command}` terminated by signal while handling `{}`", + request.event.as_str(), + request.tool_name ), }, } } Err(error) => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` failed to start for `{tool_name}`: {error}", - event.as_str() + "{} hook `{command}` failed to start for `{}`: {error}", + request.event.as_str(), + request.tool_name ), }, }