mirror of
https://github.com/lWolvesl/claw-code.git
synced 2026-04-02 23:11:52 +08:00
feat: provider abstraction layer + Grok API support
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,19 @@
|
|||||||
mod client;
|
mod client;
|
||||||
mod error;
|
mod error;
|
||||||
|
mod providers;
|
||||||
mod sse;
|
mod sse;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use client::{
|
pub use client::{
|
||||||
oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source,
|
oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token,
|
||||||
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
|
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
|
||||||
};
|
};
|
||||||
pub use error::ApiError;
|
pub use error::ApiError;
|
||||||
|
pub use providers::anthropic::{AnthropicClient, AuthSource};
|
||||||
|
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
|
||||||
|
pub use providers::{
|
||||||
|
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
|
||||||
|
};
|
||||||
pub use sse::{parse_frame, SseParser};
|
pub use sse::{parse_frame, SseParser};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ use runtime::{
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
|
|
||||||
|
use super::{Provider, ProviderFuture};
|
||||||
use crate::sse::SseParser;
|
use crate::sse::SseParser;
|
||||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||||
|
|
||||||
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||||
const REQUEST_ID_HEADER: &str = "request-id";
|
const REQUEST_ID_HEADER: &str = "request-id";
|
||||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||||
@@ -41,7 +43,10 @@ impl AuthSource {
|
|||||||
}),
|
}),
|
||||||
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
||||||
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
||||||
(None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
|
(None, None) => Err(ApiError::missing_credentials(
|
||||||
|
"Anthropic",
|
||||||
|
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,7 +367,10 @@ impl AuthSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
||||||
Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
|
Ok(None) => Err(ApiError::missing_credentials(
|
||||||
|
"Anthropic",
|
||||||
|
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
)),
|
||||||
Err(error) => Err(error),
|
Err(error) => Err(error),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -382,6 +390,12 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
|
|||||||
resolve_saved_oauth_token_set(config, token_set).map(Some)
|
resolve_saved_oauth_token_set(config, token_set).map(Some)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
|
||||||
|
Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
|
||||||
|
|| read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
|
||||||
|
|| load_saved_oauth_token()?.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
||||||
where
|
where
|
||||||
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
||||||
@@ -400,7 +414,10 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let Some(token_set) = load_saved_oauth_token()? else {
|
let Some(token_set) = load_saved_oauth_token()? else {
|
||||||
return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]));
|
return Err(ApiError::missing_credentials(
|
||||||
|
"Anthropic",
|
||||||
|
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
));
|
||||||
};
|
};
|
||||||
if !oauth_token_is_expired(&token_set) {
|
if !oauth_token_is_expired(&token_set) {
|
||||||
return Ok(AuthSource::BearerToken(token_set.access_token));
|
return Ok(AuthSource::BearerToken(token_set.access_token));
|
||||||
@@ -497,7 +514,10 @@ fn read_api_key() -> Result<String, ApiError> {
|
|||||||
auth.api_key()
|
auth.api_key()
|
||||||
.or_else(|| auth.bearer_token())
|
.or_else(|| auth.bearer_token())
|
||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
.ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]))
|
.ok_or(ApiError::missing_credentials(
|
||||||
|
"Anthropic",
|
||||||
|
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -520,6 +540,24 @@ fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<Strin
|
|||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Provider for AnthropicClient {
|
||||||
|
type Stream = MessageStream;
|
||||||
|
|
||||||
|
fn send_message<'a>(
|
||||||
|
&'a self,
|
||||||
|
request: &'a MessageRequest,
|
||||||
|
) -> ProviderFuture<'a, MessageResponse> {
|
||||||
|
Box::pin(async move { self.send_message(request).await })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_message<'a>(
|
||||||
|
&'a self,
|
||||||
|
request: &'a MessageRequest,
|
||||||
|
) -> ProviderFuture<'a, Self::Stream> {
|
||||||
|
Box::pin(async move { self.stream_message(request).await })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MessageStream {
|
pub struct MessageStream {
|
||||||
request_id: Option<String>,
|
request_id: Option<String>,
|
||||||
@@ -673,7 +711,10 @@ mod tests {
|
|||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
let error = super::read_api_key().expect_err("missing key should error");
|
let error = super::read_api_key().expect_err("missing key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
crate::error::ApiError::MissingCredentials { .. }
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -682,7 +723,10 @@ mod tests {
|
|||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
let error = super::read_api_key().expect_err("empty key should error");
|
let error = super::read_api_key().expect_err("empty key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
crate::error::ApiError::MissingCredentials { .. }
|
||||||
|
));
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,15 @@ pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>
|
|||||||
pub trait Provider {
|
pub trait Provider {
|
||||||
type Stream;
|
type Stream;
|
||||||
|
|
||||||
fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>;
|
fn send_message<'a>(
|
||||||
|
&'a self,
|
||||||
|
request: &'a MessageRequest,
|
||||||
|
) -> ProviderFuture<'a, MessageResponse>;
|
||||||
|
|
||||||
fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>;
|
fn stream_message<'a>(
|
||||||
|
&'a self,
|
||||||
|
request: &'a MessageRequest,
|
||||||
|
) -> ProviderFuture<'a, Self::Stream>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@@ -27,7 +33,6 @@ pub enum ProviderKind {
|
|||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub struct ProviderMetadata {
|
pub struct ProviderMetadata {
|
||||||
pub provider: ProviderKind,
|
pub provider: ProviderKind,
|
||||||
pub canonical_model: &'static str,
|
|
||||||
pub auth_env: &'static str,
|
pub auth_env: &'static str,
|
||||||
pub base_url_env: &'static str,
|
pub base_url_env: &'static str,
|
||||||
pub default_base_url: &'static str,
|
pub default_base_url: &'static str,
|
||||||
@@ -38,7 +43,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"opus",
|
"opus",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Anthropic,
|
provider: ProviderKind::Anthropic,
|
||||||
canonical_model: "claude-opus-4-6",
|
|
||||||
auth_env: "ANTHROPIC_API_KEY",
|
auth_env: "ANTHROPIC_API_KEY",
|
||||||
base_url_env: "ANTHROPIC_BASE_URL",
|
base_url_env: "ANTHROPIC_BASE_URL",
|
||||||
default_base_url: anthropic::DEFAULT_BASE_URL,
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
||||||
@@ -48,7 +52,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"sonnet",
|
"sonnet",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Anthropic,
|
provider: ProviderKind::Anthropic,
|
||||||
canonical_model: "claude-sonnet-4-6",
|
|
||||||
auth_env: "ANTHROPIC_API_KEY",
|
auth_env: "ANTHROPIC_API_KEY",
|
||||||
base_url_env: "ANTHROPIC_BASE_URL",
|
base_url_env: "ANTHROPIC_BASE_URL",
|
||||||
default_base_url: anthropic::DEFAULT_BASE_URL,
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
||||||
@@ -58,7 +61,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"haiku",
|
"haiku",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Anthropic,
|
provider: ProviderKind::Anthropic,
|
||||||
canonical_model: "claude-haiku-4-5-20251213",
|
|
||||||
auth_env: "ANTHROPIC_API_KEY",
|
auth_env: "ANTHROPIC_API_KEY",
|
||||||
base_url_env: "ANTHROPIC_BASE_URL",
|
base_url_env: "ANTHROPIC_BASE_URL",
|
||||||
default_base_url: anthropic::DEFAULT_BASE_URL,
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
||||||
@@ -68,7 +70,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"grok",
|
"grok",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: "grok-3",
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -78,7 +79,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"grok-3",
|
"grok-3",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: "grok-3",
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -88,7 +88,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"grok-mini",
|
"grok-mini",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: "grok-3-mini",
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -98,7 +97,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"grok-3-mini",
|
"grok-3-mini",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: "grok-3-mini",
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -108,7 +106,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|||||||
"grok-2",
|
"grok-2",
|
||||||
ProviderMetadata {
|
ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: "grok-2",
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -122,7 +119,23 @@ pub fn resolve_model_alias(model: &str) -> String {
|
|||||||
let lower = trimmed.to_ascii_lowercase();
|
let lower = trimmed.to_ascii_lowercase();
|
||||||
MODEL_REGISTRY
|
MODEL_REGISTRY
|
||||||
.iter()
|
.iter()
|
||||||
.find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model))
|
.find_map(|(alias, metadata)| {
|
||||||
|
(*alias == lower).then_some(match metadata.provider {
|
||||||
|
ProviderKind::Anthropic => match *alias {
|
||||||
|
"opus" => "claude-opus-4-6",
|
||||||
|
"sonnet" => "claude-sonnet-4-6",
|
||||||
|
"haiku" => "claude-haiku-4-5-20251213",
|
||||||
|
_ => trimmed,
|
||||||
|
},
|
||||||
|
ProviderKind::Xai => match *alias {
|
||||||
|
"grok" | "grok-3" => "grok-3",
|
||||||
|
"grok-mini" | "grok-3-mini" => "grok-3-mini",
|
||||||
|
"grok-2" => "grok-2",
|
||||||
|
_ => trimmed,
|
||||||
|
},
|
||||||
|
ProviderKind::OpenAi => trimmed,
|
||||||
|
})
|
||||||
|
})
|
||||||
.map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
|
.map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +145,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
|
|||||||
if canonical.starts_with("claude") {
|
if canonical.starts_with("claude") {
|
||||||
return Some(ProviderMetadata {
|
return Some(ProviderMetadata {
|
||||||
provider: ProviderKind::Anthropic,
|
provider: ProviderKind::Anthropic,
|
||||||
canonical_model: Box::leak(canonical.into_boxed_str()),
|
|
||||||
auth_env: "ANTHROPIC_API_KEY",
|
auth_env: "ANTHROPIC_API_KEY",
|
||||||
base_url_env: "ANTHROPIC_BASE_URL",
|
base_url_env: "ANTHROPIC_BASE_URL",
|
||||||
default_base_url: anthropic::DEFAULT_BASE_URL,
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
||||||
@@ -141,7 +153,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
|
|||||||
if canonical.starts_with("grok") {
|
if canonical.starts_with("grok") {
|
||||||
return Some(ProviderMetadata {
|
return Some(ProviderMetadata {
|
||||||
provider: ProviderKind::Xai,
|
provider: ProviderKind::Xai,
|
||||||
canonical_model: Box::leak(canonical.into_boxed_str()),
|
|
||||||
auth_env: "XAI_API_KEY",
|
auth_env: "XAI_API_KEY",
|
||||||
base_url_env: "XAI_BASE_URL",
|
base_url_env: "XAI_BASE_URL",
|
||||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||||
@@ -191,7 +202,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn detects_provider_from_model_name_first() {
|
fn detects_provider_from_model_name_first() {
|
||||||
assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
|
assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
|
||||||
assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic);
|
assert_eq!(
|
||||||
|
detect_provider_kind("claude-sonnet-4-6"),
|
||||||
|
ProviderKind::Anthropic
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
1025
rust/crates/api/src/providers/openai_compat.rs
Normal file
1025
rust/crates/api/src/providers/openai_compat.rs
Normal file
File diff suppressed because it is too large
Load Diff
312
rust/crates/api/tests/openai_compat_integration.rs
Normal file
312
rust/crates/api/tests/openai_compat_integration.rs
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use api::{
|
||||||
|
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||||
|
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
|
||||||
|
OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_uses_openai_compatible_endpoint_and_auth() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let body = concat!(
|
||||||
|
"{",
|
||||||
|
"\"id\":\"chatcmpl_test\",",
|
||||||
|
"\"model\":\"grok-3\",",
|
||||||
|
"\"choices\":[{",
|
||||||
|
"\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
|
||||||
|
"\"finish_reason\":\"stop\"",
|
||||||
|
"}],",
|
||||||
|
"\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
|
||||||
|
"}"
|
||||||
|
);
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response("200 OK", "application/json", body)],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
|
||||||
|
.with_base_url(server.base_url());
|
||||||
|
let response = client
|
||||||
|
.send_message(&sample_request(false))
|
||||||
|
.await
|
||||||
|
.expect("request should succeed");
|
||||||
|
|
||||||
|
assert_eq!(response.model, "grok-3");
|
||||||
|
assert_eq!(response.total_tokens(), 16);
|
||||||
|
assert_eq!(
|
||||||
|
response.content,
|
||||||
|
vec![OutputContentBlock::Text {
|
||||||
|
text: "Hello from Grok".to_string(),
|
||||||
|
}]
|
||||||
|
);
|
||||||
|
|
||||||
|
let captured = state.lock().await;
|
||||||
|
let request = captured.first().expect("server should capture request");
|
||||||
|
assert_eq!(request.path, "/chat/completions");
|
||||||
|
assert_eq!(
|
||||||
|
request.headers.get("authorization").map(String::as_str),
|
||||||
|
Some("Bearer xai-test-key")
|
||||||
|
);
|
||||||
|
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
|
||||||
|
assert_eq!(body["model"], json!("grok-3"));
|
||||||
|
assert_eq!(body["messages"][0]["role"], json!("system"));
|
||||||
|
assert_eq!(body["tools"][0]["type"], json!("function"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stream_message_normalizes_text_and_multiple_tool_calls() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let sse = concat!(
|
||||||
|
"data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
|
||||||
|
"data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
|
||||||
|
"data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
|
||||||
|
"data: [DONE]\n\n"
|
||||||
|
);
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response_with_headers(
|
||||||
|
"200 OK",
|
||||||
|
"text/event-stream",
|
||||||
|
sse,
|
||||||
|
&[("x-request-id", "req_grok_stream")],
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
|
||||||
|
.with_base_url(server.base_url());
|
||||||
|
let mut stream = client
|
||||||
|
.stream_message(&sample_request(false))
|
||||||
|
.await
|
||||||
|
.expect("stream should start");
|
||||||
|
|
||||||
|
assert_eq!(stream.request_id(), Some("req_grok_stream"));
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(event) = stream.next_event().await.expect("event should parse") {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
|
||||||
|
assert!(matches!(
|
||||||
|
events[1],
|
||||||
|
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||||
|
content_block: OutputContentBlock::Text { .. },
|
||||||
|
..
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[2],
|
||||||
|
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||||
|
delta: ContentBlockDelta::TextDelta { .. },
|
||||||
|
..
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[3],
|
||||||
|
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||||
|
index: 1,
|
||||||
|
content_block: OutputContentBlock::ToolUse { .. },
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[4],
|
||||||
|
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||||
|
index: 1,
|
||||||
|
delta: ContentBlockDelta::InputJsonDelta { .. },
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[5],
|
||||||
|
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||||
|
index: 2,
|
||||||
|
content_block: OutputContentBlock::ToolUse { .. },
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[6],
|
||||||
|
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||||
|
index: 2,
|
||||||
|
delta: ContentBlockDelta::InputJsonDelta { .. },
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[7],
|
||||||
|
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[8],
|
||||||
|
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[9],
|
||||||
|
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
|
||||||
|
));
|
||||||
|
assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
|
||||||
|
assert!(matches!(events[11], StreamEvent::MessageStop(_)));
|
||||||
|
|
||||||
|
let captured = state.lock().await;
|
||||||
|
let request = captured.first().expect("captured request");
|
||||||
|
assert_eq!(request.path, "/chat/completions");
|
||||||
|
assert!(request.body.contains("\"stream\":true"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
struct CapturedRequest {
|
||||||
|
path: String,
|
||||||
|
headers: HashMap<String, String>,
|
||||||
|
body: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestServer {
|
||||||
|
base_url: String,
|
||||||
|
join_handle: tokio::task::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestServer {
|
||||||
|
fn base_url(&self) -> String {
|
||||||
|
self.base_url.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TestServer {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.join_handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_server(
|
||||||
|
state: Arc<Mutex<Vec<CapturedRequest>>>,
|
||||||
|
responses: Vec<String>,
|
||||||
|
) -> TestServer {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0")
|
||||||
|
.await
|
||||||
|
.expect("listener should bind");
|
||||||
|
let address = listener.local_addr().expect("listener addr");
|
||||||
|
let join_handle = tokio::spawn(async move {
|
||||||
|
for response in responses {
|
||||||
|
let (mut socket, _) = listener.accept().await.expect("accept");
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
let mut header_end = None;
|
||||||
|
loop {
|
||||||
|
let mut chunk = [0_u8; 1024];
|
||||||
|
let read = socket.read(&mut chunk).await.expect("read request");
|
||||||
|
if read == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
buffer.extend_from_slice(&chunk[..read]);
|
||||||
|
if let Some(position) = find_header_end(&buffer) {
|
||||||
|
header_end = Some(position);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let header_end = header_end.expect("headers should exist");
|
||||||
|
let (header_bytes, remaining) = buffer.split_at(header_end);
|
||||||
|
let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
|
||||||
|
let mut lines = header_text.split("\r\n");
|
||||||
|
let request_line = lines.next().expect("request line");
|
||||||
|
let path = request_line
|
||||||
|
.split_whitespace()
|
||||||
|
.nth(1)
|
||||||
|
.expect("path")
|
||||||
|
.to_string();
|
||||||
|
let mut headers = HashMap::new();
|
||||||
|
let mut content_length = 0_usize;
|
||||||
|
for line in lines {
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let (name, value) = line.split_once(':').expect("header");
|
||||||
|
let value = value.trim().to_string();
|
||||||
|
if name.eq_ignore_ascii_case("content-length") {
|
||||||
|
content_length = value.parse().expect("content length");
|
||||||
|
}
|
||||||
|
headers.insert(name.to_ascii_lowercase(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut body = remaining[4..].to_vec();
|
||||||
|
while body.len() < content_length {
|
||||||
|
let mut chunk = vec![0_u8; content_length - body.len()];
|
||||||
|
let read = socket.read(&mut chunk).await.expect("read body");
|
||||||
|
if read == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
body.extend_from_slice(&chunk[..read]);
|
||||||
|
}
|
||||||
|
|
||||||
|
state.lock().await.push(CapturedRequest {
|
||||||
|
path,
|
||||||
|
headers,
|
||||||
|
body: String::from_utf8(body).expect("utf8 body"),
|
||||||
|
});
|
||||||
|
|
||||||
|
socket
|
||||||
|
.write_all(response.as_bytes())
|
||||||
|
.await
|
||||||
|
.expect("write response");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
TestServer {
|
||||||
|
base_url: format!("http://{address}"),
|
||||||
|
join_handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_header_end(bytes: &[u8]) -> Option<usize> {
|
||||||
|
bytes.windows(4).position(|window| window == b"\r\n\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn http_response(status: &str, content_type: &str, body: &str) -> String {
|
||||||
|
http_response_with_headers(status, content_type, body, &[])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn http_response_with_headers(
|
||||||
|
status: &str,
|
||||||
|
content_type: &str,
|
||||||
|
body: &str,
|
||||||
|
headers: &[(&str, &str)],
|
||||||
|
) -> String {
|
||||||
|
let mut extra_headers = String::new();
|
||||||
|
for (name, value) in headers {
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
|
||||||
|
}
|
||||||
|
format!(
|
||||||
|
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||||
|
body.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_request(stream: bool) -> MessageRequest {
|
||||||
|
MessageRequest {
|
||||||
|
model: "grok-3".to_string(),
|
||||||
|
max_tokens: 64,
|
||||||
|
messages: vec![InputMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![InputContentBlock::Text {
|
||||||
|
text: "Say hello".to_string(),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system: Some("Use tools when needed".to_string()),
|
||||||
|
tools: Some(vec![ToolDefinition {
|
||||||
|
name: "weather".to_string(),
|
||||||
|
description: Some("Fetches weather".to_string()),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
}]),
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,8 +12,9 @@ use std::process::Command;
|
|||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
|
detect_provider_kind, max_tokens_for_model, resolve_model_alias, resolve_startup_auth_source,
|
||||||
InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
|
AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage,
|
||||||
|
MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, ProviderKind,
|
||||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -35,13 +36,6 @@ use serde_json::json;
|
|||||||
use tools::{execute_tool, mvp_tool_specs, ToolSpec};
|
use tools::{execute_tool, mvp_tool_specs, ToolSpec};
|
||||||
|
|
||||||
const DEFAULT_MODEL: &str = "claude-opus-4-6";
|
const DEFAULT_MODEL: &str = "claude-opus-4-6";
|
||||||
fn max_tokens_for_model(model: &str) -> u32 {
|
|
||||||
if model.contains("opus") {
|
|
||||||
32_000
|
|
||||||
} else {
|
|
||||||
64_000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const DEFAULT_DATE: &str = "2026-03-31";
|
const DEFAULT_DATE: &str = "2026-03-31";
|
||||||
const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545;
|
const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545;
|
||||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
@@ -288,15 +282,6 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_model_alias(model: &str) -> &str {
|
|
||||||
match model {
|
|
||||||
"opus" => "claude-opus-4-6",
|
|
||||||
"sonnet" => "claude-sonnet-4-6",
|
|
||||||
"haiku" => "claude-haiku-4-5-20251213",
|
|
||||||
_ => model,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_allowed_tools(values: &[String]) -> Result<Option<AllowedToolSet>, String> {
|
fn normalize_allowed_tools(values: &[String]) -> Result<Option<AllowedToolSet>, String> {
|
||||||
if values.is_empty() {
|
if values.is_empty() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@@ -980,7 +965,7 @@ struct LiveCli {
|
|||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
permission_mode: PermissionMode,
|
permission_mode: PermissionMode,
|
||||||
system_prompt: Vec<String>,
|
system_prompt: Vec<String>,
|
||||||
runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
|
runtime: ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>,
|
||||||
session: SessionHandle,
|
session: SessionHandle,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1920,11 +1905,11 @@ fn build_runtime(
|
|||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
permission_mode: PermissionMode,
|
permission_mode: PermissionMode,
|
||||||
) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
||||||
{
|
{
|
||||||
Ok(ConversationRuntime::new_with_features(
|
Ok(ConversationRuntime::new_with_features(
|
||||||
session,
|
session,
|
||||||
AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
|
ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
|
||||||
CliToolExecutor::new(allowed_tools, emit_output),
|
CliToolExecutor::new(allowed_tools, emit_output),
|
||||||
permission_policy(permission_mode),
|
permission_policy(permission_mode),
|
||||||
system_prompt,
|
system_prompt,
|
||||||
@@ -1978,26 +1963,33 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AnthropicRuntimeClient {
|
struct ProviderRuntimeClient {
|
||||||
runtime: tokio::runtime::Runtime,
|
runtime: tokio::runtime::Runtime,
|
||||||
client: AnthropicClient,
|
client: ProviderClient,
|
||||||
model: String,
|
model: String,
|
||||||
enable_tools: bool,
|
enable_tools: bool,
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicRuntimeClient {
|
impl ProviderRuntimeClient {
|
||||||
fn new(
|
fn new(
|
||||||
model: String,
|
model: String,
|
||||||
enable_tools: bool,
|
enable_tools: bool,
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let model = resolve_model_alias(&model).to_string();
|
||||||
|
let client = match detect_provider_kind(&model) {
|
||||||
|
ProviderKind::Anthropic => ProviderClient::from_model_with_anthropic_auth(
|
||||||
|
&model,
|
||||||
|
Some(resolve_cli_auth_source()?),
|
||||||
|
)?,
|
||||||
|
ProviderKind::Xai | ProviderKind::OpenAi => ProviderClient::from_model(&model)?,
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
runtime: tokio::runtime::Runtime::new()?,
|
runtime: tokio::runtime::Runtime::new()?,
|
||||||
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
|
client,
|
||||||
.with_base_url(api::read_base_url()),
|
|
||||||
model,
|
model,
|
||||||
enable_tools,
|
enable_tools,
|
||||||
emit_output,
|
emit_output,
|
||||||
@@ -2016,7 +2008,7 @@ fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> {
|
|||||||
})?)
|
})?)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiClient for AnthropicRuntimeClient {
|
impl ApiClient for ProviderRuntimeClient {
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||||
let message_request = MessageRequest {
|
let message_request = MessageRequest {
|
||||||
@@ -2911,6 +2903,9 @@ mod tests {
|
|||||||
assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
|
assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
|
||||||
assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6");
|
assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6");
|
||||||
assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213");
|
assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213");
|
||||||
|
assert_eq!(resolve_model_alias("grok"), "grok-3");
|
||||||
|
assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
|
||||||
|
assert_eq!(resolve_model_alias("grok-2"), "grok-2");
|
||||||
assert_eq!(resolve_model_alias("claude-opus"), "claude-opus");
|
assert_eq!(resolve_model_alias("claude-opus"), "claude-opus");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ use std::process::Command;
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
|
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta,
|
||||||
MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice,
|
InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
|
||||||
ToolDefinition, ToolResultContentBlock,
|
ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
|
||||||
|
ToolResultContentBlock,
|
||||||
};
|
};
|
||||||
use reqwest::blocking::Client;
|
use reqwest::blocking::Client;
|
||||||
use runtime::{
|
use runtime::{
|
||||||
@@ -1459,14 +1460,14 @@ fn run_agent_job(job: &AgentJob) -> Result<(), String> {
|
|||||||
|
|
||||||
fn build_agent_runtime(
|
fn build_agent_runtime(
|
||||||
job: &AgentJob,
|
job: &AgentJob,
|
||||||
) -> Result<ConversationRuntime<AnthropicRuntimeClient, SubagentToolExecutor>, String> {
|
) -> Result<ConversationRuntime<ProviderRuntimeClient, SubagentToolExecutor>, String> {
|
||||||
let model = job
|
let model = job
|
||||||
.manifest
|
.manifest
|
||||||
.model
|
.model
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
||||||
let allowed_tools = job.allowed_tools.clone();
|
let allowed_tools = job.allowed_tools.clone();
|
||||||
let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?;
|
let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?;
|
||||||
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
||||||
Ok(ConversationRuntime::new(
|
Ok(ConversationRuntime::new(
|
||||||
Session::new(),
|
Session::new(),
|
||||||
@@ -1635,18 +1636,21 @@ fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Optio
|
|||||||
sections.join("")
|
sections.join("")
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AnthropicRuntimeClient {
|
struct ProviderRuntimeClient {
|
||||||
runtime: tokio::runtime::Runtime,
|
runtime: tokio::runtime::Runtime,
|
||||||
client: AnthropicClient,
|
client: ProviderClient,
|
||||||
model: String,
|
model: String,
|
||||||
allowed_tools: BTreeSet<String>,
|
allowed_tools: BTreeSet<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicRuntimeClient {
|
impl ProviderRuntimeClient {
|
||||||
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
||||||
let client = AnthropicClient::from_env()
|
let model = resolve_model_alias(&model).to_string();
|
||||||
.map_err(|error| error.to_string())?
|
let client = match detect_provider_kind(&model) {
|
||||||
.with_base_url(read_base_url());
|
ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => {
|
||||||
|
ProviderClient::from_model(&model).map_err(|error| error.to_string())?
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
||||||
client,
|
client,
|
||||||
@@ -1656,7 +1660,7 @@ impl AnthropicRuntimeClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiClient for AnthropicRuntimeClient {
|
impl ApiClient for ProviderRuntimeClient {
|
||||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||||
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
|
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -1668,7 +1672,7 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let message_request = MessageRequest {
|
let message_request = MessageRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
max_tokens: 32_000,
|
max_tokens: max_tokens_for_model(&self.model),
|
||||||
messages: convert_messages(&request.messages),
|
messages: convert_messages(&request.messages),
|
||||||
system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")),
|
system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")),
|
||||||
tools: (!tools.is_empty()).then_some(tools),
|
tools: (!tools.is_empty()).then_some(tools),
|
||||||
|
|||||||
Reference in New Issue
Block a user