wip: grok provider abstraction

This commit is contained in:
Yeachan-Heo
2026-04-01 06:00:48 +00:00
parent f477dde4a6
commit 40008b6513
3 changed files with 65 additions and 4 deletions

View File

@@ -185,7 +185,7 @@ impl OpenAiCompatClient {
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let request_url = chat_completions_endpoint(&self.base_url);
self.http
.post(&request_url)
.header("content-type", "application/json")
@@ -866,6 +866,15 @@ pub fn read_base_url(config: OpenAiCompatConfig) -> String {
std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
}
fn chat_completions_endpoint(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
if trimmed.ends_with("/chat/completions") {
trimmed.to_string()
} else {
format!("{trimmed}/chat/completions")
}
}
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
headers
.get(REQUEST_ID_HEADER)
@@ -927,8 +936,8 @@ impl StringExt for String {
#[cfg(test)]
mod tests {
use super::{
build_chat_completion_request, normalize_finish_reason, openai_tool_choice,
parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
};
use crate::error::ApiError;
use crate::types::{
@@ -1010,6 +1019,22 @@ mod tests {
));
}
#[test]
fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1"),
"https://api.x.ai/v1/chat/completions"
);
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1/"),
"https://api.x.ai/v1/chat/completions"
);
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
"https://api.x.ai/v1/chat/completions"
);
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))

View File

@@ -62,6 +62,41 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
assert_eq!(body["tools"][0]["type"], json!("function"));
}
#[tokio::test]
async fn send_message_accepts_full_chat_completions_endpoint_override() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"chatcmpl_full_endpoint\",",
"\"model\":\"grok-3\",",
"\"choices\":[{",
"\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
"\"finish_reason\":\"stop\"",
"}],",
"\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
"}"
);
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", body)],
)
.await;
let endpoint_url = format!("{}/chat/completions", server.base_url());
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
.with_base_url(endpoint_url);
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.total_tokens(), 10);
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert_eq!(request.path, "/chat/completions");
}
#[tokio::test]
async fn stream_message_normalizes_text_and_multiple_tool_calls() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));