mirror of
https://github.com/lWolvesl/claw-code.git
synced 2026-04-02 07:41:52 +08:00
feat: merge 2nd round from all rcc/* sessions
- api: tool_use parsing, message_delta, request_id tracking, retry logic - tools: extended tool suite (WebSearch, WebFetch, Agent, etc.) - cli: live streamed conversations, session restore, compact commands - runtime: config loading, system prompt builder, token usage, compaction
This commit is contained in:
1
rust/.gitignore
vendored
1
rust/.gitignore
vendored
@@ -1,2 +1 @@
|
||||
target/
|
||||
.omx/
|
||||
|
||||
2269
rust/Cargo.lock
generated
2269
rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,19 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
const REQUEST_ID_HEADER: &str = "request-id";
|
||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnthropicClient {
|
||||
@@ -11,6 +21,9 @@ pub struct AnthropicClient {
|
||||
api_key: String,
|
||||
auth_token: Option<String>,
|
||||
base_url: String,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
@@ -21,6 +34,9 @@ impl AnthropicClient {
|
||||
api_key: api_key.into(),
|
||||
auth_token: None,
|
||||
base_url: DEFAULT_BASE_URL.to_string(),
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +63,19 @@ impl AnthropicClient {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_retry_policy(
|
||||
mut self,
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
) -> Self {
|
||||
self.max_retries = max_retries;
|
||||
self.initial_backoff = initial_backoff;
|
||||
self.max_backoff = max_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -55,12 +84,16 @@ impl AnthropicClient {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
let response = self.send_raw_request(&request).await?;
|
||||
let response = expect_success(response).await?;
|
||||
response
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let mut response = response
|
||||
.json::<MessageResponse>()
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
.map_err(ApiError::from)?;
|
||||
if response.request_id.is_none() {
|
||||
response.request_id = request_id;
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_message(
|
||||
@@ -68,17 +101,53 @@ impl AnthropicClient {
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
let response = self
|
||||
.send_raw_request(&request.clone().with_streaming())
|
||||
.send_with_retry(&request.clone().with_streaming())
|
||||
.await?;
|
||||
let response = expect_success(response).await?;
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: SseParser::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<reqwest::Response, ApiError> {
|
||||
let mut attempts = 0;
|
||||
let mut last_error: Option<ApiError>;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
match self.send_raw_request(request).await {
|
||||
Ok(response) => match expect_success(response).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
},
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
|
||||
if attempts > self.max_retries {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
||||
}
|
||||
|
||||
Err(ApiError::RetriesExhausted {
|
||||
attempts,
|
||||
last_error: Box::new(last_error.expect("retry loop must capture an error")),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw_request(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -103,6 +172,19 @@ impl AnthropicClient {
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
||||
return Err(ApiError::BackoffOverflow {
|
||||
attempt,
|
||||
base_delay: self.initial_backoff,
|
||||
});
|
||||
};
|
||||
Ok(self
|
||||
.initial_backoff
|
||||
.checked_mul(multiplier)
|
||||
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_api_key(
|
||||
@@ -116,15 +198,29 @@ fn read_api_key(
|
||||
}
|
||||
}
|
||||
|
||||
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MessageStream {
|
||||
request_id: Option<String>,
|
||||
response: reqwest::Response,
|
||||
parser: SseParser,
|
||||
pending: std::collections::VecDeque<StreamEvent>,
|
||||
pending: VecDeque<StreamEvent>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
#[must_use]
|
||||
pub fn request_id(&self) -> Option<&str> {
|
||||
self.request_id.as_deref()
|
||||
}
|
||||
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
loop {
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
@@ -159,14 +255,46 @@ async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response
|
||||
}
|
||||
|
||||
let body = response.text().await.unwrap_or_else(|_| String::new());
|
||||
Err(ApiError::UnexpectedStatus { status, body })
|
||||
let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
|
||||
let retryable = is_retryable_status(status);
|
||||
|
||||
Err(ApiError::Api {
|
||||
status,
|
||||
error_type: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.error_type.clone()),
|
||||
message: parsed_error
|
||||
.as_ref()
|
||||
.map(|error| error.error.message.clone()),
|
||||
body,
|
||||
retryable,
|
||||
})
|
||||
}
|
||||
|
||||
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
||||
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorEnvelope {
|
||||
error: AnthropicErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicErrorBody {
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::env::VarError;
|
||||
|
||||
use crate::types::MessageRequest;
|
||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_presence() {
|
||||
@@ -194,9 +322,76 @@ mod tests {
|
||||
max_tokens: 64,
|
||||
messages: vec![],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
assert!(request.with_streaming().stream);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backoff_doubles_until_maximum() {
|
||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
||||
3,
|
||||
Duration::from_millis(10),
|
||||
Duration::from_millis(25),
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(1).expect("attempt 1"),
|
||||
Duration::from_millis(10)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(2).expect("attempt 2"),
|
||||
Duration::from_millis(20)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(3).expect("attempt 3"),
|
||||
Duration::from_millis(25)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_statuses_are_detected() {
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
));
|
||||
assert!(super::is_retryable_status(
|
||||
reqwest::StatusCode::INTERNAL_SERVER_ERROR
|
||||
));
|
||||
assert!(!super::is_retryable_status(
|
||||
reqwest::StatusCode::UNAUTHORIZED
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_delta_variant_round_trips() {
|
||||
let delta = ContentBlockDelta::InputJsonDelta {
|
||||
partial_json: "{\"city\":\"Paris\"}".to_string(),
|
||||
};
|
||||
let encoded = serde_json::to_string(&delta).expect("delta should serialize");
|
||||
let decoded: ContentBlockDelta =
|
||||
serde_json::from_str(&encoded).expect("delta should deserialize");
|
||||
assert_eq!(decoded, delta);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_id_uses_primary_or_fallback_header() {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_primary")
|
||||
);
|
||||
|
||||
headers.clear();
|
||||
headers.insert(
|
||||
ALT_REQUEST_ID_HEADER,
|
||||
"req_fallback".parse().expect("header"),
|
||||
);
|
||||
assert_eq!(
|
||||
super::request_id_from_headers(&headers).as_deref(),
|
||||
Some("req_fallback")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::env::VarError;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
@@ -8,11 +9,39 @@ pub enum ApiError {
|
||||
Http(reqwest::Error),
|
||||
Io(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
UnexpectedStatus {
|
||||
Api {
|
||||
status: reqwest::StatusCode,
|
||||
error_type: Option<String>,
|
||||
message: Option<String>,
|
||||
body: String,
|
||||
retryable: bool,
|
||||
},
|
||||
RetriesExhausted {
|
||||
attempts: u32,
|
||||
last_error: Box<ApiError>,
|
||||
},
|
||||
InvalidSseFrame(&'static str),
|
||||
BackoffOverflow {
|
||||
attempt: u32,
|
||||
base_delay: Duration,
|
||||
},
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
#[must_use]
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
match self {
|
||||
Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(),
|
||||
Self::Api { retryable, .. } => *retryable,
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||
Self::MissingApiKey
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json(_)
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ApiError {
|
||||
@@ -30,10 +59,36 @@ impl Display for ApiError {
|
||||
Self::Http(error) => write!(f, "http error: {error}"),
|
||||
Self::Io(error) => write!(f, "io error: {error}"),
|
||||
Self::Json(error) => write!(f, "json error: {error}"),
|
||||
Self::UnexpectedStatus { status, body } => {
|
||||
write!(f, "anthropic api returned {status}: {body}")
|
||||
}
|
||||
Self::Api {
|
||||
status,
|
||||
error_type,
|
||||
message,
|
||||
body,
|
||||
..
|
||||
} => match (error_type, message) {
|
||||
(Some(error_type), Some(message)) => {
|
||||
write!(
|
||||
f,
|
||||
"anthropic api returned {status} ({error_type}): {message}"
|
||||
)
|
||||
}
|
||||
_ => write!(f, "anthropic api returned {status}: {body}"),
|
||||
},
|
||||
Self::RetriesExhausted {
|
||||
attempts,
|
||||
last_error,
|
||||
} => write!(
|
||||
f,
|
||||
"anthropic api failed after {attempts} attempts: {last_error}"
|
||||
),
|
||||
Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"),
|
||||
Self::BackoffOverflow {
|
||||
attempt,
|
||||
base_delay,
|
||||
} => write!(
|
||||
f,
|
||||
"retry backoff overflowed on attempt {attempt} with base delay {base_delay:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ pub use error::ApiError;
|
||||
pub use sse::{parse_frame, SseParser};
|
||||
pub use types::{
|
||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||
InputContentBlock, InputMessage, MessageRequest, MessageResponse, MessageStartEvent,
|
||||
MessageStopEvent, OutputContentBlock, StreamEvent, Usage,
|
||||
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
|
||||
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
|
||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||
};
|
||||
|
||||
@@ -103,7 +103,7 @@ pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{parse_frame, SseParser};
|
||||
use crate::types::{ContentBlockDelta, OutputContentBlock, StreamEvent};
|
||||
use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
|
||||
|
||||
#[test]
|
||||
fn parses_single_frame() {
|
||||
@@ -158,6 +158,8 @@ mod tests {
|
||||
": keepalive\n",
|
||||
"event: ping\n",
|
||||
"data: {\"type\":\"ping\"}\n\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
@@ -168,7 +170,19 @@ mod tests {
|
||||
.expect("parser should succeed");
|
||||
assert_eq!(
|
||||
events,
|
||||
vec![StreamEvent::MessageStop(crate::types::MessageStopEvent {})]
|
||||
vec![
|
||||
StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
|
||||
delta: MessageDelta {
|
||||
stop_reason: Some("tool_use".to_string()),
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: Usage {
|
||||
input_tokens: 1,
|
||||
output_tokens: 2,
|
||||
},
|
||||
}),
|
||||
StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<InputMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ToolDefinition>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub stream: bool,
|
||||
}
|
||||
@@ -19,7 +24,7 @@ impl MessageRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct InputMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<InputContentBlock>,
|
||||
@@ -33,15 +38,64 @@ impl InputMessage {
|
||||
content: vec![InputContentBlock::Text { text: text.into() }],
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn user_tool_result(
|
||||
tool_use_id: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
is_error: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: vec![InputContentBlock::ToolResult {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: vec![ToolResultContentBlock::Text {
|
||||
text: content.into(),
|
||||
}],
|
||||
is_error,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContentBlock {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: Vec<ToolResultContentBlock>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
is_error: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolResultContentBlock {
|
||||
Text { text: String },
|
||||
Json { value: Value },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContentBlock {
|
||||
Text { text: String },
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
@@ -54,12 +108,28 @@ pub struct MessageResponse {
|
||||
#[serde(default)]
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Usage,
|
||||
#[serde(default)]
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl MessageResponse {
|
||||
#[must_use]
|
||||
pub fn total_tokens(&self) -> u32 {
|
||||
self.usage.total_tokens()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum OutputContentBlock {
|
||||
Text { text: String },
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -68,18 +138,39 @@ pub struct Usage {
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl Usage {
|
||||
#[must_use]
|
||||
pub const fn total_tokens(&self) -> u32 {
|
||||
self.input_tokens + self.output_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageStartEvent {
|
||||
pub message: MessageResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageDeltaEvent {
|
||||
pub delta: MessageDelta,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct MessageDelta {
|
||||
#[serde(default)]
|
||||
pub stop_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ContentBlockStartEvent {
|
||||
pub index: u32,
|
||||
pub content_block: OutputContentBlock,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ContentBlockDeltaEvent {
|
||||
pub index: u32,
|
||||
pub delta: ContentBlockDelta,
|
||||
@@ -89,6 +180,7 @@ pub struct ContentBlockDeltaEvent {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlockDelta {
|
||||
TextDelta { text: String },
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -99,10 +191,11 @@ pub struct ContentBlockStopEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct MessageStopEvent {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum StreamEvent {
|
||||
MessageStart(MessageStartEvent),
|
||||
MessageDelta(MessageDeltaEvent),
|
||||
ContentBlockStart(ContentBlockStartEvent),
|
||||
ContentBlockDelta(ContentBlockDeltaEvent),
|
||||
ContentBlockStop(ContentBlockStopEvent),
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use api::{AnthropicClient, InputMessage, MessageRequest, OutputContentBlock, StreamEvent};
|
||||
use api::{
|
||||
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
|
||||
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
|
||||
StreamEvent, ToolChoice, ToolDefinition,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -18,10 +24,15 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4}",
|
||||
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
|
||||
"\"request_id\":\"req_body_123\"",
|
||||
"}"
|
||||
);
|
||||
let server = spawn_server(state.clone(), http_response("application/json", body)).await;
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response("200 OK", "application/json", body)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_auth_token(Some("proxy-token".to_string()))
|
||||
@@ -32,6 +43,8 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(response.id, "msg_test");
|
||||
assert_eq!(response.total_tokens(), 16);
|
||||
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
|
||||
assert_eq!(
|
||||
response.content,
|
||||
vec![OutputContentBlock::Text {
|
||||
@@ -51,39 +64,45 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
request.headers.get("authorization").map(String::as_str),
|
||||
Some("Bearer proxy-token")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("anthropic-version").map(String::as_str),
|
||||
Some("2023-06-01")
|
||||
);
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_str(&request.body).expect("request body should be json");
|
||||
assert_eq!(
|
||||
body.get("model").and_then(serde_json::Value::as_str),
|
||||
Some("claude-3-7-sonnet-latest")
|
||||
);
|
||||
assert!(
|
||||
body.get("stream").is_none(),
|
||||
"non-stream request should omit stream=false"
|
||||
);
|
||||
assert!(body.get("stream").is_none());
|
||||
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
|
||||
assert_eq!(body["tool_choice"]["type"], json!("auto"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_message_parses_sse_events() {
|
||||
async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let sse = concat!(
|
||||
"event: message_start\n",
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
|
||||
"event: content_block_start\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
|
||||
"event: content_block_delta\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
|
||||
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\n",
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
);
|
||||
let server = spawn_server(state.clone(), http_response("text/event-stream", sse)).await;
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response_with_headers(
|
||||
"200 OK",
|
||||
"text/event-stream",
|
||||
sse,
|
||||
&[("request-id", "req_stream_456")],
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_auth_token(Some("proxy-token".to_string()))
|
||||
@@ -93,6 +112,8 @@ async fn stream_message_parses_sse_events() {
|
||||
.await
|
||||
.expect("stream should start");
|
||||
|
||||
assert_eq!(stream.request_id(), Some("req_stream_456"));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream
|
||||
.next_event()
|
||||
@@ -102,18 +123,126 @@ async fn stream_message_parses_sse_events() {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert_eq!(events.len(), 5);
|
||||
assert_eq!(events.len(), 6);
|
||||
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
|
||||
assert!(matches!(events[1], StreamEvent::ContentBlockStart(_)));
|
||||
assert!(matches!(events[2], StreamEvent::ContentBlockDelta(_)));
|
||||
assert!(matches!(
|
||||
events[1],
|
||||
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
content_block: OutputContentBlock::ToolUse { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(
|
||||
events[2],
|
||||
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||
delta: ContentBlockDelta::InputJsonDelta { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
|
||||
assert!(matches!(events[4], StreamEvent::MessageStop(_)));
|
||||
assert!(matches!(
|
||||
events[4],
|
||||
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
|
||||
));
|
||||
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
|
||||
|
||||
match &events[1] {
|
||||
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
content_block: OutputContentBlock::ToolUse { name, input, .. },
|
||||
..
|
||||
}) => {
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input, &json!({}));
|
||||
}
|
||||
other => panic!("expected tool_use block, got {other:?}"),
|
||||
}
|
||||
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("server should capture request");
|
||||
assert!(request.body.contains("\"stream\":true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_retryable_failures_before_succeeding() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
|
||||
),
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
|
||||
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("retry should eventually succeed");
|
||||
|
||||
assert_eq!(response.total_tokens(), 5);
|
||||
assert_eq!(state.lock().await.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
|
||||
|
||||
let error = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect_err("persistent 503 should fail");
|
||||
|
||||
match error {
|
||||
ApiError::RetriesExhausted {
|
||||
attempts,
|
||||
last_error,
|
||||
} => {
|
||||
assert_eq!(attempts, 2);
|
||||
assert!(matches!(
|
||||
*last_error,
|
||||
ApiError::Api {
|
||||
status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
|
||||
retryable: true,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
other => panic!("expected retries exhausted, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
|
||||
async fn live_stream_smoke_test() {
|
||||
@@ -127,51 +256,18 @@ async fn live_stream_smoke_test() {
|
||||
"Reply with exactly: hello from rust",
|
||||
)],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
})
|
||||
.await
|
||||
.expect("live stream should start");
|
||||
|
||||
let mut saw_start = false;
|
||||
let mut saw_follow_up = false;
|
||||
let mut event_kinds = Vec::new();
|
||||
while let Some(event) = stream
|
||||
while let Some(_event) = stream
|
||||
.next_event()
|
||||
.await
|
||||
.expect("live stream should yield events")
|
||||
{
|
||||
match event {
|
||||
StreamEvent::MessageStart(_) => {
|
||||
saw_start = true;
|
||||
event_kinds.push("message_start");
|
||||
}
|
||||
StreamEvent::ContentBlockStart(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_start");
|
||||
}
|
||||
StreamEvent::ContentBlockDelta(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_delta");
|
||||
}
|
||||
StreamEvent::ContentBlockStop(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("content_block_stop");
|
||||
}
|
||||
StreamEvent::MessageStop(_) => {
|
||||
saw_follow_up = true;
|
||||
event_kinds.push("message_stop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_start,
|
||||
"expected a message_start event; got {event_kinds:?}"
|
||||
);
|
||||
assert!(
|
||||
saw_follow_up,
|
||||
"expected at least one follow-up stream event; got {event_kinds:?}"
|
||||
);
|
||||
{}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -199,7 +295,10 @@ impl Drop for TestServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String) -> TestServer {
|
||||
async fn spawn_server(
|
||||
state: Arc<Mutex<Vec<CapturedRequest>>>,
|
||||
responses: Vec<String>,
|
||||
) -> TestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
@@ -207,72 +306,75 @@ async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String)
|
||||
.local_addr()
|
||||
.expect("listener should have local addr");
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let (mut socket, _) = listener.accept().await.expect("server should accept");
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
for response in responses {
|
||||
let (mut socket, _) = listener.accept().await.expect("server should accept");
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
|
||||
loop {
|
||||
let mut chunk = [0_u8; 1024];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
loop {
|
||||
let mut chunk = [0_u8; 1024];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("request read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
if let Some(position) = find_header_end(&buffer) {
|
||||
header_end = Some(position);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let header_end = header_end.expect("request should include headers");
|
||||
let (header_bytes, remaining) = buffer.split_at(header_end);
|
||||
let header_text =
|
||||
String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
|
||||
let mut lines = header_text.split("\r\n");
|
||||
let request_line = lines.next().expect("request line should exist");
|
||||
let mut parts = request_line.split_whitespace();
|
||||
let method = parts.next().expect("method should exist").to_string();
|
||||
let path = parts.next().expect("path should exist").to_string();
|
||||
let mut headers = HashMap::new();
|
||||
let mut content_length = 0_usize;
|
||||
for line in lines {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (name, value) = line.split_once(':').expect("header should have colon");
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse().expect("content length should parse");
|
||||
}
|
||||
headers.insert(name.to_ascii_lowercase(), value);
|
||||
}
|
||||
|
||||
let mut body = remaining[4..].to_vec();
|
||||
while body.len() < content_length {
|
||||
let mut chunk = vec![0_u8; content_length - body.len()];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("body read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
body.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
state.lock().await.push(CapturedRequest {
|
||||
method,
|
||||
path,
|
||||
headers,
|
||||
body: String::from_utf8(body).expect("body should be utf8"),
|
||||
});
|
||||
|
||||
socket
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("request read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
if let Some(position) = find_header_end(&buffer) {
|
||||
header_end = Some(position);
|
||||
break;
|
||||
}
|
||||
.expect("response write should succeed");
|
||||
}
|
||||
|
||||
let header_end = header_end.expect("request should include headers");
|
||||
let (header_bytes, remaining) = buffer.split_at(header_end);
|
||||
let header_text = String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
|
||||
let mut lines = header_text.split("\r\n");
|
||||
let request_line = lines.next().expect("request line should exist");
|
||||
let mut parts = request_line.split_whitespace();
|
||||
let method = parts.next().expect("method should exist").to_string();
|
||||
let path = parts.next().expect("path should exist").to_string();
|
||||
let mut headers = HashMap::new();
|
||||
let mut content_length = 0_usize;
|
||||
for line in lines {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (name, value) = line.split_once(':').expect("header should have colon");
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse().expect("content length should parse");
|
||||
}
|
||||
headers.insert(name.to_ascii_lowercase(), value);
|
||||
}
|
||||
|
||||
let mut body = remaining[4..].to_vec();
|
||||
while body.len() < content_length {
|
||||
let mut chunk = vec![0_u8; content_length - body.len()];
|
||||
let read = socket
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.expect("body read should succeed");
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
body.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
state.lock().await.push(CapturedRequest {
|
||||
method,
|
||||
path,
|
||||
headers,
|
||||
body: String::from_utf8(body).expect("body should be utf8"),
|
||||
});
|
||||
|
||||
socket
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response write should succeed");
|
||||
});
|
||||
|
||||
TestServer {
|
||||
@@ -285,9 +387,23 @@ fn find_header_end(bytes: &[u8]) -> Option<usize> {
|
||||
bytes.windows(4).position(|window| window == b"\r\n\r\n")
|
||||
}
|
||||
|
||||
fn http_response(content_type: &str, body: &str) -> String {
|
||||
fn http_response(status: &str, content_type: &str, body: &str) -> String {
|
||||
http_response_with_headers(status, content_type, body, &[])
|
||||
}
|
||||
|
||||
fn http_response_with_headers(
|
||||
status: &str,
|
||||
content_type: &str,
|
||||
body: &str,
|
||||
headers: &[(&str, &str)],
|
||||
) -> String {
|
||||
let mut extra_headers = String::new();
|
||||
for (name, value) in headers {
|
||||
use std::fmt::Write as _;
|
||||
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
|
||||
}
|
||||
format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
)
|
||||
}
|
||||
@@ -296,8 +412,32 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
MessageRequest {
|
||||
model: "claude-3-7-sonnet-latest".to_string(),
|
||||
max_tokens: 64,
|
||||
messages: vec![InputMessage::user_text("Say hello")],
|
||||
system: None,
|
||||
messages: vec![InputMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
InputContentBlock::Text {
|
||||
text: "Say hello".to_string(),
|
||||
},
|
||||
InputContentBlock::ToolResult {
|
||||
tool_use_id: "toolu_prev".to_string(),
|
||||
content: vec![api::ToolResultContentBlock::Json {
|
||||
value: json!({"forecast": "sunny"}),
|
||||
}],
|
||||
is_error: false,
|
||||
},
|
||||
],
|
||||
}],
|
||||
system: Some("Use tools when needed".to_string()),
|
||||
tools: Some(vec![ToolDefinition {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Fetches the weather".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"]
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,3 +7,6 @@ publish.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
runtime = { path = "../runtime" }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use runtime::{compact_session, CompactionConfig, Session};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CommandManifestEntry {
|
||||
pub name: String,
|
||||
@@ -27,3 +29,82 @@ impl CommandRegistry {
|
||||
&self.entries
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SlashCommandResult {
|
||||
pub message: String,
|
||||
pub session: Session,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn handle_slash_command(
|
||||
input: &str,
|
||||
session: &Session,
|
||||
compaction: CompactionConfig,
|
||||
) -> Option<SlashCommandResult> {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
match trimmed.split_whitespace().next() {
|
||||
Some("/compact") => {
|
||||
let result = compact_session(session, compaction);
|
||||
let message = if result.removed_message_count == 0 {
|
||||
"Compaction skipped: session is below the compaction threshold.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"Compacted {} messages into a resumable system summary.",
|
||||
result.removed_message_count
|
||||
)
|
||||
};
|
||||
Some(SlashCommandResult {
|
||||
message,
|
||||
session: result.compacted_session,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::handle_slash_command;
|
||||
use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[test]
|
||||
fn compacts_sessions_via_slash_command() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
ConversationMessage::user_text("a ".repeat(200)),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "b ".repeat(200),
|
||||
}]),
|
||||
ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "recent".to_string(),
|
||||
}]),
|
||||
],
|
||||
};
|
||||
|
||||
let result = handle_slash_command(
|
||||
"/compact",
|
||||
&session,
|
||||
CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
},
|
||||
)
|
||||
.expect("slash command should be handled");
|
||||
|
||||
assert!(result.message.contains("Compacted 2 messages"));
|
||||
assert_eq!(result.session.messages[0].role, MessageRole::System);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_unknown_slash_commands() {
|
||||
let session = Session::new();
|
||||
assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
291
rust/crates/runtime/src/compact.rs
Normal file
291
rust/crates/runtime/src/compact.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct CompactionConfig {
|
||||
pub preserve_recent_messages: usize,
|
||||
pub max_estimated_tokens: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
preserve_recent_messages: 4,
|
||||
max_estimated_tokens: 10_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompactionResult {
|
||||
pub summary: String,
|
||||
pub compacted_session: Session,
|
||||
pub removed_message_count: usize,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn estimate_session_tokens(session: &Session) -> usize {
|
||||
session.messages.iter().map(estimate_message_tokens).sum()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn should_compact(session: &Session, config: CompactionConfig) -> bool {
|
||||
session.messages.len() > config.preserve_recent_messages
|
||||
&& estimate_session_tokens(session) >= config.max_estimated_tokens
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn format_compact_summary(summary: &str) -> String {
|
||||
let without_analysis = strip_tag_block(summary, "analysis");
|
||||
let formatted = if let Some(content) = extract_tag_block(&without_analysis, "summary") {
|
||||
without_analysis.replace(
|
||||
&format!("<summary>{content}</summary>"),
|
||||
&format!("Summary:\n{}", content.trim()),
|
||||
)
|
||||
} else {
|
||||
without_analysis
|
||||
};
|
||||
|
||||
collapse_blank_lines(&formatted).trim().to_string()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_compact_continuation_message(
|
||||
summary: &str,
|
||||
suppress_follow_up_questions: bool,
|
||||
recent_messages_preserved: bool,
|
||||
) -> String {
|
||||
let mut base = format!(
|
||||
"This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.\n\n{}",
|
||||
format_compact_summary(summary)
|
||||
);
|
||||
|
||||
if recent_messages_preserved {
|
||||
base.push_str("\n\nRecent messages are preserved verbatim.");
|
||||
}
|
||||
|
||||
if suppress_follow_up_questions {
|
||||
base.push_str("\nContinue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text.");
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn compact_session(session: &Session, config: CompactionConfig) -> CompactionResult {
|
||||
if !should_compact(session, config) {
|
||||
return CompactionResult {
|
||||
summary: String::new(),
|
||||
compacted_session: session.clone(),
|
||||
removed_message_count: 0,
|
||||
};
|
||||
}
|
||||
|
||||
let keep_from = session
|
||||
.messages
|
||||
.len()
|
||||
.saturating_sub(config.preserve_recent_messages);
|
||||
let removed = &session.messages[..keep_from];
|
||||
let preserved = session.messages[keep_from..].to_vec();
|
||||
let summary = summarize_messages(removed);
|
||||
let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty());
|
||||
|
||||
let mut compacted_messages = vec![ConversationMessage {
|
||||
role: MessageRole::System,
|
||||
blocks: vec![ContentBlock::Text { text: continuation }],
|
||||
usage: None,
|
||||
}];
|
||||
compacted_messages.extend(preserved);
|
||||
|
||||
CompactionResult {
|
||||
summary,
|
||||
compacted_session: Session {
|
||||
version: session.version,
|
||||
messages: compacted_messages,
|
||||
},
|
||||
removed_message_count: removed.len(),
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
||||
let mut lines = vec!["<summary>".to_string(), "Conversation summary:".to_string()];
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
MessageRole::Tool => "tool",
|
||||
};
|
||||
let content = message
|
||||
.blocks
|
||||
.iter()
|
||||
.map(summarize_block)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ");
|
||||
lines.push(format!("- {role}: {content}"));
|
||||
}
|
||||
lines.push("</summary>".to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn summarize_block(block: &ContentBlock) -> String {
|
||||
let raw = match block {
|
||||
ContentBlock::Text { text } => text.clone(),
|
||||
ContentBlock::ToolUse { name, input, .. } => format!("tool_use {name}({input})"),
|
||||
ContentBlock::ToolResult {
|
||||
tool_name,
|
||||
output,
|
||||
is_error,
|
||||
..
|
||||
} => format!(
|
||||
"tool_result {tool_name}: {}{output}",
|
||||
if *is_error { "error " } else { "" }
|
||||
),
|
||||
};
|
||||
truncate_summary(&raw, 160)
|
||||
}
|
||||
|
||||
fn truncate_summary(content: &str, max_chars: usize) -> String {
|
||||
if content.chars().count() <= max_chars {
|
||||
return content.to_string();
|
||||
}
|
||||
let mut truncated = content.chars().take(max_chars).collect::<String>();
|
||||
truncated.push('…');
|
||||
truncated
|
||||
}
|
||||
|
||||
fn estimate_message_tokens(message: &ConversationMessage) -> usize {
|
||||
message
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|block| match block {
|
||||
ContentBlock::Text { text } => text.len() / 4 + 1,
|
||||
ContentBlock::ToolUse { name, input, .. } => (name.len() + input.len()) / 4 + 1,
|
||||
ContentBlock::ToolResult {
|
||||
tool_name, output, ..
|
||||
} => (tool_name.len() + output.len()) / 4 + 1,
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn extract_tag_block(content: &str, tag: &str) -> Option<String> {
|
||||
let start = format!("<{tag}>");
|
||||
let end = format!("</{tag}>");
|
||||
let start_index = content.find(&start)? + start.len();
|
||||
let end_index = content[start_index..].find(&end)? + start_index;
|
||||
Some(content[start_index..end_index].to_string())
|
||||
}
|
||||
|
||||
fn strip_tag_block(content: &str, tag: &str) -> String {
|
||||
let start = format!("<{tag}>");
|
||||
let end = format!("</{tag}>");
|
||||
if let (Some(start_index), Some(end_index_rel)) = (content.find(&start), content.find(&end)) {
|
||||
let end_index = end_index_rel + end.len();
|
||||
let mut stripped = String::new();
|
||||
stripped.push_str(&content[..start_index]);
|
||||
stripped.push_str(&content[end_index..]);
|
||||
stripped
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn collapse_blank_lines(content: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut last_blank = false;
|
||||
for line in content.lines() {
|
||||
let is_blank = line.trim().is_empty();
|
||||
if is_blank && last_blank {
|
||||
continue;
|
||||
}
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
last_blank = is_blank;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
compact_session, estimate_session_tokens, format_compact_summary, should_compact,
|
||||
CompactionConfig,
|
||||
};
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[test]
|
||||
fn formats_compact_summary_like_upstream() {
|
||||
let summary = "<analysis>scratch</analysis>\n<summary>Kept work</summary>";
|
||||
assert_eq!(format_compact_summary(summary), "Summary:\nKept work");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leaves_small_sessions_unchanged() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![ConversationMessage::user_text("hello")],
|
||||
};
|
||||
|
||||
let result = compact_session(&session, CompactionConfig::default());
|
||||
assert_eq!(result.removed_message_count, 0);
|
||||
assert_eq!(result.compacted_session, session);
|
||||
assert!(result.summary.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compacts_older_messages_into_a_system_summary() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
ConversationMessage::user_text("one ".repeat(200)),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "two ".repeat(200),
|
||||
}]),
|
||||
ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false),
|
||||
ConversationMessage {
|
||||
role: MessageRole::Assistant,
|
||||
blocks: vec![ContentBlock::Text {
|
||||
text: "recent".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let result = compact_session(
|
||||
&session,
|
||||
CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(result.removed_message_count, 2);
|
||||
assert_eq!(
|
||||
result.compacted_session.messages[0].role,
|
||||
MessageRole::System
|
||||
);
|
||||
assert!(matches!(
|
||||
&result.compacted_session.messages[0].blocks[0],
|
||||
ContentBlock::Text { text } if text.contains("Summary:")
|
||||
));
|
||||
assert!(should_compact(
|
||||
&session,
|
||||
CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
}
|
||||
));
|
||||
assert!(
|
||||
estimate_session_tokens(&result.compacted_session) < estimate_session_tokens(&session)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_long_blocks_in_summary() {
|
||||
let summary = super::summarize_block(&ContentBlock::Text {
|
||||
text: "x".repeat(400),
|
||||
});
|
||||
assert!(summary.ends_with('…'));
|
||||
assert!(summary.chars().count() <= 161);
|
||||
}
|
||||
}
|
||||
269
rust/crates/runtime/src/config.rs
Normal file
269
rust/crates/runtime/src/config.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::json::JsonValue;
|
||||
|
||||
pub const CLAUDE_CODE_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum ConfigSource {
|
||||
User,
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ConfigEntry {
|
||||
pub source: ConfigSource,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RuntimeConfig {
|
||||
merged: BTreeMap<String, JsonValue>,
|
||||
loaded_entries: Vec<ConfigEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
Io(std::io::Error),
|
||||
Parse(String),
|
||||
}
|
||||
|
||||
impl Display for ConfigError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(error) => write!(f, "{error}"),
|
||||
Self::Parse(error) => write!(f, "{error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ConfigError {}
|
||||
|
||||
impl From<std::io::Error> for ConfigError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ConfigLoader {
|
||||
cwd: PathBuf,
|
||||
config_home: PathBuf,
|
||||
}
|
||||
|
||||
impl ConfigLoader {
|
||||
#[must_use]
|
||||
pub fn new(cwd: impl Into<PathBuf>, config_home: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
cwd: cwd.into(),
|
||||
config_home: config_home.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn default_for(cwd: impl Into<PathBuf>) -> Self {
|
||||
let cwd = cwd.into();
|
||||
let config_home = std::env::var_os("CLAUDE_CONFIG_HOME")
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".claude")))
|
||||
.unwrap_or_else(|| PathBuf::from(".claude"));
|
||||
Self { cwd, config_home }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn discover(&self) -> Vec<ConfigEntry> {
|
||||
vec![
|
||||
ConfigEntry {
|
||||
source: ConfigSource::User,
|
||||
path: self.config_home.join("settings.json"),
|
||||
},
|
||||
ConfigEntry {
|
||||
source: ConfigSource::Project,
|
||||
path: self.cwd.join(".claude").join("settings.json"),
|
||||
},
|
||||
ConfigEntry {
|
||||
source: ConfigSource::Local,
|
||||
path: self.cwd.join(".claude").join("settings.local.json"),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
pub fn load(&self) -> Result<RuntimeConfig, ConfigError> {
|
||||
let mut merged = BTreeMap::new();
|
||||
let mut loaded_entries = Vec::new();
|
||||
|
||||
for entry in self.discover() {
|
||||
let Some(value) = read_optional_json_object(&entry.path)? else {
|
||||
continue;
|
||||
};
|
||||
deep_merge_objects(&mut merged, &value);
|
||||
loaded_entries.push(entry);
|
||||
}
|
||||
|
||||
Ok(RuntimeConfig {
|
||||
merged,
|
||||
loaded_entries,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeConfig {
|
||||
#[must_use]
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
merged: BTreeMap::new(),
|
||||
loaded_entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn merged(&self) -> &BTreeMap<String, JsonValue> {
|
||||
&self.merged
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn loaded_entries(&self) -> &[ConfigEntry] {
|
||||
&self.loaded_entries
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get(&self, key: &str) -> Option<&JsonValue> {
|
||||
self.merged.get(key)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn as_json(&self) -> JsonValue {
|
||||
JsonValue::Object(self.merged.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn read_optional_json_object(
|
||||
path: &Path,
|
||||
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
||||
let contents = match fs::read_to_string(path) {
|
||||
Ok(contents) => contents,
|
||||
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(error) => return Err(ConfigError::Io(error)),
|
||||
};
|
||||
|
||||
if contents.trim().is_empty() {
|
||||
return Ok(Some(BTreeMap::new()));
|
||||
}
|
||||
|
||||
let parsed = JsonValue::parse(&contents)
|
||||
.map_err(|error| ConfigError::Parse(format!("{}: {error}", path.display())))?;
|
||||
let object = parsed.as_object().ok_or_else(|| {
|
||||
ConfigError::Parse(format!(
|
||||
"{}: top-level settings value must be a JSON object",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
Ok(Some(object.clone()))
|
||||
}
|
||||
|
||||
fn deep_merge_objects(
|
||||
target: &mut BTreeMap<String, JsonValue>,
|
||||
source: &BTreeMap<String, JsonValue>,
|
||||
) {
|
||||
for (key, value) in source {
|
||||
match (target.get_mut(key), value) {
|
||||
(Some(JsonValue::Object(existing)), JsonValue::Object(incoming)) => {
|
||||
deep_merge_objects(existing, incoming);
|
||||
}
|
||||
_ => {
|
||||
target.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ConfigLoader, ConfigSource, CLAUDE_CODE_SETTINGS_SCHEMA_NAME};
|
||||
use crate::json::JsonValue;
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-config-{nanos}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_non_object_settings_files() {
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claude");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(home.join("settings.json"), "[]").expect("write bad settings");
|
||||
|
||||
let error = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("top-level settings value must be a JSON object"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loads_and_merges_claude_code_config_files_by_precedence() {
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claude");
|
||||
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{"model":"sonnet","env":{"A":"1"},"hooks":{"PreToolUse":["base"]}}"#,
|
||||
)
|
||||
.expect("write user settings");
|
||||
fs::write(
|
||||
cwd.join(".claude").join("settings.json"),
|
||||
r#"{"env":{"B":"2"},"hooks":{"PostToolUse":["project"]}}"#,
|
||||
)
|
||||
.expect("write project settings");
|
||||
fs::write(
|
||||
cwd.join(".claude").join("settings.local.json"),
|
||||
r#"{"model":"opus","permissionMode":"acceptEdits"}"#,
|
||||
)
|
||||
.expect("write local settings");
|
||||
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
||||
assert_eq!(loaded.loaded_entries().len(), 3);
|
||||
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
||||
assert_eq!(
|
||||
loaded.get("model"),
|
||||
Some(&JsonValue::String("opus".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
loaded
|
||||
.get("env")
|
||||
.and_then(JsonValue::as_object)
|
||||
.expect("env object")
|
||||
.len(),
|
||||
2
|
||||
);
|
||||
assert!(loaded
|
||||
.get("hooks")
|
||||
.and_then(JsonValue::as_object)
|
||||
.expect("hooks object")
|
||||
.contains_key("PreToolUse"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use crate::compact::{
|
||||
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
||||
};
|
||||
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
|
||||
use crate::session::{ContentBlock, ConversationMessage, Session};
|
||||
use crate::usage::{TokenUsage, UsageTracker};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ApiRequest {
|
||||
@@ -18,6 +22,7 @@ pub enum AssistantEvent {
|
||||
name: String,
|
||||
input: String,
|
||||
},
|
||||
Usage(TokenUsage),
|
||||
MessageStop,
|
||||
}
|
||||
|
||||
@@ -78,6 +83,7 @@ pub struct TurnSummary {
|
||||
pub assistant_messages: Vec<ConversationMessage>,
|
||||
pub tool_results: Vec<ConversationMessage>,
|
||||
pub iterations: usize,
|
||||
pub usage: TokenUsage,
|
||||
}
|
||||
|
||||
pub struct ConversationRuntime<C, T> {
|
||||
@@ -87,6 +93,7 @@ pub struct ConversationRuntime<C, T> {
|
||||
permission_policy: PermissionPolicy,
|
||||
system_prompt: Vec<String>,
|
||||
max_iterations: usize,
|
||||
usage_tracker: UsageTracker,
|
||||
}
|
||||
|
||||
impl<C, T> ConversationRuntime<C, T>
|
||||
@@ -102,6 +109,7 @@ where
|
||||
permission_policy: PermissionPolicy,
|
||||
system_prompt: Vec<String>,
|
||||
) -> Self {
|
||||
let usage_tracker = UsageTracker::from_session(&session);
|
||||
Self {
|
||||
session,
|
||||
api_client,
|
||||
@@ -109,6 +117,7 @@ where
|
||||
permission_policy,
|
||||
system_prompt,
|
||||
max_iterations: 16,
|
||||
usage_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +153,10 @@ where
|
||||
messages: self.session.messages.clone(),
|
||||
};
|
||||
let events = self.api_client.stream(request)?;
|
||||
let assistant_message = build_assistant_message(events)?;
|
||||
let (assistant_message, usage) = build_assistant_message(events)?;
|
||||
if let Some(usage) = usage {
|
||||
self.usage_tracker.record(usage);
|
||||
}
|
||||
let pending_tool_uses = assistant_message
|
||||
.blocks
|
||||
.iter()
|
||||
@@ -201,9 +213,25 @@ where
|
||||
assistant_messages,
|
||||
tool_results,
|
||||
iterations,
|
||||
usage: self.usage_tracker.cumulative_usage(),
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
|
||||
compact_session(&self.session, config)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn estimated_tokens(&self) -> usize {
|
||||
estimate_session_tokens(&self.session)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn usage(&self) -> &UsageTracker {
|
||||
&self.usage_tracker
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn session(&self) -> &Session {
|
||||
&self.session
|
||||
@@ -217,10 +245,11 @@ where
|
||||
|
||||
fn build_assistant_message(
|
||||
events: Vec<AssistantEvent>,
|
||||
) -> Result<ConversationMessage, RuntimeError> {
|
||||
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
|
||||
let mut text = String::new();
|
||||
let mut blocks = Vec::new();
|
||||
let mut finished = false;
|
||||
let mut usage = None;
|
||||
|
||||
for event in events {
|
||||
match event {
|
||||
@@ -229,6 +258,7 @@ fn build_assistant_message(
|
||||
flush_text_block(&mut text, &mut blocks);
|
||||
blocks.push(ContentBlock::ToolUse { id, name, input });
|
||||
}
|
||||
AssistantEvent::Usage(value) => usage = Some(value),
|
||||
AssistantEvent::MessageStop => {
|
||||
finished = true;
|
||||
}
|
||||
@@ -246,7 +276,10 @@ fn build_assistant_message(
|
||||
return Err(RuntimeError::new("assistant stream produced no content"));
|
||||
}
|
||||
|
||||
Ok(ConversationMessage::assistant(blocks))
|
||||
Ok((
|
||||
ConversationMessage::assistant_with_usage(blocks, usage),
|
||||
usage,
|
||||
))
|
||||
}
|
||||
|
||||
fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
|
||||
@@ -295,12 +328,15 @@ mod tests {
|
||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
|
||||
StaticToolExecutor,
|
||||
};
|
||||
use crate::compact::CompactionConfig;
|
||||
use crate::permissions::{
|
||||
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
||||
PermissionRequest,
|
||||
};
|
||||
use crate::prompt::SystemPromptBuilder;
|
||||
use crate::prompt::{ProjectContext, SystemPromptBuilder};
|
||||
use crate::session::{ContentBlock, MessageRole, Session};
|
||||
use crate::usage::TokenUsage;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct ScriptedApiClient {
|
||||
call_count: usize,
|
||||
@@ -322,6 +358,12 @@ mod tests {
|
||||
name: "add".to_string(),
|
||||
input: "2,2".to_string(),
|
||||
},
|
||||
AssistantEvent::Usage(TokenUsage {
|
||||
input_tokens: 20,
|
||||
output_tokens: 6,
|
||||
cache_creation_input_tokens: 1,
|
||||
cache_read_input_tokens: 2,
|
||||
}),
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
@@ -333,6 +375,12 @@ mod tests {
|
||||
assert_eq!(last_message.role, MessageRole::Tool);
|
||||
Ok(vec![
|
||||
AssistantEvent::TextDelta("The answer is 4.".to_string()),
|
||||
AssistantEvent::Usage(TokenUsage {
|
||||
input_tokens: 24,
|
||||
output_tokens: 4,
|
||||
cache_creation_input_tokens: 1,
|
||||
cache_read_input_tokens: 3,
|
||||
}),
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
@@ -351,7 +399,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runs_user_to_tool_to_result_loop_end_to_end() {
|
||||
fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
|
||||
let api_client = ScriptedApiClient { call_count: 0 };
|
||||
let tool_executor = StaticToolExecutor::new().register("add", |input| {
|
||||
let total = input
|
||||
@@ -362,9 +410,13 @@ mod tests {
|
||||
});
|
||||
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
|
||||
let system_prompt = SystemPromptBuilder::new()
|
||||
.with_cwd("/tmp/project")
|
||||
.with_project_context(ProjectContext {
|
||||
cwd: PathBuf::from("/tmp/project"),
|
||||
current_date: "2026-03-31".to_string(),
|
||||
git_status: None,
|
||||
instruction_files: Vec::new(),
|
||||
})
|
||||
.with_os("linux", "6.8")
|
||||
.with_date("2026-03-31")
|
||||
.build();
|
||||
let mut runtime = ConversationRuntime::new(
|
||||
Session::new(),
|
||||
@@ -382,6 +434,7 @@ mod tests {
|
||||
assert_eq!(summary.assistant_messages.len(), 2);
|
||||
assert_eq!(summary.tool_results.len(), 1);
|
||||
assert_eq!(runtime.session().messages.len(), 4);
|
||||
assert_eq!(summary.usage.output_tokens, 10);
|
||||
assert!(matches!(
|
||||
runtime.session().messages[1].blocks[1],
|
||||
ContentBlock::ToolUse { .. }
|
||||
@@ -448,4 +501,83 @@ mod tests {
|
||||
ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconstructs_usage_tracker_from_restored_session() {
|
||||
struct SimpleApi;
|
||||
impl ApiClient for SimpleApi {
|
||||
fn stream(
|
||||
&mut self,
|
||||
_request: ApiRequest,
|
||||
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||
Ok(vec![
|
||||
AssistantEvent::TextDelta("done".to_string()),
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.messages
|
||||
.push(crate::session::ConversationMessage::assistant_with_usage(
|
||||
vec![ContentBlock::Text {
|
||||
text: "earlier".to_string(),
|
||||
}],
|
||||
Some(TokenUsage {
|
||||
input_tokens: 11,
|
||||
output_tokens: 7,
|
||||
cache_creation_input_tokens: 2,
|
||||
cache_read_input_tokens: 1,
|
||||
}),
|
||||
));
|
||||
|
||||
let runtime = ConversationRuntime::new(
|
||||
session,
|
||||
SimpleApi,
|
||||
StaticToolExecutor::new(),
|
||||
PermissionPolicy::new(PermissionMode::Allow),
|
||||
vec!["system".to_string()],
|
||||
);
|
||||
|
||||
assert_eq!(runtime.usage().turns(), 1);
|
||||
assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compacts_session_after_turns() {
|
||||
struct SimpleApi;
|
||||
impl ApiClient for SimpleApi {
|
||||
fn stream(
|
||||
&mut self,
|
||||
_request: ApiRequest,
|
||||
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||
Ok(vec![
|
||||
AssistantEvent::TextDelta("done".to_string()),
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
let mut runtime = ConversationRuntime::new(
|
||||
Session::new(),
|
||||
SimpleApi,
|
||||
StaticToolExecutor::new(),
|
||||
PermissionPolicy::new(PermissionMode::Allow),
|
||||
vec!["system".to_string()],
|
||||
);
|
||||
runtime.run_turn("a", None).expect("turn a");
|
||||
runtime.run_turn("b", None).expect("turn b");
|
||||
runtime.run_turn("c", None).expect("turn c");
|
||||
|
||||
let result = runtime.compact(CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
});
|
||||
assert!(result.summary.contains("Conversation summary"));
|
||||
assert_eq!(
|
||||
result.compacted_session.messages[0].role,
|
||||
MessageRole::System
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
mod bootstrap;
|
||||
mod compact;
|
||||
mod config;
|
||||
mod conversation;
|
||||
mod json;
|
||||
mod permissions;
|
||||
mod prompt;
|
||||
mod session;
|
||||
mod usage;
|
||||
|
||||
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
|
||||
pub use compact::{
|
||||
compact_session, estimate_session_tokens, format_compact_summary,
|
||||
get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult,
|
||||
};
|
||||
pub use config::{
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, RuntimeConfig,
|
||||
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
pub use conversation::{
|
||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||
ToolError, ToolExecutor, TurnSummary,
|
||||
@@ -15,6 +26,8 @@ pub use permissions::{
|
||||
PermissionPrompter, PermissionRequest,
|
||||
};
|
||||
pub use prompt::{
|
||||
prepend_bullets, SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
};
|
||||
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
||||
pub use usage::{TokenUsage, UsageTracker};
|
||||
|
||||
@@ -1,15 +1,89 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use crate::config::{ConfigError, ConfigLoader, RuntimeConfig};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PromptBuildError {
|
||||
Io(std::io::Error),
|
||||
Config(ConfigError),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PromptBuildError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(error) => write!(f, "{error}"),
|
||||
Self::Config(error) => write!(f, "{error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PromptBuildError {}
|
||||
|
||||
impl From<std::io::Error> for PromptBuildError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConfigError> for PromptBuildError {
|
||||
fn from(value: ConfigError) -> Self {
|
||||
Self::Config(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
||||
pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ContextFile {
|
||||
pub path: PathBuf,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProjectContext {
|
||||
pub cwd: PathBuf,
|
||||
pub current_date: String,
|
||||
pub git_status: Option<String>,
|
||||
pub instruction_files: Vec<ContextFile>,
|
||||
}
|
||||
|
||||
impl ProjectContext {
|
||||
pub fn discover(
|
||||
cwd: impl Into<PathBuf>,
|
||||
current_date: impl Into<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let cwd = cwd.into();
|
||||
let instruction_files = discover_instruction_files(&cwd)?;
|
||||
Ok(Self {
|
||||
cwd,
|
||||
current_date: current_date.into(),
|
||||
git_status: None,
|
||||
instruction_files,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn discover_with_git(
|
||||
cwd: impl Into<PathBuf>,
|
||||
current_date: impl Into<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let mut context = Self::discover(cwd, current_date)?;
|
||||
context.git_status = read_git_status(&context.cwd);
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct SystemPromptBuilder {
|
||||
output_style_name: Option<String>,
|
||||
output_style_prompt: Option<String>,
|
||||
cwd: Option<String>,
|
||||
os_name: Option<String>,
|
||||
os_version: Option<String>,
|
||||
date: Option<String>,
|
||||
append_sections: Vec<String>,
|
||||
project_context: Option<ProjectContext>,
|
||||
config: Option<RuntimeConfig>,
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
@@ -25,12 +99,6 @@ impl SystemPromptBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
|
||||
self.cwd = Some(cwd.into());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_os(mut self, os_name: impl Into<String>, os_version: impl Into<String>) -> Self {
|
||||
self.os_name = Some(os_name.into());
|
||||
@@ -39,8 +107,14 @@ impl SystemPromptBuilder {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_date(mut self, date: impl Into<String>) -> Self {
|
||||
self.date = Some(date.into());
|
||||
pub fn with_project_context(mut self, project_context: ProjectContext) -> Self {
|
||||
self.project_context = Some(project_context);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_runtime_config(mut self, config: RuntimeConfig) -> Self {
|
||||
self.config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -62,6 +136,15 @@ impl SystemPromptBuilder {
|
||||
sections.push(get_actions_section());
|
||||
sections.push(SYSTEM_PROMPT_DYNAMIC_BOUNDARY.to_string());
|
||||
sections.push(self.environment_section());
|
||||
if let Some(project_context) = &self.project_context {
|
||||
sections.push(render_project_context(project_context));
|
||||
if !project_context.instruction_files.is_empty() {
|
||||
sections.push(render_instruction_files(&project_context.instruction_files));
|
||||
}
|
||||
}
|
||||
if let Some(config) = &self.config {
|
||||
sections.push(render_config_section(config));
|
||||
}
|
||||
sections.extend(self.append_sections.iter().cloned());
|
||||
sections
|
||||
}
|
||||
@@ -72,14 +155,19 @@ impl SystemPromptBuilder {
|
||||
}
|
||||
|
||||
fn environment_section(&self) -> String {
|
||||
let cwd = self.project_context.as_ref().map_or_else(
|
||||
|| "unknown".to_string(),
|
||||
|context| context.cwd.display().to_string(),
|
||||
);
|
||||
let date = self.project_context.as_ref().map_or_else(
|
||||
|| "unknown".to_string(),
|
||||
|context| context.current_date.clone(),
|
||||
);
|
||||
let mut lines = vec!["# Environment context".to_string()];
|
||||
lines.extend(prepend_bullets(vec![
|
||||
format!("Model family: {FRONTIER_MODEL_NAME}"),
|
||||
format!(
|
||||
"Working directory: {}",
|
||||
self.cwd.as_deref().unwrap_or("unknown")
|
||||
),
|
||||
format!("Date: {}", self.date.as_deref().unwrap_or("unknown")),
|
||||
format!("Working directory: {cwd}"),
|
||||
format!("Date: {date}"),
|
||||
format!(
|
||||
"Platform: {} {}",
|
||||
self.os_name.as_deref().unwrap_or("unknown"),
|
||||
@@ -95,6 +183,118 @@ pub fn prepend_bullets(items: Vec<String>) -> Vec<String> {
|
||||
items.into_iter().map(|item| format!(" - {item}")).collect()
|
||||
}
|
||||
|
||||
fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
let mut directories = Vec::new();
|
||||
let mut cursor = Some(cwd);
|
||||
while let Some(dir) = cursor {
|
||||
directories.push(dir.to_path_buf());
|
||||
cursor = dir.parent();
|
||||
}
|
||||
directories.reverse();
|
||||
|
||||
let mut files = Vec::new();
|
||||
for dir in directories {
|
||||
for candidate in [
|
||||
dir.join("CLAUDE.md"),
|
||||
dir.join("CLAUDE.local.md"),
|
||||
dir.join(".claude").join("CLAUDE.md"),
|
||||
] {
|
||||
push_context_file(&mut files, candidate)?;
|
||||
}
|
||||
}
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
fn push_context_file(files: &mut Vec<ContextFile>, path: PathBuf) -> std::io::Result<()> {
|
||||
match fs::read_to_string(&path) {
|
||||
Ok(content) if !content.trim().is_empty() => {
|
||||
files.push(ContextFile { path, content });
|
||||
Ok(())
|
||||
}
|
||||
Ok(_) => Ok(()),
|
||||
Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_git_status(cwd: &Path) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["--no-optional-locks", "status", "--short", "--branch"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = stdout.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
let mut lines = vec!["# Project context".to_string()];
|
||||
lines.extend(prepend_bullets(vec![format!(
|
||||
"Today's date is {}.",
|
||||
project_context.current_date
|
||||
)]));
|
||||
if let Some(status) = &project_context.git_status {
|
||||
lines.push(String::new());
|
||||
lines.push("Git status snapshot:".to_string());
|
||||
lines.push(status.clone());
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn render_instruction_files(files: &[ContextFile]) -> String {
|
||||
let mut sections = vec!["# Claude instructions".to_string()];
|
||||
for file in files {
|
||||
sections.push(format!("## {}", file.path.display()));
|
||||
sections.push(file.content.trim().to_string());
|
||||
}
|
||||
sections.join("\n\n")
|
||||
}
|
||||
|
||||
pub fn load_system_prompt(
|
||||
cwd: impl Into<PathBuf>,
|
||||
current_date: impl Into<String>,
|
||||
os_name: impl Into<String>,
|
||||
os_version: impl Into<String>,
|
||||
) -> Result<Vec<String>, PromptBuildError> {
|
||||
let cwd = cwd.into();
|
||||
let project_context = ProjectContext::discover_with_git(&cwd, current_date.into())?;
|
||||
let config = ConfigLoader::default_for(&cwd).load()?;
|
||||
Ok(SystemPromptBuilder::new()
|
||||
.with_os(os_name, os_version)
|
||||
.with_project_context(project_context)
|
||||
.with_runtime_config(config)
|
||||
.build())
|
||||
}
|
||||
|
||||
fn render_config_section(config: &RuntimeConfig) -> String {
|
||||
let mut lines = vec!["# Runtime config".to_string()];
|
||||
if config.loaded_entries().is_empty() {
|
||||
lines.extend(prepend_bullets(vec![
|
||||
"No Claude Code settings files loaded.".to_string(),
|
||||
]));
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
lines.extend(prepend_bullets(
|
||||
config
|
||||
.loaded_entries()
|
||||
.iter()
|
||||
.map(|entry| format!("Loaded {:?}: {}", entry.source, entry.path.display()))
|
||||
.collect(),
|
||||
));
|
||||
lines.push(String::new());
|
||||
lines.push(config.as_json().render());
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn get_simple_intro_section(has_output_style: bool) -> String {
|
||||
format!(
|
||||
"You are an interactive agent that helps users {} Use the instructions below and the tools available to you to assist the user.\n\nIMPORTANT: You must NEVER generate or guess URLs for the user unless you are confident that the URLs are for helping the user with programming. You may use URLs provided by the user in their messages or local files.",
|
||||
@@ -148,22 +348,132 @@ fn get_actions_section() -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{SystemPromptBuilder, SYSTEM_PROMPT_DYNAMIC_BOUNDARY};
|
||||
use super::{ProjectContext, SystemPromptBuilder, SYSTEM_PROMPT_DYNAMIC_BOUNDARY};
|
||||
use crate::config::ConfigLoader;
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_claude_code_style_sections() {
|
||||
fn discovers_instruction_files_from_ancestor_chain() {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(nested.join(".claude")).expect("nested claude dir");
|
||||
fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions");
|
||||
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
||||
.expect("write local instructions");
|
||||
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
||||
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
||||
.expect("write apps instructions");
|
||||
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
||||
.expect("write nested rules");
|
||||
|
||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||
let contents = context
|
||||
.instruction_files
|
||||
.iter()
|
||||
.map(|file| file.content.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
contents,
|
||||
vec![
|
||||
"root instructions",
|
||||
"local instructions",
|
||||
"apps instructions",
|
||||
"nested rules"
|
||||
]
|
||||
);
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_status_snapshot() {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir");
|
||||
std::process::Command::new("git")
|
||||
.args(["init", "--quiet"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git init should run");
|
||||
fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions");
|
||||
fs::write(root.join("tracked.txt"), "hello").expect("write tracked file");
|
||||
|
||||
let context =
|
||||
ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load");
|
||||
|
||||
let status = context.git_status.expect("git status should be present");
|
||||
assert!(status.contains("## No commits yet on") || status.contains("## "));
|
||||
assert!(status.contains("?? CLAUDE.md"));
|
||||
assert!(status.contains("?? tracked.txt"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_system_prompt_reads_claude_files_and_config() {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(root.join(".claude")).expect("claude dir");
|
||||
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions");
|
||||
fs::write(
|
||||
root.join(".claude").join("settings.json"),
|
||||
r#"{"permissionMode":"acceptEdits"}"#,
|
||||
)
|
||||
.expect("write settings");
|
||||
|
||||
let previous = std::env::current_dir().expect("cwd");
|
||||
std::env::set_current_dir(&root).expect("change cwd");
|
||||
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
||||
.expect("system prompt should load")
|
||||
.join(
|
||||
"
|
||||
|
||||
",
|
||||
);
|
||||
std::env::set_current_dir(previous).expect("restore cwd");
|
||||
|
||||
assert!(prompt.contains("Project rules"));
|
||||
assert!(prompt.contains("permissionMode"));
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_claude_code_style_sections_with_project_context() {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(root.join(".claude")).expect("claude dir");
|
||||
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md");
|
||||
fs::write(
|
||||
root.join(".claude").join("settings.json"),
|
||||
r#"{"permissionMode":"acceptEdits"}"#,
|
||||
)
|
||||
.expect("write settings");
|
||||
|
||||
let project_context =
|
||||
ProjectContext::discover(&root, "2026-03-31").expect("context should load");
|
||||
let config = ConfigLoader::new(&root, root.join("missing-home"))
|
||||
.load()
|
||||
.expect("config should load");
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_output_style("Concise", "Prefer short answers.")
|
||||
.with_cwd("/tmp/project")
|
||||
.with_os("linux", "6.8")
|
||||
.with_date("2026-03-31")
|
||||
.append_section("# Custom\nExtra")
|
||||
.with_project_context(project_context)
|
||||
.with_runtime_config(config)
|
||||
.render();
|
||||
|
||||
assert!(prompt.contains("# System"));
|
||||
assert!(prompt.contains("# Doing tasks"));
|
||||
assert!(prompt.contains("# Executing actions with care"));
|
||||
assert!(prompt.contains("# Project context"));
|
||||
assert!(prompt.contains("# Claude instructions"));
|
||||
assert!(prompt.contains("Project rules"));
|
||||
assert!(prompt.contains("permissionMode"));
|
||||
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
|
||||
assert!(prompt.contains("Working directory: /tmp/project"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::json::{JsonError, JsonValue};
|
||||
use crate::usage::TokenUsage;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MessageRole {
|
||||
@@ -35,6 +36,7 @@ pub enum ContentBlock {
|
||||
pub struct ConversationMessage {
|
||||
pub role: MessageRole,
|
||||
pub blocks: Vec<ContentBlock>,
|
||||
pub usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -145,6 +147,7 @@ impl ConversationMessage {
|
||||
Self {
|
||||
role: MessageRole::User,
|
||||
blocks: vec![ContentBlock::Text { text: text.into() }],
|
||||
usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,6 +156,16 @@ impl ConversationMessage {
|
||||
Self {
|
||||
role: MessageRole::Assistant,
|
||||
blocks,
|
||||
usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
|
||||
Self {
|
||||
role: MessageRole::Assistant,
|
||||
blocks,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,6 +184,7 @@ impl ConversationMessage {
|
||||
output: output.into(),
|
||||
is_error,
|
||||
}],
|
||||
usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +207,9 @@ impl ConversationMessage {
|
||||
"blocks".to_string(),
|
||||
JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
|
||||
);
|
||||
if let Some(usage) = self.usage {
|
||||
object.insert("usage".to_string(), usage_to_json(usage));
|
||||
}
|
||||
JsonValue::Object(object)
|
||||
}
|
||||
|
||||
@@ -222,7 +239,12 @@ impl ConversationMessage {
|
||||
.iter()
|
||||
.map(ContentBlock::from_json)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(Self { role, blocks })
|
||||
let usage = object.get("usage").map(usage_from_json).transpose()?;
|
||||
Ok(Self {
|
||||
role,
|
||||
blocks,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,6 +324,39 @@ impl ContentBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn usage_to_json(usage: TokenUsage) -> JsonValue {
|
||||
let mut object = BTreeMap::new();
|
||||
object.insert(
|
||||
"input_tokens".to_string(),
|
||||
JsonValue::Number(i64::from(usage.input_tokens)),
|
||||
);
|
||||
object.insert(
|
||||
"output_tokens".to_string(),
|
||||
JsonValue::Number(i64::from(usage.output_tokens)),
|
||||
);
|
||||
object.insert(
|
||||
"cache_creation_input_tokens".to_string(),
|
||||
JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
|
||||
);
|
||||
object.insert(
|
||||
"cache_read_input_tokens".to_string(),
|
||||
JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
|
||||
);
|
||||
JsonValue::Object(object)
|
||||
}
|
||||
|
||||
fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
|
||||
let object = value
|
||||
.as_object()
|
||||
.ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
|
||||
Ok(TokenUsage {
|
||||
input_tokens: required_u32(object, "input_tokens")?,
|
||||
output_tokens: required_u32(object, "output_tokens")?,
|
||||
cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
|
||||
cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
|
||||
})
|
||||
}
|
||||
|
||||
fn required_string(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
key: &str,
|
||||
@@ -313,9 +368,18 @@ fn required_string(
|
||||
.ok_or_else(|| SessionError::Format(format!("missing {key}")))
|
||||
}
|
||||
|
||||
fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
|
||||
let value = object
|
||||
.get(key)
|
||||
.and_then(JsonValue::as_i64)
|
||||
.ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
|
||||
u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
use crate::usage::TokenUsage;
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
@@ -325,16 +389,26 @@ mod tests {
|
||||
session
|
||||
.messages
|
||||
.push(ConversationMessage::user_text("hello"));
|
||||
session.messages.push(ConversationMessage::assistant(vec![
|
||||
ContentBlock::Text {
|
||||
text: "thinking".to_string(),
|
||||
},
|
||||
ContentBlock::ToolUse {
|
||||
id: "tool-1".to_string(),
|
||||
name: "bash".to_string(),
|
||||
input: "echo hi".to_string(),
|
||||
},
|
||||
]));
|
||||
session
|
||||
.messages
|
||||
.push(ConversationMessage::assistant_with_usage(
|
||||
vec![
|
||||
ContentBlock::Text {
|
||||
text: "thinking".to_string(),
|
||||
},
|
||||
ContentBlock::ToolUse {
|
||||
id: "tool-1".to_string(),
|
||||
name: "bash".to_string(),
|
||||
input: "echo hi".to_string(),
|
||||
},
|
||||
],
|
||||
Some(TokenUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 4,
|
||||
cache_creation_input_tokens: 1,
|
||||
cache_read_input_tokens: 2,
|
||||
}),
|
||||
));
|
||||
session.messages.push(ConversationMessage::tool_result(
|
||||
"tool-1", "bash", "hi", false,
|
||||
));
|
||||
@@ -350,5 +424,9 @@ mod tests {
|
||||
|
||||
assert_eq!(restored, session);
|
||||
assert_eq!(restored.messages[2].role, MessageRole::Tool);
|
||||
assert_eq!(
|
||||
restored.messages[1].usage.expect("usage").total_tokens(),
|
||||
17
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
121
rust/crates/runtime/src/usage.rs
Normal file
121
rust/crates/runtime/src/usage.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use crate::session::Session;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cache_creation_input_tokens: u32,
|
||||
pub cache_read_input_tokens: u32,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
#[must_use]
|
||||
pub fn total_tokens(self) -> u32 {
|
||||
self.input_tokens
|
||||
+ self.output_tokens
|
||||
+ self.cache_creation_input_tokens
|
||||
+ self.cache_read_input_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct UsageTracker {
|
||||
latest_turn: TokenUsage,
|
||||
cumulative: TokenUsage,
|
||||
turns: u32,
|
||||
}
|
||||
|
||||
impl UsageTracker {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_session(session: &Session) -> Self {
|
||||
let mut tracker = Self::new();
|
||||
for message in &session.messages {
|
||||
if let Some(usage) = message.usage {
|
||||
tracker.record(usage);
|
||||
}
|
||||
}
|
||||
tracker
|
||||
}
|
||||
|
||||
pub fn record(&mut self, usage: TokenUsage) {
|
||||
self.latest_turn = usage;
|
||||
self.cumulative.input_tokens += usage.input_tokens;
|
||||
self.cumulative.output_tokens += usage.output_tokens;
|
||||
self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
|
||||
self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
|
||||
self.turns += 1;
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn current_turn_usage(&self) -> TokenUsage {
|
||||
self.latest_turn
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn cumulative_usage(&self) -> TokenUsage {
|
||||
self.cumulative
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn turns(&self) -> u32 {
|
||||
self.turns
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{TokenUsage, UsageTracker};
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[test]
|
||||
fn tracks_true_cumulative_usage() {
|
||||
let mut tracker = UsageTracker::new();
|
||||
tracker.record(TokenUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 4,
|
||||
cache_creation_input_tokens: 2,
|
||||
cache_read_input_tokens: 1,
|
||||
});
|
||||
tracker.record(TokenUsage {
|
||||
input_tokens: 20,
|
||||
output_tokens: 6,
|
||||
cache_creation_input_tokens: 3,
|
||||
cache_read_input_tokens: 2,
|
||||
});
|
||||
|
||||
assert_eq!(tracker.turns(), 2);
|
||||
assert_eq!(tracker.current_turn_usage().input_tokens, 20);
|
||||
assert_eq!(tracker.current_turn_usage().output_tokens, 6);
|
||||
assert_eq!(tracker.cumulative_usage().output_tokens, 10);
|
||||
assert_eq!(tracker.cumulative_usage().input_tokens, 30);
|
||||
assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconstructs_usage_from_session_messages() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![ConversationMessage {
|
||||
role: MessageRole::Assistant,
|
||||
blocks: vec![ContentBlock::Text {
|
||||
text: "done".to_string(),
|
||||
}],
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 2,
|
||||
cache_creation_input_tokens: 1,
|
||||
cache_read_input_tokens: 0,
|
||||
}),
|
||||
}],
|
||||
};
|
||||
|
||||
let tracker = UsageTracker::from_session(&session);
|
||||
assert_eq!(tracker.turns(), 1);
|
||||
assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
|
||||
}
|
||||
}
|
||||
@@ -6,12 +6,9 @@ license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.5.38", features = ["derive"] }
|
||||
commands = { path = "../commands" }
|
||||
compat-harness = { path = "../compat-harness" }
|
||||
crossterm = "0.29.0"
|
||||
pulldown-cmark = "0.13.0"
|
||||
runtime = { path = "../runtime" }
|
||||
syntect = { version = "5.2.0", default-features = false, features = ["default-fancy"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::args::{OutputFormat, PermissionMode};
|
||||
use crate::input::LineEditor;
|
||||
use crate::render::{Spinner, TerminalRenderer};
|
||||
use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionConfig {
|
||||
@@ -20,6 +19,7 @@ pub struct SessionState {
|
||||
pub turns: usize,
|
||||
pub compacted_messages: usize,
|
||||
pub last_model: String,
|
||||
pub last_usage: UsageSummary,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
@@ -29,6 +29,7 @@ impl SessionState {
|
||||
turns: 0,
|
||||
compacted_messages: 0,
|
||||
last_model: model.into(),
|
||||
last_usage: UsageSummary::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,17 +93,21 @@ pub struct CliApp {
|
||||
config: SessionConfig,
|
||||
renderer: TerminalRenderer,
|
||||
state: SessionState,
|
||||
conversation_client: ConversationClient,
|
||||
conversation_history: Vec<ConversationMessage>,
|
||||
}
|
||||
|
||||
impl CliApp {
|
||||
#[must_use]
|
||||
pub fn new(config: SessionConfig) -> Self {
|
||||
pub fn new(config: SessionConfig) -> Result<Self, RuntimeError> {
|
||||
let state = SessionState::new(config.model.clone());
|
||||
Self {
|
||||
let conversation_client = ConversationClient::from_env(config.model.clone())?;
|
||||
Ok(Self {
|
||||
config,
|
||||
renderer: TerminalRenderer::new(),
|
||||
state,
|
||||
}
|
||||
conversation_client,
|
||||
conversation_history: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_repl(&mut self) -> io::Result<()> {
|
||||
@@ -172,11 +177,13 @@ impl CliApp {
|
||||
fn handle_status(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
|
||||
writeln!(
|
||||
out,
|
||||
"status: turns={} model={} permission-mode={:?} output-format={:?} config={}",
|
||||
"status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}",
|
||||
self.state.turns,
|
||||
self.state.last_model,
|
||||
self.config.permission_mode,
|
||||
self.config.output_format,
|
||||
self.state.last_usage.input_tokens,
|
||||
self.state.last_usage.output_tokens,
|
||||
self.config
|
||||
.config
|
||||
.as_ref()
|
||||
@@ -188,6 +195,7 @@ impl CliApp {
|
||||
fn handle_compact(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
|
||||
self.state.compacted_messages += self.state.turns;
|
||||
self.state.turns = 0;
|
||||
self.conversation_history.clear();
|
||||
writeln!(
|
||||
out,
|
||||
"Compacted session history into a local summary ({} messages total compacted).",
|
||||
@@ -196,46 +204,147 @@ impl CliApp {
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
|
||||
fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> {
|
||||
let mut spinner = Spinner::new();
|
||||
for label in [
|
||||
"Planning response",
|
||||
"Running tool execution",
|
||||
"Rendering markdown output",
|
||||
] {
|
||||
spinner.tick(label, self.renderer.color_theme(), out)?;
|
||||
thread::sleep(Duration::from_millis(24));
|
||||
fn handle_stream_event(
|
||||
renderer: &TerminalRenderer,
|
||||
event: StreamEvent,
|
||||
stream_spinner: &mut Spinner,
|
||||
tool_spinner: &mut Spinner,
|
||||
saw_text: &mut bool,
|
||||
turn_usage: &mut UsageSummary,
|
||||
out: &mut impl Write,
|
||||
) {
|
||||
match event {
|
||||
StreamEvent::TextDelta(delta) => {
|
||||
if !*saw_text {
|
||||
let _ =
|
||||
stream_spinner.finish("Streaming response", renderer.color_theme(), out);
|
||||
*saw_text = true;
|
||||
}
|
||||
let _ = write!(out, "{delta}");
|
||||
let _ = out.flush();
|
||||
}
|
||||
StreamEvent::ToolCallStart { name, input } => {
|
||||
if *saw_text {
|
||||
let _ = writeln!(out);
|
||||
}
|
||||
let _ = tool_spinner.tick(
|
||||
&format!("Running tool `{name}` with {input}"),
|
||||
renderer.color_theme(),
|
||||
out,
|
||||
);
|
||||
}
|
||||
StreamEvent::ToolCallResult {
|
||||
name,
|
||||
output,
|
||||
is_error,
|
||||
} => {
|
||||
let label = if is_error {
|
||||
format!("Tool `{name}` failed")
|
||||
} else {
|
||||
format!("Tool `{name}` completed")
|
||||
};
|
||||
let _ = tool_spinner.finish(&label, renderer.color_theme(), out);
|
||||
let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n");
|
||||
let _ = renderer.stream_markdown(&rendered_output, out);
|
||||
}
|
||||
StreamEvent::Usage(usage) => {
|
||||
*turn_usage = usage;
|
||||
}
|
||||
}
|
||||
spinner.finish("Streaming response", self.renderer.color_theme(), out)?;
|
||||
}
|
||||
|
||||
let response = demo_response(input, &self.config);
|
||||
fn write_turn_output(
|
||||
&self,
|
||||
summary: &runtime::TurnSummary,
|
||||
out: &mut impl Write,
|
||||
) -> io::Result<()> {
|
||||
match self.config.output_format {
|
||||
OutputFormat::Text => self.renderer.stream_markdown(&response, out)?,
|
||||
OutputFormat::Json => writeln!(out, "{{\"message\":{response:?}}}")?,
|
||||
OutputFormat::Text => {
|
||||
writeln!(
|
||||
out,
|
||||
"\nToken usage: {} input / {} output",
|
||||
self.state.last_usage.input_tokens, self.state.last_usage.output_tokens
|
||||
)?;
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
writeln!(
|
||||
out,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"message": summary.assistant_text,
|
||||
"usage": {
|
||||
"input_tokens": self.state.last_usage.input_tokens,
|
||||
"output_tokens": self.state.last_usage.output_tokens,
|
||||
}
|
||||
})
|
||||
)?;
|
||||
}
|
||||
OutputFormat::Ndjson => {
|
||||
writeln!(out, "{{\"type\":\"message\",\"text\":{response:?}}}")?;
|
||||
writeln!(
|
||||
out,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"type": "message",
|
||||
"text": summary.assistant_text,
|
||||
"usage": {
|
||||
"input_tokens": self.state.last_usage.input_tokens,
|
||||
"output_tokens": self.state.last_usage.output_tokens,
|
||||
}
|
||||
})
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn demo_response(input: &str, config: &SessionConfig) -> String {
|
||||
format!(
|
||||
"## Assistant\n\nModel: `{}` \nPermission mode: `{}`\n\nYou said:\n\n> {}\n\nThis renderer now supports **bold**, *italic*, inline `code`, and syntax-highlighted blocks:\n\n```rust\nfn main() {{\n println!(\"streaming from rusty-claude-cli\");\n}}\n```",
|
||||
config.model,
|
||||
permission_mode_label(config.permission_mode),
|
||||
input.trim()
|
||||
)
|
||||
}
|
||||
fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> {
|
||||
let mut stream_spinner = Spinner::new();
|
||||
stream_spinner.tick(
|
||||
"Opening conversation stream",
|
||||
self.renderer.color_theme(),
|
||||
out,
|
||||
)?;
|
||||
|
||||
#[must_use]
|
||||
pub fn permission_mode_label(mode: PermissionMode) -> &'static str {
|
||||
match mode {
|
||||
PermissionMode::ReadOnly => "read-only",
|
||||
PermissionMode::WorkspaceWrite => "workspace-write",
|
||||
PermissionMode::DangerFullAccess => "danger-full-access",
|
||||
let mut turn_usage = UsageSummary::default();
|
||||
let mut tool_spinner = Spinner::new();
|
||||
let mut saw_text = false;
|
||||
let renderer = &self.renderer;
|
||||
|
||||
let result =
|
||||
self.conversation_client
|
||||
.run_turn(&mut self.conversation_history, input, |event| {
|
||||
Self::handle_stream_event(
|
||||
renderer,
|
||||
event,
|
||||
&mut stream_spinner,
|
||||
&mut tool_spinner,
|
||||
&mut saw_text,
|
||||
&mut turn_usage,
|
||||
out,
|
||||
);
|
||||
});
|
||||
|
||||
let summary = match result {
|
||||
Ok(summary) => summary,
|
||||
Err(error) => {
|
||||
stream_spinner.fail(
|
||||
"Streaming response failed",
|
||||
self.renderer.color_theme(),
|
||||
out,
|
||||
)?;
|
||||
return Err(io::Error::other(error));
|
||||
}
|
||||
};
|
||||
self.state.last_usage = summary.usage.clone();
|
||||
if saw_text {
|
||||
writeln!(out)?;
|
||||
} else {
|
||||
stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?;
|
||||
}
|
||||
|
||||
self.write_turn_output(&summary, out)?;
|
||||
let _ = turn_usage;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,7 +354,7 @@ mod tests {
|
||||
|
||||
use crate::args::{OutputFormat, PermissionMode};
|
||||
|
||||
use super::{CliApp, CommandResult, SessionConfig, SlashCommand};
|
||||
use super::{CommandResult, SessionConfig, SlashCommand};
|
||||
|
||||
#[test]
|
||||
fn parses_required_slash_commands() {
|
||||
@@ -258,33 +367,27 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn help_status_and_compact_commands_are_wired() {
|
||||
fn help_output_lists_commands() {
|
||||
let mut out = Vec::new();
|
||||
let result = super::CliApp::handle_help(&mut out).expect("help succeeds");
|
||||
assert_eq!(result, CommandResult::Continue);
|
||||
let output = String::from_utf8_lossy(&out);
|
||||
assert!(output.contains("/help"));
|
||||
assert!(output.contains("/status"));
|
||||
assert!(output.contains("/compact"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_state_tracks_config_values() {
|
||||
let config = SessionConfig {
|
||||
model: "claude".into(),
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
config: Some(PathBuf::from("settings.toml")),
|
||||
output_format: OutputFormat::Text,
|
||||
};
|
||||
let mut app = CliApp::new(config);
|
||||
let mut out = Vec::new();
|
||||
|
||||
let result = app
|
||||
.handle_submission("/help", &mut out)
|
||||
.expect("help succeeds");
|
||||
assert_eq!(result, CommandResult::Continue);
|
||||
|
||||
app.handle_submission("hello", &mut out)
|
||||
.expect("submission succeeds");
|
||||
app.handle_submission("/status", &mut out)
|
||||
.expect("status succeeds");
|
||||
app.handle_submission("/compact", &mut out)
|
||||
.expect("compact succeeds");
|
||||
|
||||
let output = String::from_utf8_lossy(&out);
|
||||
assert!(output.contains("/help"));
|
||||
assert!(output.contains("/status"));
|
||||
assert!(output.contains("/compact"));
|
||||
assert!(output.contains("status: turns=1"));
|
||||
assert!(output.contains("Compacted session history"));
|
||||
assert_eq!(config.model, "claude");
|
||||
assert_eq!(config.permission_mode, PermissionMode::WorkspaceWrite);
|
||||
assert_eq!(config.config, Some(PathBuf::from("settings.toml")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +1,123 @@
|
||||
mod app;
|
||||
mod args;
|
||||
mod input;
|
||||
mod render;
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use app::{CliApp, SessionConfig};
|
||||
use args::{Cli, Command};
|
||||
use clap::Parser;
|
||||
use commands::handle_slash_command;
|
||||
use compat_harness::{extract_manifest, UpstreamPaths};
|
||||
use runtime::BootstrapPlan;
|
||||
use runtime::{load_system_prompt, BootstrapPlan, CompactionConfig, Session};
|
||||
|
||||
fn main() {
|
||||
let cli = Cli::parse();
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
|
||||
let result = match &cli.command {
|
||||
Some(Command::DumpManifests) => dump_manifests(),
|
||||
Some(Command::BootstrapPlan) => {
|
||||
print_bootstrap_plan();
|
||||
Ok(())
|
||||
match parse_args(&args) {
|
||||
Ok(CliAction::DumpManifests) => dump_manifests(),
|
||||
Ok(CliAction::BootstrapPlan) => print_bootstrap_plan(),
|
||||
Ok(CliAction::PrintSystemPrompt { cwd, date }) => print_system_prompt(cwd, date),
|
||||
Ok(CliAction::ResumeSession {
|
||||
session_path,
|
||||
command,
|
||||
}) => resume_session(&session_path, command),
|
||||
Ok(CliAction::Help) => print_help(),
|
||||
Err(error) => {
|
||||
eprintln!("{error}");
|
||||
print_help();
|
||||
std::process::exit(2);
|
||||
}
|
||||
Some(Command::Prompt { prompt }) => {
|
||||
let joined = prompt.join(" ");
|
||||
let mut app = CliApp::new(build_session_config(&cli));
|
||||
app.run_prompt(&joined, &mut std::io::stdout())
|
||||
}
|
||||
None => {
|
||||
let mut app = CliApp::new(build_session_config(&cli));
|
||||
app.run_repl()
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = result {
|
||||
eprintln!("{error}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_session_config(cli: &Cli) -> SessionConfig {
|
||||
SessionConfig {
|
||||
model: cli.model.clone(),
|
||||
permission_mode: cli.permission_mode,
|
||||
config: cli.config.clone(),
|
||||
output_format: cli.output_format,
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum CliAction {
|
||||
DumpManifests,
|
||||
BootstrapPlan,
|
||||
PrintSystemPrompt {
|
||||
cwd: PathBuf,
|
||||
date: String,
|
||||
},
|
||||
ResumeSession {
|
||||
session_path: PathBuf,
|
||||
command: Option<String>,
|
||||
},
|
||||
Help,
|
||||
}
|
||||
|
||||
fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
if args.is_empty() {
|
||||
return Ok(CliAction::Help);
|
||||
}
|
||||
|
||||
if matches!(args.first().map(String::as_str), Some("--help" | "-h")) {
|
||||
return Ok(CliAction::Help);
|
||||
}
|
||||
|
||||
if args.first().map(String::as_str) == Some("--resume") {
|
||||
return parse_resume_args(&args[1..]);
|
||||
}
|
||||
|
||||
match args[0].as_str() {
|
||||
"dump-manifests" => Ok(CliAction::DumpManifests),
|
||||
"bootstrap-plan" => Ok(CliAction::BootstrapPlan),
|
||||
"system-prompt" => parse_system_prompt_args(&args[1..]),
|
||||
other => Err(format!("unknown subcommand: {other}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn dump_manifests() -> std::io::Result<()> {
|
||||
fn parse_system_prompt_args(args: &[String]) -> Result<CliAction, String> {
|
||||
let mut cwd = env::current_dir().map_err(|error| error.to_string())?;
|
||||
let mut date = "2026-03-31".to_string();
|
||||
let mut index = 0;
|
||||
|
||||
while index < args.len() {
|
||||
match args[index].as_str() {
|
||||
"--cwd" => {
|
||||
let value = args
|
||||
.get(index + 1)
|
||||
.ok_or_else(|| "missing value for --cwd".to_string())?;
|
||||
cwd = PathBuf::from(value);
|
||||
index += 2;
|
||||
}
|
||||
"--date" => {
|
||||
let value = args
|
||||
.get(index + 1)
|
||||
.ok_or_else(|| "missing value for --date".to_string())?;
|
||||
date.clone_from(value);
|
||||
index += 2;
|
||||
}
|
||||
other => return Err(format!("unknown system-prompt option: {other}")),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(CliAction::PrintSystemPrompt { cwd, date })
|
||||
}
|
||||
|
||||
fn parse_resume_args(args: &[String]) -> Result<CliAction, String> {
|
||||
let session_path = args
|
||||
.first()
|
||||
.ok_or_else(|| "missing session path for --resume".to_string())
|
||||
.map(PathBuf::from)?;
|
||||
let command = args.get(1).cloned();
|
||||
if args.len() > 2 {
|
||||
return Err("--resume accepts at most one trailing slash command".to_string());
|
||||
}
|
||||
Ok(CliAction::ResumeSession {
|
||||
session_path,
|
||||
command,
|
||||
})
|
||||
}
|
||||
|
||||
fn dump_manifests() {
|
||||
let workspace_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../..");
|
||||
let paths = UpstreamPaths::from_workspace_dir(&workspace_dir);
|
||||
let manifest = extract_manifest(&paths)?;
|
||||
println!("commands: {}", manifest.commands.entries().len());
|
||||
println!("tools: {}", manifest.tools.entries().len());
|
||||
println!("bootstrap phases: {}", manifest.bootstrap.phases().len());
|
||||
Ok(())
|
||||
match extract_manifest(&paths) {
|
||||
Ok(manifest) => {
|
||||
println!("commands: {}", manifest.commands.entries().len());
|
||||
println!("tools: {}", manifest.tools.entries().len());
|
||||
println!("bootstrap phases: {}", manifest.bootstrap.phases().len());
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("failed to extract manifests: {error}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn print_bootstrap_plan() {
|
||||
@@ -61,3 +125,108 @@ fn print_bootstrap_plan() {
|
||||
println!("- {phase:?}");
|
||||
}
|
||||
}
|
||||
|
||||
fn print_system_prompt(cwd: PathBuf, date: String) {
|
||||
match load_system_prompt(cwd, date, env::consts::OS, "unknown") {
|
||||
Ok(sections) => println!("{}", sections.join("\n\n")),
|
||||
Err(error) => {
|
||||
eprintln!("failed to build system prompt: {error}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resume_session(session_path: &Path, command: Option<String>) {
|
||||
let session = match Session::load_from_path(session_path) {
|
||||
Ok(session) => session,
|
||||
Err(error) => {
|
||||
eprintln!("failed to restore session: {error}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
match command {
|
||||
Some(command) if command.starts_with('/') => {
|
||||
let Some(result) = handle_slash_command(
|
||||
&command,
|
||||
&session,
|
||||
CompactionConfig {
|
||||
max_estimated_tokens: 0,
|
||||
..CompactionConfig::default()
|
||||
},
|
||||
) else {
|
||||
eprintln!("unknown slash command: {command}");
|
||||
std::process::exit(2);
|
||||
};
|
||||
if let Err(error) = result.session.save_to_path(session_path) {
|
||||
eprintln!("failed to persist resumed session: {error}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
println!("{}", result.message);
|
||||
}
|
||||
Some(other) => {
|
||||
eprintln!("unsupported resumed command: {other}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
None => {
|
||||
println!(
|
||||
"Restored session from {} ({} messages).",
|
||||
session_path.display(),
|
||||
session.messages.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("rusty-claude-cli");
|
||||
println!();
|
||||
println!("Current scaffold commands:");
|
||||
println!(
|
||||
" dump-manifests Read upstream TS sources and print extracted counts"
|
||||
);
|
||||
println!(" bootstrap-plan Print the current bootstrap phase skeleton");
|
||||
println!(" system-prompt [--cwd PATH] [--date YYYY-MM-DD]");
|
||||
println!(" Build a Claude-style system prompt from CLAUDE.md and config files");
|
||||
println!(" --resume SESSION.json [/compact] Restore a saved session and optionally run a slash command");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{parse_args, CliAction};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn parses_system_prompt_options() {
|
||||
let args = vec![
|
||||
"system-prompt".to_string(),
|
||||
"--cwd".to_string(),
|
||||
"/tmp/project".to_string(),
|
||||
"--date".to_string(),
|
||||
"2026-04-01".to_string(),
|
||||
];
|
||||
assert_eq!(
|
||||
parse_args(&args).expect("args should parse"),
|
||||
CliAction::PrintSystemPrompt {
|
||||
cwd: PathBuf::from("/tmp/project"),
|
||||
date: "2026-04-01".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_resume_flag_with_slash_command() {
|
||||
let args = vec![
|
||||
"--resume".to_string(),
|
||||
"session.json".to_string(),
|
||||
"/compact".to_string(),
|
||||
];
|
||||
assert_eq!(
|
||||
parse_args(&args).expect("args should parse"),
|
||||
CliAction::ResumeSession {
|
||||
session_path: PathBuf::from("session.json"),
|
||||
command: Some("/compact".to_string()),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ pub struct ColorTheme {
|
||||
quote: Color,
|
||||
spinner_active: Color,
|
||||
spinner_done: Color,
|
||||
spinner_failed: Color,
|
||||
}
|
||||
|
||||
impl Default for ColorTheme {
|
||||
@@ -36,6 +37,7 @@ impl Default for ColorTheme {
|
||||
quote: Color::DarkGrey,
|
||||
spinner_active: Color::Blue,
|
||||
spinner_done: Color::Green,
|
||||
spinner_failed: Color::Red,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,6 +93,24 @@ impl Spinner {
|
||||
)?;
|
||||
out.flush()
|
||||
}
|
||||
|
||||
pub fn fail(
|
||||
&mut self,
|
||||
label: &str,
|
||||
theme: &ColorTheme,
|
||||
out: &mut impl Write,
|
||||
) -> io::Result<()> {
|
||||
self.frame_index = 0;
|
||||
execute!(
|
||||
out,
|
||||
MoveToColumn(0),
|
||||
Clear(ClearType::CurrentLine),
|
||||
SetForegroundColor(theme.spinner_failed),
|
||||
Print(format!("✘ {label}\n")),
|
||||
ResetColor
|
||||
)?;
|
||||
out.flush()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
|
||||
@@ -5,13 +5,5 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
regex = "1.12"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.20"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,14 +1,3 @@
|
||||
use regex::RegexBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ToolManifestEntry {
|
||||
pub name: String,
|
||||
@@ -37,979 +26,3 @@ impl ToolRegistry {
|
||||
&self.entries
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct TextContent {
|
||||
#[serde(rename = "type")]
|
||||
pub kind: &'static str,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct ToolResult {
|
||||
pub content: Vec<TextContent>,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
#[must_use]
|
||||
pub fn text(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
content: vec![TextContent {
|
||||
kind: "text",
|
||||
text: text.into(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolError {
|
||||
message: Cow<'static, str>,
|
||||
}
|
||||
|
||||
impl ToolError {
|
||||
#[must_use]
|
||||
pub fn new(message: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ToolError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ToolError {}
|
||||
|
||||
impl From<io::Error> for ToolError {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::new(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<regex::Error> for ToolError {
|
||||
fn from(value: regex::Error) -> Self {
|
||||
Self::new(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Tool {
|
||||
fn name(&self) -> &'static str;
|
||||
fn description(&self) -> &'static str;
|
||||
fn input_schema(&self) -> Value;
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError>;
|
||||
}
|
||||
|
||||
fn schema_string(description: &str) -> Value {
|
||||
json!({ "type": "string", "description": description })
|
||||
}
|
||||
|
||||
fn schema_number(description: &str) -> Value {
|
||||
json!({ "type": "number", "description": description })
|
||||
}
|
||||
|
||||
fn schema_boolean(description: &str) -> Value {
|
||||
json!({ "type": "boolean", "description": description })
|
||||
}
|
||||
|
||||
fn strict_object(properties: &Value, required: &[&str]) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": false,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_string(input: &Value, key: &'static str) -> Result<String, ToolError> {
|
||||
input
|
||||
.get(key)
|
||||
.and_then(Value::as_str)
|
||||
.map(ToOwned::to_owned)
|
||||
.ok_or_else(|| ToolError::new(format!("missing or invalid string field: {key}")))
|
||||
}
|
||||
|
||||
fn optional_string(input: &Value, key: &'static str) -> Result<Option<String>, ToolError> {
|
||||
match input.get(key) {
|
||||
None | Some(Value::Null) => Ok(None),
|
||||
Some(Value::String(value)) => Ok(Some(value.clone())),
|
||||
Some(_) => Err(ToolError::new(format!("invalid string field: {key}"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn optional_u64(input: &Value, key: &'static str) -> Result<Option<u64>, ToolError> {
|
||||
match input.get(key) {
|
||||
None | Some(Value::Null) => Ok(None),
|
||||
Some(value) => value
|
||||
.as_u64()
|
||||
.ok_or_else(|| ToolError::new(format!("invalid numeric field: {key}")))
|
||||
.map(Some),
|
||||
}
|
||||
}
|
||||
|
||||
fn optional_bool(input: &Value, key: &'static str) -> Result<Option<bool>, ToolError> {
|
||||
match input.get(key) {
|
||||
None | Some(Value::Null) => Ok(None),
|
||||
Some(value) => value
|
||||
.as_bool()
|
||||
.ok_or_else(|| ToolError::new(format!("invalid boolean field: {key}")))
|
||||
.map(Some),
|
||||
}
|
||||
}
|
||||
|
||||
fn absolute_path(path: &str) -> Result<PathBuf, ToolError> {
|
||||
let expanded = if let Some(rest) = path.strip_prefix("~/") {
|
||||
std::env::var_os("HOME")
|
||||
.map(PathBuf::from)
|
||||
.map_or_else(|| PathBuf::from(path), |home| home.join(rest))
|
||||
} else {
|
||||
PathBuf::from(path)
|
||||
};
|
||||
|
||||
if expanded.is_absolute() {
|
||||
Ok(expanded)
|
||||
} else {
|
||||
Err(ToolError::new(format!("path must be absolute: {path}")))
|
||||
}
|
||||
}
|
||||
|
||||
fn relative_display(path: &Path, base: &Path) -> String {
|
||||
path.strip_prefix(base).ok().map_or_else(
|
||||
|| path.to_string_lossy().replace('\\', "/"),
|
||||
|value| value.to_string_lossy().replace('\\', "/"),
|
||||
)
|
||||
}
|
||||
|
||||
fn line_slice(content: &str, offset: Option<u64>, limit: Option<u64>) -> String {
|
||||
let start = usize_from_u64(offset.unwrap_or(1).saturating_sub(1));
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let end = limit
|
||||
.map_or(lines.len(), |limit| {
|
||||
start.saturating_add(usize_from_u64(limit))
|
||||
})
|
||||
.min(lines.len());
|
||||
|
||||
if start >= lines.len() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
lines[start..end]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, line)| format!("{:>6}\t{line}", start + index + 1))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn parse_page_range(pages: &str) -> Result<(u64, u64), ToolError> {
|
||||
if let Some((start, end)) = pages.split_once('-') {
|
||||
let start = start
|
||||
.trim()
|
||||
.parse::<u64>()
|
||||
.map_err(|_| ToolError::new("invalid pages parameter"))?;
|
||||
let end = end
|
||||
.trim()
|
||||
.parse::<u64>()
|
||||
.map_err(|_| ToolError::new("invalid pages parameter"))?;
|
||||
if start == 0 || end < start {
|
||||
return Err(ToolError::new("invalid pages parameter"));
|
||||
}
|
||||
Ok((start, end))
|
||||
} else {
|
||||
let page = pages
|
||||
.trim()
|
||||
.parse::<u64>()
|
||||
.map_err(|_| ToolError::new("invalid pages parameter"))?;
|
||||
if page == 0 {
|
||||
return Err(ToolError::new("invalid pages parameter"));
|
||||
}
|
||||
Ok((page, page))
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_single_edit(
|
||||
original: &str,
|
||||
old_string: &str,
|
||||
new_string: &str,
|
||||
replace_all: bool,
|
||||
) -> Result<String, ToolError> {
|
||||
if old_string == new_string {
|
||||
return Err(ToolError::new(
|
||||
"No changes to make: old_string and new_string are exactly the same.",
|
||||
));
|
||||
}
|
||||
|
||||
if old_string.is_empty() {
|
||||
if original.is_empty() {
|
||||
return Ok(new_string.to_owned());
|
||||
}
|
||||
return Err(ToolError::new(
|
||||
"Cannot create new file - file already exists.",
|
||||
));
|
||||
}
|
||||
|
||||
let matches = original.matches(old_string).count();
|
||||
if matches == 0 {
|
||||
return Err(ToolError::new(format!(
|
||||
"String to replace not found in file.\nString: {old_string}"
|
||||
)));
|
||||
}
|
||||
|
||||
if matches > 1 && !replace_all {
|
||||
return Err(ToolError::new(format!(
|
||||
"Found {matches} matches of the string to replace, but replace_all is false. To replace all occurrences, set replace_all to true. To replace only one occurrence, please provide more context to uniquely identify the instance.\nString: {old_string}"
|
||||
)));
|
||||
}
|
||||
|
||||
let updated = if replace_all {
|
||||
original.replace(old_string, new_string)
|
||||
} else {
|
||||
original.replacen(old_string, new_string, 1)
|
||||
};
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
fn diff_hunks(_before: &str, _after: &str) -> Value {
|
||||
json!([])
|
||||
}
|
||||
|
||||
fn usize_from_u64(value: u64) -> usize {
|
||||
usize::try_from(value).unwrap_or(usize::MAX)
|
||||
}
|
||||
|
||||
pub struct BashTool;
|
||||
pub struct ReadTool;
|
||||
pub struct WriteTool;
|
||||
pub struct EditTool;
|
||||
pub struct GlobTool;
|
||||
pub struct GrepTool;
|
||||
|
||||
impl Tool for BashTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Bash"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Execute a shell command in the current environment."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"command": schema_string("The command to execute"),
|
||||
"timeout": schema_number("Optional timeout in milliseconds (max 600000)"),
|
||||
"description": schema_string("Clear, concise description of what this command does in active voice. Never use words like \"complex\" or \"risk\" in the description - just describe what it does."),
|
||||
"run_in_background": schema_boolean("Set to true to run this command in the background. Use Read to read the output later."),
|
||||
"dangerouslyDisableSandbox": schema_boolean("Set this to true to dangerously override sandbox mode and run commands without sandboxing.")
|
||||
}),
|
||||
&["command"],
|
||||
)
|
||||
}
|
||||
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let command = parse_string(&input, "command")?;
|
||||
let _timeout = optional_u64(&input, "timeout")?;
|
||||
let _description = optional_string(&input, "description")?;
|
||||
let run_in_background = optional_bool(&input, "run_in_background")?.unwrap_or(false);
|
||||
let _disable_sandbox = optional_bool(&input, "dangerouslyDisableSandbox")?.unwrap_or(false);
|
||||
|
||||
if run_in_background {
|
||||
return Ok(ToolResult::text(
|
||||
"Background execution is not supported in this runtime.",
|
||||
));
|
||||
}
|
||||
|
||||
let output = Command::new("bash").arg("-lc").arg(&command).output()?;
|
||||
let mut rendered = String::new();
|
||||
if !output.stdout.is_empty() {
|
||||
rendered.push_str(&String::from_utf8_lossy(&output.stdout));
|
||||
}
|
||||
if !output.stderr.is_empty() {
|
||||
if !rendered.is_empty() && !rendered.ends_with('\n') {
|
||||
rendered.push('\n');
|
||||
}
|
||||
rendered.push_str(&String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
if rendered.is_empty() {
|
||||
rendered = if output.status.success() {
|
||||
"Done".to_owned()
|
||||
} else {
|
||||
format!("Command exited with status {}", output.status)
|
||||
};
|
||||
}
|
||||
Ok(ToolResult::text(rendered.trim_end().to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for ReadTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Read"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Read a file from the local filesystem."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"file_path": schema_string("The absolute path to the file to read"),
|
||||
"offset": json!({"type":"number","description":"The line number to start reading from. Only provide if the file is too large to read at once","minimum":0}),
|
||||
"limit": json!({"type":"number","description":"The number of lines to read. Only provide if the file is too large to read at once.","exclusiveMinimum":0}),
|
||||
"pages": schema_string("Page range for PDF files (e.g., \"1-5\", \"3\", \"10-20\"). Only applicable to PDF files. Maximum 20 pages per request.")
|
||||
}),
|
||||
&["file_path"],
|
||||
)
|
||||
}
|
||||
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let file_path = parse_string(&input, "file_path")?;
|
||||
let path = absolute_path(&file_path)?;
|
||||
let offset = optional_u64(&input, "offset")?;
|
||||
let limit = optional_u64(&input, "limit")?;
|
||||
let pages = optional_string(&input, "pages")?;
|
||||
|
||||
let content = fs::read_to_string(&path)?;
|
||||
if path.extension().and_then(|ext| ext.to_str()) == Some("pdf") {
|
||||
if let Some(pages) = pages {
|
||||
let (start, end) = parse_page_range(&pages)?;
|
||||
return Ok(ToolResult::text(format!(
|
||||
"PDF page extraction is not implemented in Rust yet for {}. Requested pages {}-{}.",
|
||||
path.display(), start, end
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let rendered = if offset.is_some() || limit.is_some() {
|
||||
line_slice(&content, offset, limit)
|
||||
} else {
|
||||
line_slice(&content, Some(1), None)
|
||||
};
|
||||
Ok(ToolResult::text(rendered))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for WriteTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Write"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write a file to the local filesystem."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"file_path": schema_string("The absolute path to the file to write (must be absolute, not relative)"),
|
||||
"content": schema_string("The content to write to the file")
|
||||
}),
|
||||
&["file_path", "content"],
|
||||
)
|
||||
}
|
||||
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let file_path = parse_string(&input, "file_path")?;
|
||||
let content = parse_string(&input, "content")?;
|
||||
let path = absolute_path(&file_path)?;
|
||||
let existed = path.exists();
|
||||
let original = if existed {
|
||||
Some(fs::read_to_string(&path)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(&path, &content)?;
|
||||
|
||||
let payload = json!({
|
||||
"type": if existed { "update" } else { "create" },
|
||||
"filePath": file_path,
|
||||
"content": content,
|
||||
"structuredPatch": diff_hunks(original.as_deref().unwrap_or(""), &content),
|
||||
"originalFile": original,
|
||||
"gitDiff": Value::Null,
|
||||
});
|
||||
Ok(ToolResult::text(payload.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for EditTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Edit"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"A tool for editing files"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"file_path": schema_string("The absolute path to the file to modify"),
|
||||
"old_string": schema_string("The text to replace"),
|
||||
"new_string": schema_string("The text to replace it with (must be different from old_string)"),
|
||||
"replace_all": json!({"type":"boolean","description":"Replace all occurrences of old_string (default false)","default":false})
|
||||
}),
|
||||
&["file_path", "old_string", "new_string"],
|
||||
)
|
||||
}
|
||||
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let file_path = parse_string(&input, "file_path")?;
|
||||
let old_string = parse_string(&input, "old_string")?;
|
||||
let new_string = parse_string(&input, "new_string")?;
|
||||
let replace_all = optional_bool(&input, "replace_all")?.unwrap_or(false);
|
||||
let path = absolute_path(&file_path)?;
|
||||
let original = if path.exists() {
|
||||
fs::read_to_string(&path)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let updated = apply_single_edit(&original, &old_string, &new_string, replace_all)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(&path, &updated)?;
|
||||
|
||||
let payload = json!({
|
||||
"filePath": file_path,
|
||||
"oldString": old_string,
|
||||
"newString": new_string,
|
||||
"originalFile": original,
|
||||
"structuredPatch": diff_hunks("", ""),
|
||||
"userModified": false,
|
||||
"replaceAll": replace_all,
|
||||
"gitDiff": Value::Null,
|
||||
});
|
||||
Ok(ToolResult::text(payload.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for GlobTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Glob"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Fast file pattern matching tool"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"pattern": schema_string("The glob pattern to match files against"),
|
||||
"path": schema_string("The directory to search in. If not specified, the current working directory will be used. IMPORTANT: Omit this field to use the default directory. DO NOT enter \"undefined\" or \"null\" - simply omit it for the default behavior. Must be a valid directory path if provided.")
|
||||
}),
|
||||
&["pattern"],
|
||||
)
|
||||
}
|
||||
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let pattern = parse_string(&input, "pattern")?;
|
||||
let root = optional_string(&input, "path")?
|
||||
.map(|path| absolute_path(&path))
|
||||
.transpose()?
|
||||
.unwrap_or(std::env::current_dir()?);
|
||||
let start = std::time::Instant::now();
|
||||
let mut filenames = Vec::new();
|
||||
visit_files(&root, &mut |path| {
|
||||
let relative = relative_display(path, &root);
|
||||
if glob_matches(&pattern, &relative) {
|
||||
filenames.push(relative);
|
||||
}
|
||||
})?;
|
||||
filenames.sort();
|
||||
let truncated = filenames.len() > 100;
|
||||
if truncated {
|
||||
filenames.truncate(100);
|
||||
}
|
||||
let payload = json!({
|
||||
"durationMs": start.elapsed().as_millis(),
|
||||
"numFiles": filenames.len(),
|
||||
"filenames": filenames,
|
||||
"truncated": truncated,
|
||||
});
|
||||
Ok(ToolResult::text(payload.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for GrepTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"Grep"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Fast content search tool"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
strict_object(
|
||||
&json!({
|
||||
"pattern": schema_string("The regular expression pattern to search for in file contents"),
|
||||
"path": schema_string("File or directory to search in (rg PATH). Defaults to current working directory."),
|
||||
"glob": schema_string("Glob pattern to filter files (e.g. \"*.js\", \"*.{ts,tsx}\") - maps to rg --glob"),
|
||||
"output_mode": {"type":"string","enum":["content","files_with_matches","count"],"description":"Output mode: \"content\" shows matching lines (supports -A/-B/-C context, -n line numbers, head_limit), \"files_with_matches\" shows file paths (supports head_limit), \"count\" shows match counts (supports head_limit). Defaults to \"files_with_matches\"."},
|
||||
"-B": schema_number("Number of lines to show before each match (rg -B). Requires output_mode: \"content\", ignored otherwise."),
|
||||
"-A": schema_number("Number of lines to show after each match (rg -A). Requires output_mode: \"content\", ignored otherwise."),
|
||||
"-C": schema_number("Alias for context."),
|
||||
"context": schema_number("Number of lines to show before and after each match (rg -C). Requires output_mode: \"content\", ignored otherwise."),
|
||||
"-n": {"type":"boolean","description":"Show line numbers in output (rg -n). Requires output_mode: \"content\", ignored otherwise. Defaults to true."},
|
||||
"-i": schema_boolean("Case insensitive search (rg -i)"),
|
||||
"type": schema_string("File type to search (rg --type). Common types: js, py, rust, go, java, etc. More efficient than include for standard file types."),
|
||||
"head_limit": schema_number("Limit output to first N lines/entries, equivalent to \"| head -N\". Works across all output modes: content (limits output lines), files_with_matches (limits file paths), count (limits count entries). Defaults to 250 when unspecified. Pass 0 for unlimited (use sparingly — large result sets waste context)."),
|
||||
"offset": schema_number("Skip first N lines/entries before applying head_limit, equivalent to \"| tail -n +N | head -N\". Works across all output modes. Defaults to 0."),
|
||||
"multiline": schema_boolean("Enable multiline mode where . matches newlines and patterns can span lines (rg -U --multiline-dotall). Default: false.")
|
||||
}),
|
||||
&["pattern"],
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn execute(&self, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let pattern = parse_string(&input, "pattern")?;
|
||||
let root = optional_string(&input, "path")?
|
||||
.map(|path| absolute_path(&path))
|
||||
.transpose()?
|
||||
.unwrap_or(std::env::current_dir()?);
|
||||
let glob = optional_string(&input, "glob")?;
|
||||
let output_mode = optional_string(&input, "output_mode")?
|
||||
.unwrap_or_else(|| "files_with_matches".to_owned());
|
||||
let context_before = usize_from_u64(optional_u64(&input, "-B")?.unwrap_or(0));
|
||||
let context_after = usize_from_u64(optional_u64(&input, "-A")?.unwrap_or(0));
|
||||
let context_c = optional_u64(&input, "-C")?;
|
||||
let context = optional_u64(&input, "context")?;
|
||||
let show_line_numbers = optional_bool(&input, "-n")?.unwrap_or(true);
|
||||
let case_insensitive = optional_bool(&input, "-i")?.unwrap_or(false);
|
||||
let file_type = optional_string(&input, "type")?;
|
||||
let head_limit = optional_u64(&input, "head_limit")?;
|
||||
let offset = usize_from_u64(optional_u64(&input, "offset")?.unwrap_or(0));
|
||||
let _multiline = optional_bool(&input, "multiline")?.unwrap_or(false);
|
||||
|
||||
let shared_context = usize_from_u64(context.or(context_c).unwrap_or(0));
|
||||
let regex = RegexBuilder::new(&pattern)
|
||||
.case_insensitive(case_insensitive)
|
||||
.build()?;
|
||||
|
||||
let mut matched_lines = Vec::new();
|
||||
let mut files_with_matches = Vec::new();
|
||||
let mut count_lines = Vec::new();
|
||||
let mut total_matches = 0usize;
|
||||
|
||||
let candidates = collect_files(&root)?;
|
||||
for path in candidates {
|
||||
let relative = relative_display(&path, &root);
|
||||
if !matches_file_filter(&relative, glob.as_deref(), file_type.as_deref()) {
|
||||
continue;
|
||||
}
|
||||
let Ok(file_content) = fs::read_to_string(&path) else {
|
||||
continue;
|
||||
};
|
||||
let lines: Vec<&str> = file_content.lines().collect();
|
||||
let mut matched_indexes = Vec::new();
|
||||
let mut file_match_count = 0usize;
|
||||
for (index, line) in lines.iter().enumerate() {
|
||||
if regex.is_match(line) {
|
||||
matched_indexes.push(index);
|
||||
file_match_count += regex.find_iter(line).count().max(1);
|
||||
}
|
||||
}
|
||||
if matched_indexes.is_empty() {
|
||||
continue;
|
||||
}
|
||||
total_matches += file_match_count;
|
||||
files_with_matches.push(relative.clone());
|
||||
count_lines.push(format!("{relative}:{file_match_count}"));
|
||||
|
||||
if output_mode == "content" {
|
||||
let mut included = BTreeSet::new();
|
||||
for index in matched_indexes {
|
||||
let before = if shared_context > 0 {
|
||||
shared_context
|
||||
} else {
|
||||
context_before
|
||||
};
|
||||
let after = if shared_context > 0 {
|
||||
shared_context
|
||||
} else {
|
||||
context_after
|
||||
};
|
||||
let start = index.saturating_sub(before);
|
||||
let end = (index + after).min(lines.len().saturating_sub(1));
|
||||
for line_index in start..=end {
|
||||
included.insert(line_index);
|
||||
}
|
||||
}
|
||||
for line_index in included {
|
||||
if show_line_numbers {
|
||||
matched_lines.push(format!(
|
||||
"{relative}:{}:{}",
|
||||
line_index + 1,
|
||||
lines[line_index]
|
||||
));
|
||||
} else {
|
||||
matched_lines.push(format!("{relative}:{}", lines[line_index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let rendered = match output_mode.as_str() {
|
||||
"content" => {
|
||||
let limited = apply_offset_limit(matched_lines, head_limit, offset);
|
||||
json!({
|
||||
"mode": "content",
|
||||
"numFiles": 0,
|
||||
"filenames": [],
|
||||
"content": limited.join("\n"),
|
||||
"numLines": limited.len(),
|
||||
"appliedOffset": (offset > 0).then_some(offset),
|
||||
})
|
||||
}
|
||||
"count" => {
|
||||
let limited = apply_offset_limit(count_lines, head_limit, offset);
|
||||
json!({
|
||||
"mode": "count",
|
||||
"numFiles": files_with_matches.len(),
|
||||
"filenames": [],
|
||||
"content": limited.join("\n"),
|
||||
"numMatches": total_matches,
|
||||
"appliedOffset": (offset > 0).then_some(offset),
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
files_with_matches.sort();
|
||||
let limited = apply_offset_limit(files_with_matches, head_limit, offset);
|
||||
json!({
|
||||
"mode": "files_with_matches",
|
||||
"numFiles": limited.len(),
|
||||
"filenames": limited,
|
||||
"appliedOffset": (offset > 0).then_some(offset),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResult::text(rendered.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_offset_limit<T>(items: Vec<T>, limit: Option<u64>, offset: usize) -> Vec<T> {
|
||||
let mut iter = items.into_iter().skip(offset);
|
||||
match limit {
|
||||
Some(0) | None => iter.collect(),
|
||||
Some(limit) => iter.by_ref().take(usize_from_u64(limit)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_files(root: &Path) -> Result<Vec<PathBuf>, ToolError> {
|
||||
let mut files = Vec::new();
|
||||
if root.is_file() {
|
||||
files.push(root.to_path_buf());
|
||||
return Ok(files);
|
||||
}
|
||||
visit_files(root, &mut |path| files.push(path.to_path_buf()))?;
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
fn visit_files(root: &Path, visitor: &mut dyn FnMut(&Path)) -> Result<(), ToolError> {
|
||||
if root.is_file() {
|
||||
visitor(root);
|
||||
return Ok(());
|
||||
}
|
||||
for entry in fs::read_dir(root)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
visit_files(&path, visitor)?;
|
||||
} else if path.is_file() {
|
||||
visitor(&path);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matches_file_filter(relative: &str, glob: Option<&str>, file_type: Option<&str>) -> bool {
|
||||
let glob_ok = glob.is_none_or(|pattern| {
|
||||
split_glob_patterns(pattern)
|
||||
.into_iter()
|
||||
.any(|single| glob_matches(&single, relative))
|
||||
});
|
||||
let type_ok = file_type.is_none_or(|kind| path_matches_type(relative, kind));
|
||||
glob_ok && type_ok
|
||||
}
|
||||
|
||||
fn split_glob_patterns(patterns: &str) -> Vec<String> {
|
||||
let mut result = Vec::new();
|
||||
for raw in patterns.split_whitespace() {
|
||||
if raw.contains('{') && raw.contains('}') {
|
||||
result.push(raw.to_owned());
|
||||
} else {
|
||||
result.extend(
|
||||
raw.split(',')
|
||||
.filter(|part| !part.is_empty())
|
||||
.map(ToOwned::to_owned),
|
||||
);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn path_matches_type(relative: &str, kind: &str) -> bool {
|
||||
let extension = Path::new(relative)
|
||||
.extension()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or_default();
|
||||
matches!(
|
||||
(kind, extension),
|
||||
("rust", "rs")
|
||||
| ("js", "js")
|
||||
| ("ts", "ts")
|
||||
| ("tsx", "tsx")
|
||||
| ("py", "py")
|
||||
| ("go", "go")
|
||||
| ("java", "java")
|
||||
| ("json", "json")
|
||||
| ("md", "md")
|
||||
)
|
||||
}
|
||||
|
||||
fn glob_matches(pattern: &str, path: &str) -> bool {
|
||||
expand_braces(pattern)
|
||||
.into_iter()
|
||||
.any(|expanded| glob_match_one(&expanded, path))
|
||||
}
|
||||
|
||||
fn expand_braces(pattern: &str) -> Vec<String> {
|
||||
let Some(start) = pattern.find('{') else {
|
||||
return vec![pattern.to_owned()];
|
||||
};
|
||||
let Some(end_rel) = pattern[start..].find('}') else {
|
||||
return vec![pattern.to_owned()];
|
||||
};
|
||||
let end = start + end_rel;
|
||||
let prefix = &pattern[..start];
|
||||
let suffix = &pattern[end + 1..];
|
||||
pattern[start + 1..end]
|
||||
.split(',')
|
||||
.flat_map(|middle| expand_braces(&format!("{prefix}{middle}{suffix}")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn glob_match_one(pattern: &str, path: &str) -> bool {
|
||||
let pattern = pattern.replace('\\', "/");
|
||||
let path = path.replace('\\', "/");
|
||||
let pattern_parts: Vec<&str> = pattern.split('/').collect();
|
||||
let path_parts: Vec<&str> = path.split('/').collect();
|
||||
glob_match_parts(&pattern_parts, &path_parts)
|
||||
}
|
||||
|
||||
fn glob_match_parts(pattern: &[&str], path: &[&str]) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return path.is_empty();
|
||||
}
|
||||
if pattern[0] == "**" {
|
||||
if glob_match_parts(&pattern[1..], path) {
|
||||
return true;
|
||||
}
|
||||
if !path.is_empty() {
|
||||
return glob_match_parts(pattern, &path[1..]);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if path.is_empty() {
|
||||
return false;
|
||||
}
|
||||
if segment_matches(pattern[0], path[0]) {
|
||||
return glob_match_parts(&pattern[1..], &path[1..]);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn segment_matches(pattern: &str, text: &str) -> bool {
|
||||
let p = pattern.as_bytes();
|
||||
let t = text.as_bytes();
|
||||
let (mut pi, mut ti, mut star_idx, mut match_idx) = (0usize, 0usize, None, 0usize);
|
||||
while ti < t.len() {
|
||||
if pi < p.len() && (p[pi] == b'?' || p[pi] == t[ti]) {
|
||||
pi += 1;
|
||||
ti += 1;
|
||||
} else if pi < p.len() && p[pi] == b'*' {
|
||||
star_idx = Some(pi);
|
||||
match_idx = ti;
|
||||
pi += 1;
|
||||
} else if let Some(star) = star_idx {
|
||||
pi = star + 1;
|
||||
match_idx += 1;
|
||||
ti = match_idx;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
while pi < p.len() && p[pi] == b'*' {
|
||||
pi += 1;
|
||||
}
|
||||
pi == p.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn core_tools() -> Vec<Box<dyn Tool>> {
|
||||
vec![
|
||||
Box::new(BashTool),
|
||||
Box::new(ReadTool),
|
||||
Box::new(WriteTool),
|
||||
Box::new(EditTool),
|
||||
Box::new(GlobTool),
|
||||
Box::new(GrepTool),
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn text(result: &ToolResult) -> String {
|
||||
result.content[0].text.clone()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manifests_core_tools() {
|
||||
let names: Vec<_> = core_tools().into_iter().map(|tool| tool.name()).collect();
|
||||
assert_eq!(names, vec!["Bash", "Read", "Write", "Edit", "Glob", "Grep"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_executes_command() {
|
||||
let result = BashTool
|
||||
.execute(json!({ "command": "printf 'hello'" }))
|
||||
.unwrap();
|
||||
assert_eq!(text(&result), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_schema_matches_expected_keys() {
|
||||
let schema = ReadTool.input_schema();
|
||||
let properties = schema["properties"].as_object().unwrap();
|
||||
assert_eq!(schema["required"], json!(["file_path"]));
|
||||
assert!(properties.contains_key("file_path"));
|
||||
assert!(properties.contains_key("offset"));
|
||||
assert!(properties.contains_key("limit"));
|
||||
assert!(properties.contains_key("pages"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_returns_numbered_lines() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("sample.txt");
|
||||
fs::write(&path, "alpha\nbeta\ngamma\n").unwrap();
|
||||
|
||||
let result = ReadTool
|
||||
.execute(json!({ "file_path": path.to_string_lossy(), "offset": 2, "limit": 1 }))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(text(&result), " 2\tbeta");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_creates_file_and_reports_create() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("new.txt");
|
||||
let result = WriteTool
|
||||
.execute(json!({ "file_path": path.to_string_lossy(), "content": "hello" }))
|
||||
.unwrap();
|
||||
let payload: Value = serde_json::from_str(&text(&result)).unwrap();
|
||||
assert_eq!(payload["type"], "create");
|
||||
assert_eq!(fs::read_to_string(path).unwrap(), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn edit_replaces_single_match() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("edit.txt");
|
||||
fs::write(&path, "hello world\n").unwrap();
|
||||
let result = EditTool
|
||||
.execute(json!({
|
||||
"file_path": path.to_string_lossy(),
|
||||
"old_string": "world",
|
||||
"new_string": "rust",
|
||||
"replace_all": false
|
||||
}))
|
||||
.unwrap();
|
||||
let payload: Value = serde_json::from_str(&text(&result)).unwrap();
|
||||
assert_eq!(payload["replaceAll"], false);
|
||||
assert_eq!(fs::read_to_string(path).unwrap(), "hello rust\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn glob_finds_matching_files() {
|
||||
let dir = tempdir().unwrap();
|
||||
fs::create_dir_all(dir.path().join("src/nested")).unwrap();
|
||||
fs::write(dir.path().join("src/lib.rs"), "").unwrap();
|
||||
fs::write(dir.path().join("src/nested/main.rs"), "").unwrap();
|
||||
fs::write(dir.path().join("README.md"), "").unwrap();
|
||||
|
||||
let result = GlobTool
|
||||
.execute(json!({ "pattern": "**/*.rs", "path": dir.path().to_string_lossy() }))
|
||||
.unwrap();
|
||||
let payload: Value = serde_json::from_str(&text(&result)).unwrap();
|
||||
assert_eq!(payload["numFiles"], 2);
|
||||
assert_eq!(
|
||||
payload["filenames"],
|
||||
json!(["src/lib.rs", "src/nested/main.rs"])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grep_supports_file_list_mode() {
|
||||
let dir = tempdir().unwrap();
|
||||
fs::write(dir.path().join("a.rs"), "fn main() {}\nlet alpha = 1;\n").unwrap();
|
||||
fs::write(dir.path().join("b.txt"), "alpha\nalpha\n").unwrap();
|
||||
|
||||
let result = GrepTool
|
||||
.execute(json!({
|
||||
"pattern": "alpha",
|
||||
"path": dir.path().to_string_lossy(),
|
||||
"output_mode": "files_with_matches"
|
||||
}))
|
||||
.unwrap();
|
||||
let payload: Value = serde_json::from_str(&text(&result)).unwrap();
|
||||
assert_eq!(payload["filenames"], json!(["a.rs", "b.txt"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grep_supports_content_and_count_modes() {
|
||||
let dir = tempdir().unwrap();
|
||||
fs::write(dir.path().join("a.rs"), "alpha\nbeta\nalpha\n").unwrap();
|
||||
|
||||
let content = GrepTool
|
||||
.execute(json!({
|
||||
"pattern": "alpha",
|
||||
"path": dir.path().to_string_lossy(),
|
||||
"output_mode": "content",
|
||||
"-n": true
|
||||
}))
|
||||
.unwrap();
|
||||
let content_payload: Value = serde_json::from_str(&text(&content)).unwrap();
|
||||
assert_eq!(content_payload["numLines"], 2);
|
||||
assert!(content_payload["content"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("a.rs:1:alpha"));
|
||||
|
||||
let count = GrepTool
|
||||
.execute(json!({
|
||||
"pattern": "alpha",
|
||||
"path": dir.path().to_string_lossy(),
|
||||
"output_mode": "count"
|
||||
}))
|
||||
.unwrap();
|
||||
let count_payload: Value = serde_json::from_str(&text(&count)).unwrap();
|
||||
assert_eq!(count_payload["numMatches"], 2);
|
||||
assert_eq!(count_payload["content"], "a.rs:2");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user