feat: provider abstraction layer + Grok API support

This commit is contained in:
Yeachan-Heo
2026-04-01 04:10:46 +00:00
parent cbc0a83059
commit 2a0f4b677a
8 changed files with 1547 additions and 999 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,19 @@
mod client;
mod error;
mod providers;
mod sse;
mod types;
pub use client::{
oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source,
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token,
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
};
pub use error::ApiError;
pub use providers::anthropic::{AnthropicClient, AuthSource};
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
pub use providers::{
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
};
pub use sse::{parse_frame, SseParser};
pub use types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,

View File

@@ -8,10 +8,12 @@ use runtime::{
use serde::Deserialize;
use crate::error::ApiError;
use super::{Provider, ProviderFuture};
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const REQUEST_ID_HEADER: &str = "request-id";
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
@@ -41,7 +43,10 @@ impl AuthSource {
}),
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
(None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
(None, None) => Err(ApiError::missing_credentials(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
)),
}
}
@@ -362,7 +367,10 @@ impl AuthSource {
}
}
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
Ok(None) => Err(ApiError::missing_credentials(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
)),
Err(error) => Err(error),
}
}
@@ -382,6 +390,12 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
resolve_saved_oauth_token_set(config, token_set).map(Some)
}
pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
|| read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
|| load_saved_oauth_token()?.is_some())
}
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
where
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
@@ -400,7 +414,10 @@ where
}
let Some(token_set) = load_saved_oauth_token()? else {
return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]));
return Err(ApiError::missing_credentials(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
));
};
if !oauth_token_is_expired(&token_set) {
return Ok(AuthSource::BearerToken(token_set.access_token));
@@ -497,7 +514,10 @@ fn read_api_key() -> Result<String, ApiError> {
auth.api_key()
.or_else(|| auth.bearer_token())
.map(ToOwned::to_owned)
.ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]))
.ok_or(ApiError::missing_credentials(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
))
}
#[cfg(test)]
@@ -520,6 +540,24 @@ fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<Strin
.map(ToOwned::to_owned)
}
impl Provider for AnthropicClient {
type Stream = MessageStream;
fn send_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, MessageResponse> {
Box::pin(async move { self.send_message(request).await })
}
fn stream_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, Self::Stream> {
Box::pin(async move { self.stream_message(request).await })
}
}
#[derive(Debug)]
pub struct MessageStream {
request_id: Option<String>,
@@ -673,7 +711,10 @@ mod tests {
std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME");
let error = super::read_api_key().expect_err("missing key should error");
assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
assert!(matches!(
error,
crate::error::ApiError::MissingCredentials { .. }
));
}
#[test]
@@ -682,7 +723,10 @@ mod tests {
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error");
assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
assert!(matches!(
error,
crate::error::ApiError::MissingCredentials { .. }
));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
}

View File

@@ -12,9 +12,15 @@ pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>
pub trait Provider {
type Stream;
fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>;
fn send_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, MessageResponse>;
fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>;
fn stream_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, Self::Stream>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -27,7 +33,6 @@ pub enum ProviderKind {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ProviderMetadata {
pub provider: ProviderKind,
pub canonical_model: &'static str,
pub auth_env: &'static str,
pub base_url_env: &'static str,
pub default_base_url: &'static str,
@@ -38,7 +43,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"opus",
ProviderMetadata {
provider: ProviderKind::Anthropic,
canonical_model: "claude-opus-4-6",
auth_env: "ANTHROPIC_API_KEY",
base_url_env: "ANTHROPIC_BASE_URL",
default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -48,7 +52,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"sonnet",
ProviderMetadata {
provider: ProviderKind::Anthropic,
canonical_model: "claude-sonnet-4-6",
auth_env: "ANTHROPIC_API_KEY",
base_url_env: "ANTHROPIC_BASE_URL",
default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -58,7 +61,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"haiku",
ProviderMetadata {
provider: ProviderKind::Anthropic,
canonical_model: "claude-haiku-4-5-20251213",
auth_env: "ANTHROPIC_API_KEY",
base_url_env: "ANTHROPIC_BASE_URL",
default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -68,7 +70,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"grok",
ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: "grok-3",
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -78,7 +79,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"grok-3",
ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: "grok-3",
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -88,7 +88,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"grok-mini",
ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: "grok-3-mini",
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -98,7 +97,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"grok-3-mini",
ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: "grok-3-mini",
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -108,7 +106,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
"grok-2",
ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: "grok-2",
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -122,7 +119,23 @@ pub fn resolve_model_alias(model: &str) -> String {
let lower = trimmed.to_ascii_lowercase();
MODEL_REGISTRY
.iter()
.find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model))
.find_map(|(alias, metadata)| {
(*alias == lower).then_some(match metadata.provider {
ProviderKind::Anthropic => match *alias {
"opus" => "claude-opus-4-6",
"sonnet" => "claude-sonnet-4-6",
"haiku" => "claude-haiku-4-5-20251213",
_ => trimmed,
},
ProviderKind::Xai => match *alias {
"grok" | "grok-3" => "grok-3",
"grok-mini" | "grok-3-mini" => "grok-3-mini",
"grok-2" => "grok-2",
_ => trimmed,
},
ProviderKind::OpenAi => trimmed,
})
})
.map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
}
@@ -132,7 +145,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
if canonical.starts_with("claude") {
return Some(ProviderMetadata {
provider: ProviderKind::Anthropic,
canonical_model: Box::leak(canonical.into_boxed_str()),
auth_env: "ANTHROPIC_API_KEY",
base_url_env: "ANTHROPIC_BASE_URL",
default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -141,7 +153,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
if canonical.starts_with("grok") {
return Some(ProviderMetadata {
provider: ProviderKind::Xai,
canonical_model: Box::leak(canonical.into_boxed_str()),
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -191,7 +202,10 @@ mod tests {
#[test]
fn detects_provider_from_model_name_first() {
assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic);
assert_eq!(
detect_provider_kind("claude-sonnet-4-6"),
ProviderKind::Anthropic
);
}
#[test]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,312 @@
use std::collections::HashMap;
use std::sync::Arc;
use api::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
#[tokio::test]
async fn send_message_uses_openai_compatible_endpoint_and_auth() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"chatcmpl_test\",",
"\"model\":\"grok-3\",",
"\"choices\":[{",
"\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
"\"finish_reason\":\"stop\"",
"}],",
"\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
"}"
);
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
.with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.model, "grok-3");
assert_eq!(response.total_tokens(), 16);
assert_eq!(
response.content,
vec![OutputContentBlock::Text {
text: "Hello from Grok".to_string(),
}]
);
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert_eq!(request.path, "/chat/completions");
assert_eq!(
request.headers.get("authorization").map(String::as_str),
Some("Bearer xai-test-key")
);
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["model"], json!("grok-3"));
assert_eq!(body["messages"][0]["role"], json!("system"));
assert_eq!(body["tools"][0]["type"], json!("function"));
}
#[tokio::test]
async fn stream_message_normalizes_text_and_multiple_tool_calls() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
"data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
"data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: [DONE]\n\n"
);
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"text/event-stream",
sse,
&[("x-request-id", "req_grok_stream")],
)],
)
.await;
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
.with_base_url(server.base_url());
let mut stream = client
.stream_message(&sample_request(false))
.await
.expect("stream should start");
assert_eq!(stream.request_id(), Some("req_grok_stream"));
let mut events = Vec::new();
while let Some(event) = stream.next_event().await.expect("event should parse") {
events.push(event);
}
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
assert!(matches!(
events[1],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::Text { .. },
..
})
));
assert!(matches!(
events[2],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
delta: ContentBlockDelta::TextDelta { .. },
..
})
));
assert!(matches!(
events[3],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
index: 1,
content_block: OutputContentBlock::ToolUse { .. },
})
));
assert!(matches!(
events[4],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
index: 1,
delta: ContentBlockDelta::InputJsonDelta { .. },
})
));
assert!(matches!(
events[5],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
index: 2,
content_block: OutputContentBlock::ToolUse { .. },
})
));
assert!(matches!(
events[6],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
index: 2,
delta: ContentBlockDelta::InputJsonDelta { .. },
})
));
assert!(matches!(
events[7],
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
));
assert!(matches!(
events[8],
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
));
assert!(matches!(
events[9],
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
));
assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
assert!(matches!(events[11], StreamEvent::MessageStop(_)));
let captured = state.lock().await;
let request = captured.first().expect("captured request");
assert_eq!(request.path, "/chat/completions");
assert!(request.body.contains("\"stream\":true"));
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct CapturedRequest {
path: String,
headers: HashMap<String, String>,
body: String,
}
struct TestServer {
base_url: String,
join_handle: tokio::task::JoinHandle<()>,
}
impl TestServer {
fn base_url(&self) -> String {
self.base_url.clone()
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.join_handle.abort();
}
}
async fn spawn_server(
state: Arc<Mutex<Vec<CapturedRequest>>>,
responses: Vec<String>,
) -> TestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let address = listener.local_addr().expect("listener addr");
let join_handle = tokio::spawn(async move {
for response in responses {
let (mut socket, _) = listener.accept().await.expect("accept");
let mut buffer = Vec::new();
let mut header_end = None;
loop {
let mut chunk = [0_u8; 1024];
let read = socket.read(&mut chunk).await.expect("read request");
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
if let Some(position) = find_header_end(&buffer) {
header_end = Some(position);
break;
}
}
let header_end = header_end.expect("headers should exist");
let (header_bytes, remaining) = buffer.split_at(header_end);
let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
let mut lines = header_text.split("\r\n");
let request_line = lines.next().expect("request line");
let path = request_line
.split_whitespace()
.nth(1)
.expect("path")
.to_string();
let mut headers = HashMap::new();
let mut content_length = 0_usize;
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').expect("header");
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().expect("content length");
}
headers.insert(name.to_ascii_lowercase(), value);
}
let mut body = remaining[4..].to_vec();
while body.len() < content_length {
let mut chunk = vec![0_u8; content_length - body.len()];
let read = socket.read(&mut chunk).await.expect("read body");
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
}
state.lock().await.push(CapturedRequest {
path,
headers,
body: String::from_utf8(body).expect("utf8 body"),
});
socket
.write_all(response.as_bytes())
.await
.expect("write response");
}
});
TestServer {
base_url: format!("http://{address}"),
join_handle,
}
}
fn find_header_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn http_response(status: &str, content_type: &str, body: &str) -> String {
http_response_with_headers(status, content_type, body, &[])
}
fn http_response_with_headers(
status: &str,
content_type: &str,
body: &str,
headers: &[(&str, &str)],
) -> String {
let mut extra_headers = String::new();
for (name, value) in headers {
use std::fmt::Write as _;
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
}
format!(
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
)
}
fn sample_request(stream: bool) -> MessageRequest {
MessageRequest {
model: "grok-3".to_string(),
max_tokens: 64,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![InputContentBlock::Text {
text: "Say hello".to_string(),
}],
}],
system: Some("Use tools when needed".to_string()),
tools: Some(vec![ToolDefinition {
name: "weather".to_string(),
description: Some("Fetches weather".to_string()),
input_schema: json!({
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]
}),
}]),
tool_choice: Some(ToolChoice::Auto),
stream,
}
}