From 40008b65138d4760923cec68d1002805608b89ce Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 06:00:48 +0000 Subject: [PATCH] wip: grok provider abstraction --- .../crates/api/src/providers/openai_compat.rs | 31 ++++++++++++++-- .../api/tests/openai_compat_integration.rs | 35 +++++++++++++++++++ rust/crates/rusty-claude-cli/src/main.rs | 3 +- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 8a0fe9c..e8210ae 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -185,7 +185,7 @@ impl OpenAiCompatClient { &self, request: &MessageRequest, ) -> Result { - let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") @@ -866,6 +866,15 @@ pub fn read_base_url(config: OpenAiCompatConfig) -> String { std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) } +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { headers .get(REQUEST_ID_HEADER) @@ -927,8 +936,8 @@ impl StringExt for String { #[cfg(test)] mod tests { use super::{ - build_chat_completion_request, normalize_finish_reason, openai_tool_choice, - parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, }; use crate::error::ApiError; use crate::types::{ @@ -1010,6 +1019,22 @@ mod tests { )); } + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index 81a65f4..b345b1f 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -62,6 +62,41 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() { assert_eq!(body["tools"][0]["type"], json!("function")); } +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + #[tokio::test] async fn stream_message_normalizes_text_and_multiple_tool_calls() { let state = Arc::new(Mutex::new(Vec::::new())); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index dee54d9..847f94f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1907,13 +1907,14 @@ fn build_runtime( permission_mode: PermissionMode, ) -> Result, Box> { + let feature_config = build_runtime_feature_config()?; Ok(ConversationRuntime::new_with_features( session, ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode), system_prompt, - build_runtime_feature_config()?, + &feature_config, )) }