mirror of
https://github.com/lWolvesl/claw-code.git
synced 2026-04-02 16:51:51 +08:00
feat: merge 2nd round from all rcc/* sessions
- api: tool_use parsing, message_delta, request_id tracking, retry logic - tools: extended tool suite (WebSearch, WebFetch, Agent, etc.) - cli: live streamed conversations, session restore, compact commands - runtime: config loading, system prompt builder, token usage, compaction
This commit is contained in:
@@ -1,9 +1,19 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
const REQUEST_ID_HEADER: &str = "request-id";
|
||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnthropicClient {
|
||||
@@ -11,6 +21,9 @@ pub struct AnthropicClient {
|
||||
api_key: String,
|
||||
auth_token: Option<String>,
|
||||
base_url: String,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
@@ -21,6 +34,9 @@ impl AnthropicClient {
|
||||
api_key: api_key.into(),
|
||||
auth_token: None,
|
||||
base_url: DEFAULT_BASE_URL.to_string(),
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +63,19 @@ impl AnthropicClient {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_retry_policy(
|
||||
mut self,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
) -> Self {
|
||||
self.max_retries = max_retries;
|
||||
self.initial_backoff = initial_backoff;
|
||||
self.max_backoff = max_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -55,12 +84,16 @@ impl AnthropicClient {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
let response = self.send_raw_request(&request).await?;
|
||||
let response = expect_success(response).await?;
|
||||
response
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let mut response = response
|
||||
.json::<MessageResponse>()
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
.map_err(ApiError::from)?;
|
||||
if response.request_id.is_none() {
|
||||
response.request_id = request_id;
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_message(
|
||||
@@ -68,17 +101,53 @@ impl AnthropicClient {
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
let response = self
|
||||
.send_raw_request(&request.clone().with_streaming())
|
||||
.send_with_retry(&request.clone().with_streaming())
|
||||
.await?;
|
||||
let response = expect_success(response).await?;
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: SseParser::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<reqwest::Response, ApiError> {
|
||||
let mut attempts = 0;
|
||||
let mut last_error: Option<ApiError>;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
match self.send_raw_request(request).await {
|
||||
Ok(response) => match expect_success(response).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
},
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
|
||||
if attempts > self.max_retries {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
||||
}
|
||||
|
||||
Err(ApiError::RetriesExhausted {
|
||||
attempts,
|
||||
last_error: Box::new(last_error.expect("retry loop must capture an error")),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw_request(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -103,6 +172,19 @@ impl AnthropicClient {
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
||||
return Err(ApiError::BackoffOverflow {
|
||||
attempt,
|
||||
base_delay: self.initial_backoff,
|
||||
});
|
||||
};
|
||||
Ok(self
|
||||
.initial_backoff
|
||||
.checked_mul(multiplier)
|
||||
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_api_key(
|
||||
@@ -116,15 +198,29 @@ fn read_api_key(
|
||||
}
|
||||
}
|
||||
|
||||
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MessageStream {
|
||||
request_id: Option<String>,
|
||||
response: reqwest::Response,
|
||||
parser: SseParser,
|
||||
pending: std::collections::VecDeque<StreamEvent>,
|
||||
pending: VecDeque<StreamEvent>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
#[must_use]
|
||||
pub fn request_id(&self) -> Option<&str> {
|
||||
self.request_id.as_deref()
|
||||
}
|
||||
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
loop {
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
@@ -159,14 +255,46 @@ async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response
|
||||
}
|
||||
|
||||
let body = response.text().await.unwrap_or_else(|_| String::new());
|
||||
Err(ApiError::UnexpectedStatus { status, body })
|
||||
let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
|
||||
let retryable = is_retryable_status(status);
|
||||
|
||||
Err(ApiError::Api {
|
||||
status,
|
||||
error_type: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.error_type.clone()),
|
||||
message: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.message.clone()),
|
||||
body,
|
||||
retryable,
|
||||
})
|
||||
}
|
||||
|
||||
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
||||
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorEnvelope {
|
||||
error: AnthropicErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorBody {
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::env::VarError;
|
||||
|
||||
use crate::types::MessageRequest;
|
||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_presence() {
|
||||
@@ -194,9 +322,76 @@ mod tests {
|
||||
max_tokens: 64,
|
||||
messages: vec![],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
assert!(request.with_streaming().stream);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backoff_doubles_until_maximum() {
|
||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
||||
3,
|
||||
Duration::from_millis(10),
|
||||
Duration::from_millis(25),
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(1).expect("attempt 1"),
|
||||
Duration::from_millis(10)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(2).expect("attempt 2"),
|
||||
Duration::from_millis(20)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(3).expect("attempt 3"),
|
||||
Duration::from_millis(25)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_statuses_are_detected() {
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
));
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::INTERNAL_SERVER_ERROR
|
||||
));
|
||||
assert!(!super::is_retryable_status(
|
||||
reqwest::StatusCode::UNAUTHORIZED
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_delta_variant_round_trips() {
|
||||
let delta = ContentBlockDelta::InputJsonDelta {
|
||||
partial_json: "{\"city\":\"Paris\"}".to_string(),
|
||||
};
|
||||
let encoded = serde_json::to_string(&delta).expect("delta should serialize");
|
||||
let decoded: ContentBlockDelta =
|
||||
serde_json::from_str(&encoded).expect("delta should deserialize");
|
||||
assert_eq!(decoded, delta);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_id_uses_primary_or_fallback_header() {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_primary")
|
||||
);
|
||||
|
||||
headers.clear();
|
||||
headers.insert(
|
||||
ALT_REQUEST_ID_HEADER,
|
||||
"req_fallback".parse().expect("header"),
|
||||
);
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_fallback")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::env::VarError;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
@@ -8,11 +9,39 @@ pub enum ApiError {
|
||||
Http(reqwest::Error),
|
||||
Io(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
UnexpectedStatus {
|
||||
Api {
|
||||
status: reqwest::StatusCode,
|
||||
error_type: Option<String>,
|
||||
message: Option<String>,
|
||||
body: String,
|
||||
retryable: bool,
|
||||
},
|
||||
RetriesExhausted {
|
||||
attempts: u32,
|
||||
last_error: Box<ApiError>,
|
||||
},
|
||||
InvalidSseFrame(&'static str),
|
||||
BackoffOverflow {
|
||||
attempt: u32,
|
||||
base_delay: Duration,
|
||||
},
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
#[must_use]
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
match self {
|
||||
Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(),
|
||||
Self::Api { retryable, .. } => *retryable,
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||
Self::MissingApiKey
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json(_)
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ApiError {
|
||||
@@ -30,10 +59,36 @@ impl Display for ApiError {
|
||||
Self::Http(error) => write!(f, "http error: {error}"),
|
||||
Self::Io(error) => write!(f, "io error: {error}"),
|
||||
Self::Json(error) => write!(f, "json error: {error}"),
|
||||
Self::UnexpectedStatus { status, body } => {
|
||||
write!(f, "anthropic api returned {status}: {body}")
|
||||
}
|
||||
Self::Api {
|
||||
status,
|
||||
error_type,
|
||||
message,
|
||||
body,
|
||||
..
|
||||
} => match (error_type, message) {
|
||||
(Some(error_type), Some(message)) => {
|
||||
write!(
|
||||
f,
|
||||
"anthropic api returned {status} ({error_type}): {message}"
|
||||
)
|
||||
}
|
||||
_ => write!(f, "anthropic api returned {status}: {body}"),
|
||||
},
|
||||
Self::RetriesExhausted {
|
||||
attempts,
|
||||
last_error,
|
||||
} => write!(
|
||||
f,
|
||||
"anthropic api failed after {attempts} attempts: {last_error}"
|
||||
),
|
||||
Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"),
|
||||
Self::BackoffOverflow {
|
||||
attempt,
|
||||
base_delay,
|
||||
} => write!(
|
||||
f,
|
||||
"retry backoff overflowed on attempt {attempt} with base delay {base_delay:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ pub use error::ApiError;
|
||||
pub use sse::{parse_frame, SseParser};
|
||||
pub use types::{
|
||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||
InputContentBlock, InputMessage, MessageRequest, MessageResponse, MessageStartEvent,
|
||||
MessageStopEvent, OutputContentBlock, StreamEvent, Usage,
|
||||
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
|
||||
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
|
||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||
};
|
||||
|
||||
@@ -103,7 +103,7 @@ pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{parse_frame, SseParser};
|
||||
use crate::types::{ContentBlockDelta, OutputContentBlock, StreamEvent};
|
||||
use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
|
||||
|
||||
#[test]
|
||||
fn parses_single_frame() {
|
||||
@@ -158,6 +158,8 @@ mod tests {
|
||||
": keepalive\n",
|
||||
"event: ping\n",
|
||||
"data: {\"type\":\"ping\"}\n\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
@@ -168,7 +170,19 @@ mod tests {
|
||||
.expect("parser should succeed");
|
||||
assert_eq!(
|
||||
events,
|
||||
vec![StreamEvent::MessageStop(crate::types::MessageStopEvent {})]
|
||||
vec![
|
||||
StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
|
||||
delta: MessageDelta {
|
||||
stop_reason: Some("tool_use".to_string()),
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: Usage {
|
||||
input_tokens: 1,
|
||||
output_tokens: 2,
|
||||
},
|
||||
}),
|
||||
StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<InputMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ToolDefinition>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub stream: bool,
|
||||
}
|
||||
@@ -19,7 +24,7 @@ impl MessageRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct InputMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<InputContentBlock>,
|
||||
@@ -33,15 +38,64 @@ impl InputMessage {
|
||||
content: vec![InputContentBlock::Text { text: text.into() }],
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn user_tool_result(
|
||||
tool_use_id: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
is_error: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: vec![InputContentBlock::ToolResult {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: vec![ToolResultContentBlock::Text {
|
||||
text: content.into(),
|
||||
}],
|
||||
is_error,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContentBlock {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: Vec<ToolResultContentBlock>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
is_error: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolResultContentBlock {
|
||||
Text { text: String },
|
||||
Json { value: Value },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContentBlock {
|
||||
Text { text: String },
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
@@ -54,12 +108,28 @@ pub struct MessageResponse {
|
||||
#[serde(default)]
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Usage,
|
||||
#[serde(default)]
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl MessageResponse {
|
||||
#[must_use]
|
||||
pub fn total_tokens(&self) -> u32 {
|
||||
self.usage.total_tokens()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum OutputContentBlock {
|
||||
Text { text: String },
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -68,18 +138,39 @@ pub struct Usage {
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl Usage {
|
||||
#[must_use]
|
||||
pub const fn total_tokens(&self) -> u32 {
|
||||
self.input_tokens + self.output_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageStartEvent {
|
||||
pub message: MessageResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageDeltaEvent {
|
||||
pub delta: MessageDelta,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct MessageDelta {
|
||||
#[serde(default)]
|
||||
pub stop_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ContentBlockStartEvent {
|
||||
pub index: u32,
|
||||
pub content_block: OutputContentBlock,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ContentBlockDeltaEvent {
|
||||
pub index: u32,
|
||||
pub delta: ContentBlockDelta,
|
||||
@@ -89,6 +180,7 @@ pub struct ContentBlockDeltaEvent {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlockDelta {
|
||||
TextDelta { text: String },
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -99,10 +191,11 @@ pub struct ContentBlockStopEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct MessageStopEvent {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum StreamEvent {
|
||||
MessageStart(MessageStartEvent),
|
||||
MessageDelta(MessageDeltaEvent),
|
||||
ContentBlockStart(ContentBlockStartEvent),
|
||||
ContentBlockDelta(ContentBlockDeltaEvent),
|
||||
ContentBlockStop(ContentBlockStopEvent),
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use api::{AnthropicClient, InputMessage, MessageRequest, OutputContentBlock, StreamEvent};
|
||||
use api::{
|
||||
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
|
||||
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
|
||||
StreamEvent, ToolChoice, ToolDefinition,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -18,10 +24,15 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4}",
|
||||
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
|
||||
"\"request_id\":\"req_body_123\"",
|
||||
"}"
|
||||
);
|
||||
let server = spawn_server(state.clone(), http_response("application/json", body)).await;
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response("200 OK", "application/json", body)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_auth_token(Some("proxy-token".to_string()))
|
||||
@@ -32,6 +43,8 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(response.id, "msg_test");
|
||||
assert_eq!(response.total_tokens(), 16);
|
||||
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
|
||||
assert_eq!(
|
||||
response.content,
|
||||
vec![OutputContentBlock::Text {
|
||||
@@ -51,39 +64,45 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
request.headers.get("authorization").map(String::as_str),
|
||||
Some("Bearer proxy-token")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("anthropic-version").map(String::as_str),
|
||||
Some("2023-06-01")
|
||||
);
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_str(&request.body).expect("request body should be json");
|
||||
assert_eq!(
|
||||
body.get("model").and_then(serde_json::Value::as_str),
|
||||
Some("claude-3-7-sonnet-latest")
|
||||
);
|
||||
assert!(
|
||||
body.get("stream").is_none(),
|
||||
"non-stream request should omit stream=false"
|
||||
);
|
||||
assert!(body.get("stream").is_none());
|
||||
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
|
||||
assert_eq!(body["tool_choice"]["type"], json!("auto"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_message_parses_sse_events() {
|
||||
async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let sse = concat!(
|
||||
"event: message_start\n",
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
|
||||
"event: content_block_start\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
|
||||
"event: content_block_delta\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\n",
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
);
|
||||
let server = spawn_server(state.clone(), http_response("text/event-stream", sse)).await;
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response_with_headers(
|
||||
"200 OK",
|
||||
"text/event-stream",
|
||||
sse,
|
||||
&[("request-id", "req_stream_456")],
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_auth_token(Some("proxy-token".to_string()))
|
||||
@@ -93,6 +112,8 @@ async fn stream_message_parses_sse_events() {
|
||||
.await
|
||||
.expect("stream should start");
|
||||
|
||||
assert_eq!(stream.request_id(), Some("req_stream_456"));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream
|
||||
.next_event()
|
||||
@@ -102,18 +123,126 @@ async fn stream_message_parses_sse_events() {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert_eq!(events.len(), 5);
|
||||
assert_eq!(events.len(), 6);
|
||||
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
|
||||
assert!(matches!(events[1], StreamEvent::ContentBlockStart(_)));
|
||||
assert!(matches!(events[2], StreamEvent::ContentBlockDelta(_)));
|
||||
assert!(matches!(
|
||||
events[1],
|
||||
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
content_block: OutputContentBlock::ToolUse { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(
|
||||
events[2],
|
||||
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||
delta: ContentBlockDelta::InputJsonDelta { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
|
||||
assert!(matches!(events[4], StreamEvent::MessageStop(_)));
|
||||
assert!(matches!(
|
||||
events[4],
|
||||
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
|
||||
));
|
||||
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
|
||||
|
||||
match &events[1] {
|
||||
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
content_block: OutputContentBlock::ToolUse { name, input, .. },
|
||||
..
|
||||
}) => {
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input, &json!({}));
|
||||
}
|
||||
other => panic!("expected tool_use block, got {other:?}"),
|
||||
}
|
||||
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("server should capture request");
|
||||
assert!(request.body.contains("\"stream\":true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_retryable_failures_before_succeeding() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
|
||||
),
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
|
||||
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("retry should eventually succeed");
|
||||
|
||||
assert_eq!(response.total_tokens(), 5);
|
||||
assert_eq!(state.lock().await.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
|
||||
|
||||
let error = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect_err("persistent 503 should fail");
|
||||
|
||||
match error {
|
||||
ApiError::RetriesExhausted {
|
||||
attempts,
|
||||
last_error,
|
||||
} => {
|
||||
assert_eq!(attempts, 2);
|
||||
assert!(matches!(
|
||||
*last_error,
|
||||
ApiError::Api {
|
||||
status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
|
||||
retryable: true,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
other => panic!("expected retries exhausted, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
|
||||
async fn live_stream_smoke_test() {
|
||||
@@ -127,51 +256,18 @@ async fn live_stream_smoke_test() {
|
||||
"Reply with exactly: hello from rust",
|
||||
)],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
})
|
||||
.await
|
||||
.expect("live stream should start");
|
||||
|
||||
let mut saw_start = false;
|
||||
let mut saw_follow_up = false;
|
||||
let mut event_kinds = Vec::new();
|
||||
while let Some(event) = stream
|
||||
while let Some(_event) = stream
|
||||
.next_event()
|
||||
.await
|
||||
.expect("live stream should yield events")
|
||||
{
|
||||
match event {
|
||||
StreamEvent::MessageStart(_) => {
|
||||
saw_start = true;
|
||||
event_kinds.push("message_start");
|
||||
}
|
||||
StreamEvent::ContentBlockStart(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_start");
|
||||
}
|
||||
StreamEvent::ContentBlockDelta(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_delta");
|
||||
}
|
||||
StreamEvent::ContentBlockStop(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_stop");
|
||||
}
|
||||
StreamEvent::MessageStop(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("message_stop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_start,
|
||||
"expected a message_start event; got {event_kinds:?}"
|
||||
);
|
||||
assert!(
|
||||
saw_follow_up,
|
||||
"expected at least one follow-up stream event; got {event_kinds:?}"
|
||||
);
|
||||
{}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -199,7 +295,10 @@ impl Drop for TestServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String) -> TestServer {
|
||||
async fn spawn_server(
|
||||
state: Arc<Mutex<Vec<CapturedRequest>>>,
|
||||
responses: Vec<String>,
|
||||
) -> TestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
@@ -207,72 +306,75 @@ async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String)
|
||||
.local_addr()
|
||||
.expect("listener should have local addr");
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let (mut socket, _) = listener.accept().await.expect("server should accept");
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
for response in responses {
|
||||
let (mut socket, _) = listener.accept().await.expect("server should accept");
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
|
||||
loop {
|
||||
let mut chunk = [0_u8; 1024];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
loop {
|
||||
let mut chunk = [0_u8; 1024];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("request read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
if let Some(position) = find_header_end(&buffer) {
|
||||
header_end = Some(position);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let header_end = header_end.expect("request should include headers");
|
||||
let (header_bytes, remaining) = buffer.split_at(header_end);
|
||||
let header_text =
|
||||
String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
|
||||
let mut lines = header_text.split("\r\n");
|
||||
let request_line = lines.next().expect("request line should exist");
|
||||
let mut parts = request_line.split_whitespace();
|
||||
let method = parts.next().expect("method should exist").to_string();
|
||||
let path = parts.next().expect("path should exist").to_string();
|
||||
let mut headers = HashMap::new();
|
||||
let mut content_length = 0_usize;
|
||||
for line in lines {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (name, value) = line.split_once(':').expect("header should have colon");
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse().expect("content length should parse");
|
||||
}
|
||||
headers.insert(name.to_ascii_lowercase(), value);
|
||||
}
|
||||
|
||||
let mut body = remaining[4..].to_vec();
|
||||
while body.len() < content_length {
|
||||
let mut chunk = vec![0_u8; content_length - body.len()];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("body read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
body.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
state.lock().await.push(CapturedRequest {
|
||||
method,
|
||||
path,
|
||||
headers,
|
||||
body: String::from_utf8(body).expect("body should be utf8"),
|
||||
});
|
||||
|
||||
socket
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("request read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
if let Some(position) = find_header_end(&buffer) {
|
||||
header_end = Some(position);
|
||||
break;
|
||||
}
|
||||
.expect("response write should succeed");
|
||||
}
|
||||
|
||||
let header_end = header_end.expect("request should include headers");
|
||||
let (header_bytes, remaining) = buffer.split_at(header_end);
|
||||
let header_text = String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
|
||||
let mut lines = header_text.split("\r\n");
|
||||
let request_line = lines.next().expect("request line should exist");
|
||||
let mut parts = request_line.split_whitespace();
|
||||
let method = parts.next().expect("method should exist").to_string();
|
||||
let path = parts.next().expect("path should exist").to_string();
|
||||
let mut headers = HashMap::new();
|
||||
let mut content_length = 0_usize;
|
||||
for line in lines {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (name, value) = line.split_once(':').expect("header should have colon");
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse().expect("content length should parse");
|
||||
}
|
||||
headers.insert(name.to_ascii_lowercase(), value);
|
||||
}
|
||||
|
||||
let mut body = remaining[4..].to_vec();
|
||||
while body.len() < content_length {
|
||||
let mut chunk = vec![0_u8; content_length - body.len()];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("body read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
body.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
state.lock().await.push(CapturedRequest {
|
||||
method,
|
||||
path,
|
||||
headers,
|
||||
body: String::from_utf8(body).expect("body should be utf8"),
|
||||
});
|
||||
|
||||
socket
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response write should succeed");
|
||||
});
|
||||
|
||||
TestServer {
|
||||
@@ -285,9 +387,23 @@ fn find_header_end(bytes: &[u8]) -> Option<usize> {
|
||||
bytes.windows(4).position(|window| window == b"\r\n\r\n")
|
||||
}
|
||||
|
||||
fn http_response(content_type: &str, body: &str) -> String {
|
||||
fn http_response(status: &str, content_type: &str, body: &str) -> String {
|
||||
http_response_with_headers(status, content_type, body, &[])
|
||||
}
|
||||
|
||||
fn http_response_with_headers(
|
||||
status: &str,
|
||||
content_type: &str,
|
||||
body: &str,
|
||||
headers: &[(&str, &str)],
|
||||
) -> String {
|
||||
let mut extra_headers = String::new();
|
||||
for (name, value) in headers {
|
||||
use std::fmt::Write as _;
|
||||
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
|
||||
}
|
||||
format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
)
|
||||
}
|
||||
@@ -296,8 +412,32 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
MessageRequest {
|
||||
model: "claude-3-7-sonnet-latest".to_string(),
|
||||
max_tokens: 64,
|
||||
messages: vec![InputMessage::user_text("Say hello")],
|
||||
system: None,
|
||||
messages: vec![InputMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
InputContentBlock::Text {
|
||||
text: "Say hello".to_string(),
|
||||
},
|
||||
InputContentBlock::ToolResult {
|
||||
tool_use_id: "toolu_prev".to_string(),
|
||||
content: vec![api::ToolResultContentBlock::Json {
|
||||
value: json!({"forecast": "sunny"}),
|
||||
}],
|
||||
is_error: false,
|
||||
},
|
||||
],
|
||||
}],
|
||||
system: Some("Use tools when needed".to_string()),
|
||||
tools: Some(vec![ToolDefinition {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Fetches the weather".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"]
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user