diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 308a108..806c309 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -54,6 +54,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -104,6 +113,15 @@ dependencies = [ "tools", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -138,6 +156,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deranged" version = "0.5.8" @@ -147,6 +175,16 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -238,6 +276,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.24" @@ -950,6 +998,7 @@ dependencies = [ "regex", "serde", "serde_json", + "sha2", "tokio", "walkdir", ] @@ -1106,6 +1155,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1427,6 +1487,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.9.0" @@ -1469,6 +1535,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 5756b3e..5e7d319 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -15,11 +15,90 @@ const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); const DEFAULT_MAX_RETRIES: u32 = 2; +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::MissingApiKey), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + #[derive(Debug, Clone)] pub struct AnthropicClient { http: reqwest::Client, - api_key: String, - auth_token: Option, + auth: AuthSource, base_url: String, max_retries: u32, initial_backoff: Duration, @@ -31,8 +110,19 @@ impl AnthropicClient { pub fn new(api_key: impl Into) -> Self { Self { http: reqwest::Client::new(), - api_key: api_key.into(), - auth_token: None, + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, base_url: DEFAULT_BASE_URL.to_string(), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, @@ -41,14 +131,37 @@ impl AnthropicClient { } pub fn from_env() -> Result { - Ok(Self::new(read_api_key()?) - .with_auth_token(read_auth_token()) - .with_base_url(read_base_url())) + Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self } #[must_use] pub fn with_auth_token(mut self, auth_token: Option) -> Self { - self.auth_token = auth_token.filter(|token| !token.is_empty()); + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } self } @@ -71,6 +184,11 @@ impl AnthropicClient { self } + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + pub async fn send_message( &self, request: &MessageRequest, @@ -151,28 +269,25 @@ impl AnthropicClient { let resolved_base_url = self.base_url.trim_end_matches('/'); eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}"); eprintln!("[anthropic-client] request_url={request_url}"); - let mut request_builder = self + let request_builder = self .http .post(&request_url) - .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); - let auth_header = self - .auth_token - .as_ref() - .map_or("", |_| "Bearer [REDACTED]"); - eprintln!("[anthropic-client] headers x-api-key=[REDACTED] authorization={auth_header} anthropic-version={ANTHROPIC_VERSION} content-type=application/json"); + eprintln!( + "[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json", + if self.auth.api_key().is_some() { + "[REDACTED]" + } else { + "" + }, + self.auth.masked_authorization_header() + ); - if let Some(auth_token) = &self.auth_token { - request_builder = request_builder.bearer_auth(auth_token); - } - - request_builder - .json(request) - .send() - .await - .map_err(ApiError::from) + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { @@ -189,24 +304,28 @@ impl AnthropicClient { } } -fn read_api_key() -> Result { - match std::env::var("ANTHROPIC_API_KEY") { - Ok(api_key) if !api_key.is_empty() => Ok(api_key), - Ok(_) => Err(ApiError::MissingApiKey), - Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_AUTH_TOKEN") { - Ok(api_key) if !api_key.is_empty() => Ok(api_key), - Ok(_) | Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey), - Err(error) => Err(ApiError::from(error)), - }, +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), Err(error) => Err(ApiError::from(error)), } } +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::MissingApiKey) +} + +#[cfg(test)] fn read_auth_token() -> Option { - match std::env::var("ANTHROPIC_AUTH_TOKEN") { - Ok(token) if !token.is_empty() => Some(token), - _ => None, - } + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) } fn read_base_url() -> String { @@ -308,14 +427,14 @@ mod tests { use std::sync::{Mutex, OnceLock}; use std::time::Duration; + use crate::client::{AuthSource, OAuthTokenSet}; use crate::types::{ContentBlockDelta, MessageRequest}; fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static ENV_LOCK: OnceLock> = OnceLock::new(); - ENV_LOCK - .get_or_init(|| Mutex::new(())) + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) .lock() - .expect("env lock should not be poisoned") + .expect("env lock") } #[test] @@ -357,6 +476,30 @@ mod tests { std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); } + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + #[test] fn message_request_stream_helper_sets_stream_true() { let request = MessageRequest { @@ -436,4 +579,25 @@ mod tests { Some("req_fallback") ); } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index e08e3d7..9d587ee 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -3,7 +3,7 @@ mod error; mod sse; mod types; -pub use client::{AnthropicClient, MessageStream}; +pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet}; pub use error::ApiError; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index 8bd9a42..3803c10 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true publish.workspace = true [dependencies] +sha2 = "0.10" glob = "0.3" regex = "1" serde = { version = "1", features = ["derive"] } diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 4939557..559ae6a 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -24,6 +24,95 @@ pub struct ConfigEntry { pub struct RuntimeConfig { merged: BTreeMap, loaded_entries: Vec, + feature_config: RuntimeFeatureConfig, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimeFeatureConfig { + mcp: McpConfigCollection, + oauth: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct McpConfigCollection { + servers: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ScopedMcpServerConfig { + pub scope: ConfigSource, + pub config: McpServerConfig, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum McpTransport { + Stdio, + Sse, + Http, + Ws, + Sdk, + ClaudeAiProxy, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum McpServerConfig { + Stdio(McpStdioServerConfig), + Sse(McpRemoteServerConfig), + Http(McpRemoteServerConfig), + Ws(McpWebSocketServerConfig), + Sdk(McpSdkServerConfig), + ClaudeAiProxy(McpClaudeAiProxyServerConfig), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpStdioServerConfig { + pub command: String, + pub args: Vec, + pub env: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpRemoteServerConfig { + pub url: String, + pub headers: BTreeMap, + pub headers_helper: Option, + pub oauth: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpWebSocketServerConfig { + pub url: String, + pub headers: BTreeMap, + pub headers_helper: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpSdkServerConfig { + pub name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpClaudeAiProxyServerConfig { + pub url: String, + pub id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpOAuthConfig { + pub client_id: Option, + pub callback_port: Option, + pub auth_server_metadata_url: Option, + pub xaa: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthConfig { + pub client_id: String, + pub authorize_url: String, + pub token_url: String, + pub callback_port: Option, + pub manual_redirect_url: Option, + pub scopes: Vec, } #[derive(Debug)] @@ -95,18 +184,31 @@ impl ConfigLoader { pub fn load(&self) -> Result { let mut merged = BTreeMap::new(); let mut loaded_entries = Vec::new(); + let mut mcp_servers = BTreeMap::new(); for entry in self.discover() { let Some(value) = read_optional_json_object(&entry.path)? else { continue; }; + merge_mcp_servers(&mut mcp_servers, entry.source, &value, &entry.path)?; deep_merge_objects(&mut merged, &value); loaded_entries.push(entry); } + let feature_config = RuntimeFeatureConfig { + mcp: McpConfigCollection { + servers: mcp_servers, + }, + oauth: parse_optional_oauth_config( + &JsonValue::Object(merged.clone()), + "merged settings.oauth", + )?, + }; + Ok(RuntimeConfig { merged, loaded_entries, + feature_config, }) } } @@ -117,6 +219,7 @@ impl RuntimeConfig { Self { merged: BTreeMap::new(), loaded_entries: Vec::new(), + feature_config: RuntimeFeatureConfig::default(), } } @@ -139,6 +242,66 @@ impl RuntimeConfig { pub fn as_json(&self) -> JsonValue { JsonValue::Object(self.merged.clone()) } + + #[must_use] + pub fn feature_config(&self) -> &RuntimeFeatureConfig { + &self.feature_config + } + + #[must_use] + pub fn mcp(&self) -> &McpConfigCollection { + &self.feature_config.mcp + } + + #[must_use] + pub fn oauth(&self) -> Option<&OAuthConfig> { + self.feature_config.oauth.as_ref() + } +} + +impl RuntimeFeatureConfig { + #[must_use] + pub fn mcp(&self) -> &McpConfigCollection { + &self.mcp + } + + #[must_use] + pub fn oauth(&self) -> Option<&OAuthConfig> { + self.oauth.as_ref() + } +} + +impl McpConfigCollection { + #[must_use] + pub fn servers(&self) -> &BTreeMap { + &self.servers + } + + #[must_use] + pub fn get(&self, name: &str) -> Option<&ScopedMcpServerConfig> { + self.servers.get(name) + } +} + +impl ScopedMcpServerConfig { + #[must_use] + pub fn transport(&self) -> McpTransport { + self.config.transport() + } +} + +impl McpServerConfig { + #[must_use] + pub fn transport(&self) -> McpTransport { + match self { + Self::Stdio(_) => McpTransport::Stdio, + Self::Sse(_) => McpTransport::Sse, + Self::Http(_) => McpTransport::Http, + Self::Ws(_) => McpTransport::Ws, + Self::Sdk(_) => McpTransport::Sdk, + Self::ClaudeAiProxy(_) => McpTransport::ClaudeAiProxy, + } + } } fn read_optional_json_object( @@ -165,6 +328,253 @@ fn read_optional_json_object( Ok(Some(object.clone())) } +fn merge_mcp_servers( + target: &mut BTreeMap, + source: ConfigSource, + root: &BTreeMap, + path: &Path, +) -> Result<(), ConfigError> { + let Some(mcp_servers) = root.get("mcpServers") else { + return Ok(()); + }; + let servers = expect_object(mcp_servers, &format!("{}: mcpServers", path.display()))?; + for (name, value) in servers { + let parsed = parse_mcp_server_config( + name, + value, + &format!("{}: mcpServers.{name}", path.display()), + )?; + target.insert( + name.clone(), + ScopedMcpServerConfig { + scope: source, + config: parsed, + }, + ); + } + Ok(()) +} + +fn parse_optional_oauth_config( + root: &JsonValue, + context: &str, +) -> Result, ConfigError> { + let Some(oauth_value) = root.as_object().and_then(|object| object.get("oauth")) else { + return Ok(None); + }; + let object = expect_object(oauth_value, context)?; + let client_id = expect_string(object, "clientId", context)?.to_string(); + let authorize_url = expect_string(object, "authorizeUrl", context)?.to_string(); + let token_url = expect_string(object, "tokenUrl", context)?.to_string(); + let callback_port = optional_u16(object, "callbackPort", context)?; + let manual_redirect_url = + optional_string(object, "manualRedirectUrl", context)?.map(str::to_string); + let scopes = optional_string_array(object, "scopes", context)?.unwrap_or_default(); + Ok(Some(OAuthConfig { + client_id, + authorize_url, + token_url, + callback_port, + manual_redirect_url, + scopes, + })) +} + +fn parse_mcp_server_config( + server_name: &str, + value: &JsonValue, + context: &str, +) -> Result { + let object = expect_object(value, context)?; + let server_type = optional_string(object, "type", context)?.unwrap_or("stdio"); + match server_type { + "stdio" => Ok(McpServerConfig::Stdio(McpStdioServerConfig { + command: expect_string(object, "command", context)?.to_string(), + args: optional_string_array(object, "args", context)?.unwrap_or_default(), + env: optional_string_map(object, "env", context)?.unwrap_or_default(), + })), + "sse" => Ok(McpServerConfig::Sse(parse_mcp_remote_server_config( + object, context, + )?)), + "http" => Ok(McpServerConfig::Http(parse_mcp_remote_server_config( + object, context, + )?)), + "ws" => Ok(McpServerConfig::Ws(McpWebSocketServerConfig { + url: expect_string(object, "url", context)?.to_string(), + headers: optional_string_map(object, "headers", context)?.unwrap_or_default(), + headers_helper: optional_string(object, "headersHelper", context)?.map(str::to_string), + })), + "sdk" => Ok(McpServerConfig::Sdk(McpSdkServerConfig { + name: expect_string(object, "name", context)?.to_string(), + })), + "claudeai-proxy" => Ok(McpServerConfig::ClaudeAiProxy( + McpClaudeAiProxyServerConfig { + url: expect_string(object, "url", context)?.to_string(), + id: expect_string(object, "id", context)?.to_string(), + }, + )), + other => Err(ConfigError::Parse(format!( + "{context}: unsupported MCP server type for {server_name}: {other}" + ))), + } +} + +fn parse_mcp_remote_server_config( + object: &BTreeMap, + context: &str, +) -> Result { + Ok(McpRemoteServerConfig { + url: expect_string(object, "url", context)?.to_string(), + headers: optional_string_map(object, "headers", context)?.unwrap_or_default(), + headers_helper: optional_string(object, "headersHelper", context)?.map(str::to_string), + oauth: parse_optional_mcp_oauth_config(object, context)?, + }) +} + +fn parse_optional_mcp_oauth_config( + object: &BTreeMap, + context: &str, +) -> Result, ConfigError> { + let Some(value) = object.get("oauth") else { + return Ok(None); + }; + let oauth = expect_object(value, &format!("{context}.oauth"))?; + Ok(Some(McpOAuthConfig { + client_id: optional_string(oauth, "clientId", context)?.map(str::to_string), + callback_port: optional_u16(oauth, "callbackPort", context)?, + auth_server_metadata_url: optional_string(oauth, "authServerMetadataUrl", context)? + .map(str::to_string), + xaa: optional_bool(oauth, "xaa", context)?, + })) +} + +fn expect_object<'a>( + value: &'a JsonValue, + context: &str, +) -> Result<&'a BTreeMap, ConfigError> { + value + .as_object() + .ok_or_else(|| ConfigError::Parse(format!("{context}: expected JSON object"))) +} + +fn expect_string<'a>( + object: &'a BTreeMap, + key: &str, + context: &str, +) -> Result<&'a str, ConfigError> { + object + .get(key) + .and_then(JsonValue::as_str) + .ok_or_else(|| ConfigError::Parse(format!("{context}: missing string field {key}"))) +} + +fn optional_string<'a>( + object: &'a BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => value + .as_str() + .map(Some) + .ok_or_else(|| ConfigError::Parse(format!("{context}: field {key} must be a string"))), + None => Ok(None), + } +} + +fn optional_bool( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => value + .as_bool() + .map(Some) + .ok_or_else(|| ConfigError::Parse(format!("{context}: field {key} must be a boolean"))), + None => Ok(None), + } +} + +fn optional_u16( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(number) = value.as_i64() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be an integer" + ))); + }; + let number = u16::try_from(number).map_err(|_| { + ConfigError::Parse(format!("{context}: field {key} is out of range")) + })?; + Ok(Some(number)) + } + None => Ok(None), + } +} + +fn optional_string_array( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result>, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(array) = value.as_array() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be an array" + ))); + }; + array + .iter() + .map(|item| { + item.as_str().map(ToOwned::to_owned).ok_or_else(|| { + ConfigError::Parse(format!( + "{context}: field {key} must contain only strings" + )) + }) + }) + .collect::, _>>() + .map(Some) + } + None => Ok(None), + } +} + +fn optional_string_map( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result>, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(map) = value.as_object() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be an object" + ))); + }; + map.iter() + .map(|(entry_key, entry_value)| { + entry_value + .as_str() + .map(|text| (entry_key.clone(), text.to_string())) + .ok_or_else(|| { + ConfigError::Parse(format!( + "{context}: field {key} must contain only string values" + )) + }) + }) + .collect::, _>>() + .map(Some) + } + None => Ok(None), + } +} + fn deep_merge_objects( target: &mut BTreeMap, source: &BTreeMap, @@ -183,7 +593,9 @@ fn deep_merge_objects( #[cfg(test)] mod tests { - use super::{ConfigLoader, ConfigSource, CLAUDE_CODE_SETTINGS_SCHEMA_NAME}; + use super::{ + ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + }; use crate::json::JsonValue; use std::fs; use std::time::{SystemTime, UNIX_EPOCH}; @@ -266,4 +678,118 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + + #[test] + fn parses_typed_mcp_and_oauth_config() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claude"); + fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{ + "mcpServers": { + "stdio-server": { + "command": "uvx", + "args": ["mcp-server"], + "env": {"TOKEN": "secret"} + }, + "remote-server": { + "type": "http", + "url": "https://example.test/mcp", + "headers": {"Authorization": "Bearer token"}, + "headersHelper": "helper.sh", + "oauth": { + "clientId": "mcp-client", + "callbackPort": 7777, + "authServerMetadataUrl": "https://issuer.test/.well-known/oauth-authorization-server", + "xaa": true + } + } + }, + "oauth": { + "clientId": "runtime-client", + "authorizeUrl": "https://console.test/oauth/authorize", + "tokenUrl": "https://console.test/oauth/token", + "callbackPort": 54545, + "manualRedirectUrl": "https://console.test/oauth/callback", + "scopes": ["org:read", "user:write"] + } + }"#, + ) + .expect("write user settings"); + fs::write( + cwd.join(".claude").join("settings.local.json"), + r#"{ + "mcpServers": { + "remote-server": { + "type": "ws", + "url": "wss://override.test/mcp", + "headers": {"X-Env": "local"} + } + } + }"#, + ) + .expect("write local settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + let stdio_server = loaded + .mcp() + .get("stdio-server") + .expect("stdio server should exist"); + assert_eq!(stdio_server.scope, ConfigSource::User); + assert_eq!(stdio_server.transport(), McpTransport::Stdio); + + let remote_server = loaded + .mcp() + .get("remote-server") + .expect("remote server should exist"); + assert_eq!(remote_server.scope, ConfigSource::Local); + assert_eq!(remote_server.transport(), McpTransport::Ws); + match &remote_server.config { + McpServerConfig::Ws(config) => { + assert_eq!(config.url, "wss://override.test/mcp"); + assert_eq!( + config.headers.get("X-Env").map(String::as_str), + Some("local") + ); + } + other => panic!("expected ws config, got {other:?}"), + } + + let oauth = loaded.oauth().expect("oauth config should exist"); + assert_eq!(oauth.client_id, "runtime-client"); + assert_eq!(oauth.callback_port, Some(54_545)); + assert_eq!(oauth.scopes, vec!["org:read", "user:write"]); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn rejects_invalid_mcp_server_shapes() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claude"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + home.join("settings.json"), + r#"{"mcpServers":{"broken":{"type":"http","url":123}}}"#, + ) + .expect("write broken settings"); + + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + assert!(error + .to_string() + .contains("mcpServers.broken: missing string field url")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } } diff --git a/rust/crates/runtime/src/file_ops.rs b/rust/crates/runtime/src/file_ops.rs index 47a5f7e..e18bed7 100644 --- a/rust/crates/runtime/src/file_ops.rs +++ b/rust/crates/runtime/src/file_ops.rs @@ -285,7 +285,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result { .output_mode .clone() .unwrap_or_else(|| String::from("files_with_matches")); - let context_window = input.context.or(input.context_short).unwrap_or(0); + let context = input.context.or(input.context_short).unwrap_or(0); let mut filenames = Vec::new(); let mut content_lines = Vec::new(); @@ -325,8 +325,8 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result { filenames.push(file_path.to_string_lossy().into_owned()); if output_mode == "content" { for index in matched_lines { - let start = index.saturating_sub(input.before.unwrap_or(context_window)); - let end = (index + input.after.unwrap_or(context_window) + 1).min(lines.len()); + let start = index.saturating_sub(input.before.unwrap_or(context)); + let end = (index + input.after.unwrap_or(context) + 1).min(lines.len()); for (current, line_content) in lines.iter().enumerate().take(end).skip(start) { let prefix = if input.line_numbers.unwrap_or(true) { format!("{}:{}:", file_path.to_string_lossy(), current + 1) @@ -341,7 +341,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result { let (filenames, applied_limit, applied_offset) = apply_limit(filenames, input.head_limit, input.offset); - let content = if output_mode == "content" { + let rendered_content = if output_mode == "content" { let (lines, limit, offset) = apply_limit(content_lines, input.head_limit, input.offset); return Ok(GrepSearchOutput { mode: Some(output_mode), @@ -361,7 +361,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result { mode: Some(output_mode.clone()), num_files: filenames.len(), filenames, - content, + content: rendered_content, num_lines: None, num_matches: (output_mode == "count").then_some(total_matches), applied_limit, diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 573f858..b257eeb 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -5,8 +5,12 @@ mod config; mod conversation; mod file_ops; mod json; +mod mcp; +mod mcp_client; +mod oauth; mod permissions; mod prompt; +mod remote; mod session; mod usage; @@ -17,8 +21,10 @@ pub use compact::{ get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, }; pub use config::{ - ConfigEntry, ConfigError, ConfigLoader, ConfigSource, RuntimeConfig, - CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig, + McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, + McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, + RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, @@ -29,6 +35,19 @@ pub use file_ops::{ GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; +pub use mcp::{ + mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, + scoped_mcp_config_hash, unwrap_ccr_proxy_url, +}; +pub use mcp_client::{ + McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, + McpRemoteTransport, McpSdkTransport, McpStdioTransport, +}; +pub use oauth::{ + code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, + OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + PkceChallengeMethod, PkceCodePair, +}; pub use permissions::{ PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, PermissionPrompter, PermissionRequest, @@ -37,6 +56,11 @@ pub use prompt::{ load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError, SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY, }; +pub use remote::{ + inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url, + RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL, + DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS, +}; pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError}; pub use usage::{ format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker, diff --git a/rust/crates/runtime/src/mcp.rs b/rust/crates/runtime/src/mcp.rs new file mode 100644 index 0000000..103fbe4 --- /dev/null +++ b/rust/crates/runtime/src/mcp.rs @@ -0,0 +1,300 @@ +use crate::config::{McpServerConfig, ScopedMcpServerConfig}; + +const CLAUDEAI_SERVER_PREFIX: &str = "claude.ai "; +const CCR_PROXY_PATH_MARKERS: [&str; 2] = ["/v2/session_ingress/shttp/mcp/", "/v2/ccr-sessions/"]; + +#[must_use] +pub fn normalize_name_for_mcp(name: &str) -> String { + let mut normalized = name + .chars() + .map(|ch| match ch { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '-' => ch, + _ => '_', + }) + .collect::(); + + if name.starts_with(CLAUDEAI_SERVER_PREFIX) { + normalized = collapse_underscores(&normalized) + .trim_matches('_') + .to_string(); + } + + normalized +} + +#[must_use] +pub fn mcp_tool_prefix(server_name: &str) -> String { + format!("mcp__{}__", normalize_name_for_mcp(server_name)) +} + +#[must_use] +pub fn mcp_tool_name(server_name: &str, tool_name: &str) -> String { + format!( + "{}{}", + mcp_tool_prefix(server_name), + normalize_name_for_mcp(tool_name) + ) +} + +#[must_use] +pub fn unwrap_ccr_proxy_url(url: &str) -> String { + if !CCR_PROXY_PATH_MARKERS + .iter() + .any(|marker| url.contains(marker)) + { + return url.to_string(); + } + + let Some(query_start) = url.find('?') else { + return url.to_string(); + }; + let query = &url[query_start + 1..]; + for pair in query.split('&') { + let mut parts = pair.splitn(2, '='); + if matches!(parts.next(), Some("mcp_url")) { + if let Some(value) = parts.next() { + return percent_decode(value); + } + } + } + + url.to_string() +} + +#[must_use] +pub fn mcp_server_signature(config: &McpServerConfig) -> Option { + match config { + McpServerConfig::Stdio(config) => { + let mut command = vec![config.command.clone()]; + command.extend(config.args.clone()); + Some(format!("stdio:{}", render_command_signature(&command))) + } + McpServerConfig::Sse(config) | McpServerConfig::Http(config) => { + Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) + } + McpServerConfig::Ws(config) => Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))), + McpServerConfig::ClaudeAiProxy(config) => { + Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) + } + McpServerConfig::Sdk(_) => None, + } +} + +#[must_use] +pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String { + let rendered = match &config.config { + McpServerConfig::Stdio(stdio) => format!( + "stdio|{}|{}|{}", + stdio.command, + render_command_signature(&stdio.args), + render_env_signature(&stdio.env) + ), + McpServerConfig::Sse(remote) => format!( + "sse|{}|{}|{}|{}", + remote.url, + render_env_signature(&remote.headers), + remote.headers_helper.as_deref().unwrap_or(""), + render_oauth_signature(remote.oauth.as_ref()) + ), + McpServerConfig::Http(remote) => format!( + "http|{}|{}|{}|{}", + remote.url, + render_env_signature(&remote.headers), + remote.headers_helper.as_deref().unwrap_or(""), + render_oauth_signature(remote.oauth.as_ref()) + ), + McpServerConfig::Ws(ws) => format!( + "ws|{}|{}|{}", + ws.url, + render_env_signature(&ws.headers), + ws.headers_helper.as_deref().unwrap_or("") + ), + McpServerConfig::Sdk(sdk) => format!("sdk|{}", sdk.name), + McpServerConfig::ClaudeAiProxy(proxy) => { + format!("claudeai-proxy|{}|{}", proxy.url, proxy.id) + } + }; + stable_hex_hash(&rendered) +} + +fn render_command_signature(command: &[String]) -> String { + let escaped = command + .iter() + .map(|part| part.replace('\\', "\\\\").replace('|', "\\|")) + .collect::>(); + format!("[{}]", escaped.join("|")) +} + +fn render_env_signature(map: &std::collections::BTreeMap) -> String { + map.iter() + .map(|(key, value)| format!("{key}={value}")) + .collect::>() + .join(";") +} + +fn render_oauth_signature(oauth: Option<&crate::config::McpOAuthConfig>) -> String { + oauth.map_or_else(String::new, |oauth| { + format!( + "{}|{}|{}|{}", + oauth.client_id.as_deref().unwrap_or(""), + oauth + .callback_port + .map_or_else(String::new, |port| port.to_string()), + oauth.auth_server_metadata_url.as_deref().unwrap_or(""), + oauth.xaa.map_or_else(String::new, |flag| flag.to_string()) + ) + }) +} + +fn stable_hex_hash(value: &str) -> String { + let mut hash = 0xcbf2_9ce4_8422_2325_u64; + for byte in value.as_bytes() { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(0x0100_0000_01b3); + } + format!("{hash:016x}") +} + +fn collapse_underscores(value: &str) -> String { + let mut collapsed = String::with_capacity(value.len()); + let mut last_was_underscore = false; + for ch in value.chars() { + if ch == '_' { + if !last_was_underscore { + collapsed.push(ch); + } + last_was_underscore = true; + } else { + collapsed.push(ch); + last_was_underscore = false; + } + } + collapsed +} + +fn percent_decode(value: &str) -> String { + let bytes = value.as_bytes(); + let mut decoded = Vec::with_capacity(bytes.len()); + let mut index = 0; + while index < bytes.len() { + match bytes[index] { + b'%' if index + 2 < bytes.len() => { + let hex = &value[index + 1..index + 3]; + if let Ok(byte) = u8::from_str_radix(hex, 16) { + decoded.push(byte); + index += 3; + continue; + } + decoded.push(bytes[index]); + index += 1; + } + b'+' => { + decoded.push(b' '); + index += 1; + } + byte => { + decoded.push(byte); + index += 1; + } + } + } + String::from_utf8_lossy(&decoded).into_owned() +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use crate::config::{ + ConfigSource, McpRemoteServerConfig, McpServerConfig, McpStdioServerConfig, + McpWebSocketServerConfig, ScopedMcpServerConfig, + }; + + use super::{ + mcp_server_signature, mcp_tool_name, normalize_name_for_mcp, scoped_mcp_config_hash, + unwrap_ccr_proxy_url, + }; + + #[test] + fn normalizes_server_names_for_mcp_tooling() { + assert_eq!(normalize_name_for_mcp("github.com"), "github_com"); + assert_eq!(normalize_name_for_mcp("tool name!"), "tool_name_"); + assert_eq!( + normalize_name_for_mcp("claude.ai Example Server!!"), + "claude_ai_Example_Server" + ); + assert_eq!( + mcp_tool_name("claude.ai Example Server", "weather tool"), + "mcp__claude_ai_Example_Server__weather_tool" + ); + } + + #[test] + fn unwraps_ccr_proxy_urls_for_signature_matching() { + let wrapped = "https://api.anthropic.com/v2/session_ingress/shttp/mcp/123?mcp_url=https%3A%2F%2Fvendor.example%2Fmcp&other=1"; + assert_eq!(unwrap_ccr_proxy_url(wrapped), "https://vendor.example/mcp"); + assert_eq!( + unwrap_ccr_proxy_url("https://vendor.example/mcp"), + "https://vendor.example/mcp" + ); + } + + #[test] + fn computes_signatures_for_stdio_and_remote_servers() { + let stdio = McpServerConfig::Stdio(McpStdioServerConfig { + command: "uvx".to_string(), + args: vec!["mcp-server".to_string()], + env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + }); + assert_eq!( + mcp_server_signature(&stdio), + Some("stdio:[uvx|mcp-server]".to_string()) + ); + + let remote = McpServerConfig::Ws(McpWebSocketServerConfig { + url: "https://api.anthropic.com/v2/ccr-sessions/1?mcp_url=wss%3A%2F%2Fvendor.example%2Fmcp".to_string(), + headers: BTreeMap::new(), + headers_helper: None, + }); + assert_eq!( + mcp_server_signature(&remote), + Some("url:wss://vendor.example/mcp".to_string()) + ); + } + + #[test] + fn scoped_hash_ignores_scope_but_tracks_config_content() { + let base_config = McpServerConfig::Http(McpRemoteServerConfig { + url: "https://vendor.example/mcp".to_string(), + headers: BTreeMap::from([("Authorization".to_string(), "Bearer token".to_string())]), + headers_helper: Some("helper.sh".to_string()), + oauth: None, + }); + let user = ScopedMcpServerConfig { + scope: ConfigSource::User, + config: base_config.clone(), + }; + let local = ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: base_config, + }; + assert_eq!( + scoped_mcp_config_hash(&user), + scoped_mcp_config_hash(&local) + ); + + let changed = ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Http(McpRemoteServerConfig { + url: "https://vendor.example/v2/mcp".to_string(), + headers: BTreeMap::new(), + headers_helper: None, + oauth: None, + }), + }; + assert_ne!( + scoped_mcp_config_hash(&user), + scoped_mcp_config_hash(&changed) + ); + } +} diff --git a/rust/crates/runtime/src/mcp_client.rs b/rust/crates/runtime/src/mcp_client.rs new file mode 100644 index 0000000..23ccb95 --- /dev/null +++ b/rust/crates/runtime/src/mcp_client.rs @@ -0,0 +1,236 @@ +use std::collections::BTreeMap; + +use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig}; +use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum McpClientTransport { + Stdio(McpStdioTransport), + Sse(McpRemoteTransport), + Http(McpRemoteTransport), + WebSocket(McpRemoteTransport), + Sdk(McpSdkTransport), + ClaudeAiProxy(McpClaudeAiProxyTransport), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpStdioTransport { + pub command: String, + pub args: Vec, + pub env: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpRemoteTransport { + pub url: String, + pub headers: BTreeMap, + pub headers_helper: Option, + pub auth: McpClientAuth, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpSdkTransport { + pub name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpClaudeAiProxyTransport { + pub url: String, + pub id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum McpClientAuth { + None, + OAuth(McpOAuthConfig), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpClientBootstrap { + pub server_name: String, + pub normalized_name: String, + pub tool_prefix: String, + pub signature: Option, + pub transport: McpClientTransport, +} + +impl McpClientBootstrap { + #[must_use] + pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self { + Self { + server_name: server_name.to_string(), + normalized_name: normalize_name_for_mcp(server_name), + tool_prefix: mcp_tool_prefix(server_name), + signature: mcp_server_signature(&config.config), + transport: McpClientTransport::from_config(&config.config), + } + } +} + +impl McpClientTransport { + #[must_use] + pub fn from_config(config: &McpServerConfig) -> Self { + match config { + McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport { + command: config.command.clone(), + args: config.args.clone(), + env: config.env.clone(), + }), + McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport { + url: config.url.clone(), + headers: config.headers.clone(), + headers_helper: config.headers_helper.clone(), + auth: McpClientAuth::from_oauth(config.oauth.clone()), + }), + McpServerConfig::Http(config) => Self::Http(McpRemoteTransport { + url: config.url.clone(), + headers: config.headers.clone(), + headers_helper: config.headers_helper.clone(), + auth: McpClientAuth::from_oauth(config.oauth.clone()), + }), + McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport { + url: config.url.clone(), + headers: config.headers.clone(), + headers_helper: config.headers_helper.clone(), + auth: McpClientAuth::None, + }), + McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport { + name: config.name.clone(), + }), + McpServerConfig::ClaudeAiProxy(config) => { + Self::ClaudeAiProxy(McpClaudeAiProxyTransport { + url: config.url.clone(), + id: config.id.clone(), + }) + } + } + } +} + +impl McpClientAuth { + #[must_use] + pub fn from_oauth(oauth: Option) -> Self { + oauth.map_or(Self::None, Self::OAuth) + } + + #[must_use] + pub const fn requires_user_auth(&self) -> bool { + matches!(self, Self::OAuth(_)) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use crate::config::{ + ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, + McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig, + }; + + use super::{McpClientAuth, McpClientBootstrap, McpClientTransport}; + + #[test] + fn bootstraps_stdio_servers_into_transport_targets() { + let config = ScopedMcpServerConfig { + scope: ConfigSource::User, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "uvx".to_string(), + args: vec!["mcp-server".to_string()], + env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + }), + }; + + let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config); + assert_eq!(bootstrap.normalized_name, "stdio-server"); + assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__"); + assert_eq!( + bootstrap.signature.as_deref(), + Some("stdio:[uvx|mcp-server]") + ); + match bootstrap.transport { + McpClientTransport::Stdio(transport) => { + assert_eq!(transport.command, "uvx"); + assert_eq!(transport.args, vec!["mcp-server"]); + assert_eq!( + transport.env.get("TOKEN").map(String::as_str), + Some("secret") + ); + } + other => panic!("expected stdio transport, got {other:?}"), + } + } + + #[test] + fn bootstraps_remote_servers_with_oauth_auth() { + let config = ScopedMcpServerConfig { + scope: ConfigSource::Project, + config: McpServerConfig::Http(McpRemoteServerConfig { + url: "https://vendor.example/mcp".to_string(), + headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]), + headers_helper: Some("helper.sh".to_string()), + oauth: Some(McpOAuthConfig { + client_id: Some("client-id".to_string()), + callback_port: Some(7777), + auth_server_metadata_url: Some( + "https://issuer.example/.well-known/oauth-authorization-server".to_string(), + ), + xaa: Some(true), + }), + }), + }; + + let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config); + assert_eq!(bootstrap.normalized_name, "remote_server"); + match bootstrap.transport { + McpClientTransport::Http(transport) => { + assert_eq!(transport.url, "https://vendor.example/mcp"); + assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh")); + assert!(transport.auth.requires_user_auth()); + match transport.auth { + McpClientAuth::OAuth(oauth) => { + assert_eq!(oauth.client_id.as_deref(), Some("client-id")); + } + other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"), + } + } + other => panic!("expected http transport, got {other:?}"), + } + } + + #[test] + fn bootstraps_websocket_and_sdk_transports_without_oauth() { + let ws = ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Ws(McpWebSocketServerConfig { + url: "wss://vendor.example/mcp".to_string(), + headers: BTreeMap::new(), + headers_helper: None, + }), + }; + let sdk = ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Sdk(McpSdkServerConfig { + name: "sdk-server".to_string(), + }), + }; + + let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws); + match ws_bootstrap.transport { + McpClientTransport::WebSocket(transport) => { + assert_eq!(transport.url, "wss://vendor.example/mcp"); + assert!(!transport.auth.requires_user_auth()); + } + other => panic!("expected websocket transport, got {other:?}"), + } + + let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk); + assert_eq!(sdk_bootstrap.signature, None); + match sdk_bootstrap.transport { + McpClientTransport::Sdk(transport) => { + assert_eq!(transport.name, "sdk-server"); + } + other => panic!("expected sdk transport, got {other:?}"), + } + } +} diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs new file mode 100644 index 0000000..320a8ee --- /dev/null +++ b/rust/crates/runtime/src/oauth.rs @@ -0,0 +1,338 @@ +use std::collections::BTreeMap; +use std::fs::File; +use std::io::{self, Read}; + +use sha2::{Digest, Sha256}; + +use crate::config::OAuthConfig; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + pub scopes: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PkceCodePair { + pub verifier: String, + pub challenge: String, + pub challenge_method: PkceChallengeMethod, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PkceChallengeMethod { + S256, +} + +impl PkceChallengeMethod { + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::S256 => "S256", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthAuthorizationRequest { + pub authorize_url: String, + pub client_id: String, + pub redirect_uri: String, + pub scopes: Vec, + pub state: String, + pub code_challenge: String, + pub code_challenge_method: PkceChallengeMethod, + pub extra_params: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthTokenExchangeRequest { + pub grant_type: &'static str, + pub code: String, + pub redirect_uri: String, + pub client_id: String, + pub code_verifier: String, + pub state: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthRefreshRequest { + pub grant_type: &'static str, + pub refresh_token: String, + pub client_id: String, + pub scopes: Vec, +} + +impl OAuthAuthorizationRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + redirect_uri: impl Into, + state: impl Into, + pkce: &PkceCodePair, + ) -> Self { + Self { + authorize_url: config.authorize_url.clone(), + client_id: config.client_id.clone(), + redirect_uri: redirect_uri.into(), + scopes: config.scopes.clone(), + state: state.into(), + code_challenge: pkce.challenge.clone(), + code_challenge_method: pkce.challenge_method, + extra_params: BTreeMap::new(), + } + } + + #[must_use] + pub fn with_extra_param(mut self, key: impl Into, value: impl Into) -> Self { + self.extra_params.insert(key.into(), value.into()); + self + } + + #[must_use] + pub fn build_url(&self) -> String { + let mut params = vec![ + ("response_type", "code".to_string()), + ("client_id", self.client_id.clone()), + ("redirect_uri", self.redirect_uri.clone()), + ("scope", self.scopes.join(" ")), + ("state", self.state.clone()), + ("code_challenge", self.code_challenge.clone()), + ( + "code_challenge_method", + self.code_challenge_method.as_str().to_string(), + ), + ]; + params.extend( + self.extra_params + .iter() + .map(|(key, value)| (key.as_str(), value.clone())), + ); + let query = params + .into_iter() + .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value))) + .collect::>() + .join("&"); + format!( + "{}{}{}", + self.authorize_url, + if self.authorize_url.contains('?') { + '&' + } else { + '?' + }, + query + ) + } +} + +impl OAuthTokenExchangeRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + code: impl Into, + state: impl Into, + verifier: impl Into, + redirect_uri: impl Into, + ) -> Self { + let _ = config; + Self { + grant_type: "authorization_code", + code: code.into(), + redirect_uri: redirect_uri.into(), + client_id: config.client_id.clone(), + code_verifier: verifier.into(), + state: state.into(), + } + } + + #[must_use] + pub fn form_params(&self) -> BTreeMap<&str, String> { + BTreeMap::from([ + ("grant_type", self.grant_type.to_string()), + ("code", self.code.clone()), + ("redirect_uri", self.redirect_uri.clone()), + ("client_id", self.client_id.clone()), + ("code_verifier", self.code_verifier.clone()), + ("state", self.state.clone()), + ]) + } +} + +impl OAuthRefreshRequest { + #[must_use] + pub fn from_config( + config: &OAuthConfig, + refresh_token: impl Into, + scopes: Option>, + ) -> Self { + Self { + grant_type: "refresh_token", + refresh_token: refresh_token.into(), + client_id: config.client_id.clone(), + scopes: scopes.unwrap_or_else(|| config.scopes.clone()), + } + } + + #[must_use] + pub fn form_params(&self) -> BTreeMap<&str, String> { + BTreeMap::from([ + ("grant_type", self.grant_type.to_string()), + ("refresh_token", self.refresh_token.clone()), + ("client_id", self.client_id.clone()), + ("scope", self.scopes.join(" ")), + ]) + } +} + +pub fn generate_pkce_pair() -> io::Result { + let verifier = generate_random_token(32)?; + Ok(PkceCodePair { + challenge: code_challenge_s256(&verifier), + verifier, + challenge_method: PkceChallengeMethod::S256, + }) +} + +pub fn generate_state() -> io::Result { + generate_random_token(32) +} + +#[must_use] +pub fn code_challenge_s256(verifier: &str) -> String { + let digest = Sha256::digest(verifier.as_bytes()); + base64url_encode(&digest) +} + +#[must_use] +pub fn loopback_redirect_uri(port: u16) -> String { + format!("http://localhost:{port}/callback") +} + +fn generate_random_token(bytes: usize) -> io::Result { + let mut buffer = vec![0_u8; bytes]; + File::open("/dev/urandom")?.read_exact(&mut buffer)?; + Ok(base64url_encode(&buffer)) +} + +fn base64url_encode(bytes: &[u8]) -> String { + const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + let mut output = String::new(); + let mut index = 0; + while index + 3 <= bytes.len() { + let block = (u32::from(bytes[index]) << 16) + | (u32::from(bytes[index + 1]) << 8) + | u32::from(bytes[index + 2]); + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); + output.push(TABLE[(block & 0x3F) as usize] as char); + index += 3; + } + match bytes.len().saturating_sub(index) { + 1 => { + let block = u32::from(bytes[index]) << 16; + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + } + 2 => { + let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8); + output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); + output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); + } + _ => {} + } + output +} + +fn percent_encode(value: &str) -> String { + let mut encoded = String::new(); + for byte in value.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(char::from(byte)); + } + _ => { + use std::fmt::Write as _; + let _ = write!(&mut encoded, "%{byte:02X}"); + } + } + } + encoded +} + +#[cfg(test)] +mod tests { + use super::{ + code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, + OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, + }; + + fn sample_config() -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url: "https://console.test/oauth/token".to_string(), + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + #[test] + fn s256_challenge_matches_expected_vector() { + assert_eq!( + code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"), + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + ); + } + + #[test] + fn generates_pkce_pair_and_state() { + let pair = generate_pkce_pair().expect("pkce pair"); + let state = generate_state().expect("state"); + assert!(!pair.verifier.is_empty()); + assert!(!pair.challenge.is_empty()); + assert!(!state.is_empty()); + } + + #[test] + fn builds_authorize_url_and_form_requests() { + let config = sample_config(); + let pair = generate_pkce_pair().expect("pkce"); + let url = OAuthAuthorizationRequest::from_config( + &config, + loopback_redirect_uri(4545), + "state-123", + &pair, + ) + .with_extra_param("login_hint", "user@example.com") + .build_url(); + assert!(url.starts_with("https://console.test/oauth/authorize?")); + assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=runtime-client")); + assert!(url.contains("scope=org%3Aread%20user%3Awrite")); + assert!(url.contains("login_hint=user%40example.com")); + + let exchange = OAuthTokenExchangeRequest::from_config( + &config, + "auth-code", + "state-123", + pair.verifier, + loopback_redirect_uri(4545), + ); + assert_eq!( + exchange.form_params().get("grant_type").map(String::as_str), + Some("authorization_code") + ); + + let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None); + assert_eq!( + refresh.form_params().get("scope").map(String::as_str), + Some("org:read user:write") + ); + } +} diff --git a/rust/crates/runtime/src/remote.rs b/rust/crates/runtime/src/remote.rs new file mode 100644 index 0000000..24ee780 --- /dev/null +++ b/rust/crates/runtime/src/remote.rs @@ -0,0 +1,401 @@ +use std::collections::BTreeMap; +use std::env; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +pub const DEFAULT_REMOTE_BASE_URL: &str = "https://api.anthropic.com"; +pub const DEFAULT_SESSION_TOKEN_PATH: &str = "/run/ccr/session_token"; +pub const DEFAULT_SYSTEM_CA_BUNDLE: &str = "/etc/ssl/certs/ca-certificates.crt"; + +pub const UPSTREAM_PROXY_ENV_KEYS: [&str; 8] = [ + "HTTPS_PROXY", + "https_proxy", + "NO_PROXY", + "no_proxy", + "SSL_CERT_FILE", + "NODE_EXTRA_CA_CERTS", + "REQUESTS_CA_BUNDLE", + "CURL_CA_BUNDLE", +]; + +pub const NO_PROXY_HOSTS: [&str; 16] = [ + "localhost", + "127.0.0.1", + "::1", + "169.254.0.0/16", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "anthropic.com", + ".anthropic.com", + "*.anthropic.com", + "github.com", + "api.github.com", + "*.github.com", + "*.githubusercontent.com", + "registry.npmjs.org", + "index.crates.io", +]; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RemoteSessionContext { + pub enabled: bool, + pub session_id: Option, + pub base_url: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpstreamProxyBootstrap { + pub remote: RemoteSessionContext, + pub upstream_proxy_enabled: bool, + pub token_path: PathBuf, + pub ca_bundle_path: PathBuf, + pub system_ca_path: PathBuf, + pub token: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpstreamProxyState { + pub enabled: bool, + pub proxy_url: Option, + pub ca_bundle_path: Option, + pub no_proxy: String, +} + +impl RemoteSessionContext { + #[must_use] + pub fn from_env() -> Self { + Self::from_env_map(&env::vars().collect()) + } + + #[must_use] + pub fn from_env_map(env_map: &BTreeMap) -> Self { + Self { + enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")), + session_id: env_map + .get("CLAUDE_CODE_REMOTE_SESSION_ID") + .filter(|value| !value.is_empty()) + .cloned(), + base_url: env_map + .get("ANTHROPIC_BASE_URL") + .filter(|value| !value.is_empty()) + .cloned() + .unwrap_or_else(|| DEFAULT_REMOTE_BASE_URL.to_string()), + } + } +} + +impl UpstreamProxyBootstrap { + #[must_use] + pub fn from_env() -> Self { + Self::from_env_map(&env::vars().collect()) + } + + #[must_use] + pub fn from_env_map(env_map: &BTreeMap) -> Self { + let remote = RemoteSessionContext::from_env_map(env_map); + let token_path = env_map + .get("CCR_SESSION_TOKEN_PATH") + .filter(|value| !value.is_empty()) + .map_or_else(|| PathBuf::from(DEFAULT_SESSION_TOKEN_PATH), PathBuf::from); + let system_ca_path = env_map + .get("CCR_SYSTEM_CA_BUNDLE") + .filter(|value| !value.is_empty()) + .map_or_else(|| PathBuf::from(DEFAULT_SYSTEM_CA_BUNDLE), PathBuf::from); + let ca_bundle_path = env_map + .get("CCR_CA_BUNDLE_PATH") + .filter(|value| !value.is_empty()) + .map_or_else(default_ca_bundle_path, PathBuf::from); + let token = read_token(&token_path).ok().flatten(); + + Self { + remote, + upstream_proxy_enabled: env_truthy(env_map.get("CCR_UPSTREAM_PROXY_ENABLED")), + token_path, + ca_bundle_path, + system_ca_path, + token, + } + } + + #[must_use] + pub fn should_enable(&self) -> bool { + self.remote.enabled + && self.upstream_proxy_enabled + && self.remote.session_id.is_some() + && self.token.is_some() + } + + #[must_use] + pub fn ws_url(&self) -> String { + upstream_proxy_ws_url(&self.remote.base_url) + } + + #[must_use] + pub fn state_for_port(&self, port: u16) -> UpstreamProxyState { + if !self.should_enable() { + return UpstreamProxyState::disabled(); + } + UpstreamProxyState { + enabled: true, + proxy_url: Some(format!("http://127.0.0.1:{port}")), + ca_bundle_path: Some(self.ca_bundle_path.clone()), + no_proxy: no_proxy_list(), + } + } +} + +impl UpstreamProxyState { + #[must_use] + pub fn disabled() -> Self { + Self { + enabled: false, + proxy_url: None, + ca_bundle_path: None, + no_proxy: no_proxy_list(), + } + } + + #[must_use] + pub fn subprocess_env(&self) -> BTreeMap { + if !self.enabled { + return BTreeMap::new(); + } + let Some(proxy_url) = &self.proxy_url else { + return BTreeMap::new(); + }; + let Some(ca_bundle_path) = &self.ca_bundle_path else { + return BTreeMap::new(); + }; + let ca_bundle_path = ca_bundle_path.to_string_lossy().into_owned(); + BTreeMap::from([ + ("HTTPS_PROXY".to_string(), proxy_url.clone()), + ("https_proxy".to_string(), proxy_url.clone()), + ("NO_PROXY".to_string(), self.no_proxy.clone()), + ("no_proxy".to_string(), self.no_proxy.clone()), + ("SSL_CERT_FILE".to_string(), ca_bundle_path.clone()), + ("NODE_EXTRA_CA_CERTS".to_string(), ca_bundle_path.clone()), + ("REQUESTS_CA_BUNDLE".to_string(), ca_bundle_path.clone()), + ("CURL_CA_BUNDLE".to_string(), ca_bundle_path), + ]) + } +} + +pub fn read_token(path: &Path) -> io::Result> { + match fs::read_to_string(path) { + Ok(contents) => { + let token = contents.trim(); + if token.is_empty() { + Ok(None) + } else { + Ok(Some(token.to_string())) + } + } + Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None), + Err(error) => Err(error), + } +} + +#[must_use] +pub fn upstream_proxy_ws_url(base_url: &str) -> String { + let base = base_url.trim_end_matches('/'); + let ws_base = if let Some(stripped) = base.strip_prefix("https://") { + format!("wss://{stripped}") + } else if let Some(stripped) = base.strip_prefix("http://") { + format!("ws://{stripped}") + } else { + format!("wss://{base}") + }; + format!("{ws_base}/v1/code/upstreamproxy/ws") +} + +#[must_use] +pub fn no_proxy_list() -> String { + let mut hosts = NO_PROXY_HOSTS.to_vec(); + hosts.extend(["pypi.org", "files.pythonhosted.org", "proxy.golang.org"]); + hosts.join(",") +} + +#[must_use] +pub fn inherited_upstream_proxy_env( + env_map: &BTreeMap, +) -> BTreeMap { + if !(env_map.contains_key("HTTPS_PROXY") && env_map.contains_key("SSL_CERT_FILE")) { + return BTreeMap::new(); + } + UPSTREAM_PROXY_ENV_KEYS + .iter() + .filter_map(|key| { + env_map + .get(*key) + .map(|value| ((*key).to_string(), value.clone())) + }) + .collect() +} + +fn default_ca_bundle_path() -> PathBuf { + env::var_os("HOME") + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join(".ccr") + .join("ca-bundle.crt") +} + +fn env_truthy(value: Option<&String>) -> bool { + value.is_some_and(|raw| { + matches!( + raw.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +#[cfg(test)] +mod tests { + use super::{ + inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url, + RemoteSessionContext, UpstreamProxyBootstrap, + }; + use std::collections::BTreeMap; + use std::fs; + use std::path::PathBuf; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-remote-{nanos}")) + } + + #[test] + fn remote_context_reads_env_state() { + let env = BTreeMap::from([ + ("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()), + ( + "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "session-123".to_string(), + ), + ( + "ANTHROPIC_BASE_URL".to_string(), + "https://remote.test".to_string(), + ), + ]); + let context = RemoteSessionContext::from_env_map(&env); + assert!(context.enabled); + assert_eq!(context.session_id.as_deref(), Some("session-123")); + assert_eq!(context.base_url, "https://remote.test"); + } + + #[test] + fn bootstrap_fails_open_when_token_or_session_is_missing() { + let env = BTreeMap::from([ + ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), + ]); + let bootstrap = UpstreamProxyBootstrap::from_env_map(&env); + assert!(!bootstrap.should_enable()); + assert!(!bootstrap.state_for_port(8080).enabled); + } + + #[test] + fn bootstrap_derives_proxy_state_and_env() { + let root = temp_dir(); + let token_path = root.join("session_token"); + fs::create_dir_all(&root).expect("temp dir"); + fs::write(&token_path, "secret-token\n").expect("write token"); + + let env = BTreeMap::from([ + ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), + ( + "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "session-123".to_string(), + ), + ( + "ANTHROPIC_BASE_URL".to_string(), + "https://remote.test".to_string(), + ), + ( + "CCR_SESSION_TOKEN_PATH".to_string(), + token_path.to_string_lossy().into_owned(), + ), + ( + "CCR_CA_BUNDLE_PATH".to_string(), + root.join("ca-bundle.crt").to_string_lossy().into_owned(), + ), + ]); + + let bootstrap = UpstreamProxyBootstrap::from_env_map(&env); + assert!(bootstrap.should_enable()); + assert_eq!(bootstrap.token.as_deref(), Some("secret-token")); + assert_eq!( + bootstrap.ws_url(), + "wss://remote.test/v1/code/upstreamproxy/ws" + ); + + let state = bootstrap.state_for_port(9443); + assert!(state.enabled); + let env = state.subprocess_env(); + assert_eq!( + env.get("HTTPS_PROXY").map(String::as_str), + Some("http://127.0.0.1:9443") + ); + assert_eq!( + env.get("SSL_CERT_FILE").map(String::as_str), + Some(root.join("ca-bundle.crt").to_string_lossy().as_ref()) + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn token_reader_trims_and_handles_missing_files() { + let root = temp_dir(); + fs::create_dir_all(&root).expect("temp dir"); + let token_path = root.join("session_token"); + fs::write(&token_path, " abc123 \n").expect("write token"); + assert_eq!( + read_token(&token_path).expect("read token").as_deref(), + Some("abc123") + ); + assert_eq!( + read_token(&root.join("missing")).expect("missing token"), + None + ); + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn inherited_proxy_env_requires_proxy_and_ca() { + let env = BTreeMap::from([ + ( + "HTTPS_PROXY".to_string(), + "http://127.0.0.1:8888".to_string(), + ), + ( + "SSL_CERT_FILE".to_string(), + "/tmp/ca-bundle.crt".to_string(), + ), + ("NO_PROXY".to_string(), "localhost".to_string()), + ]); + let inherited = inherited_upstream_proxy_env(&env); + assert_eq!(inherited.len(), 3); + assert_eq!( + inherited.get("NO_PROXY").map(String::as_str), + Some("localhost") + ); + assert!(inherited_upstream_proxy_env(&BTreeMap::new()).is_empty()); + } + + #[test] + fn helper_outputs_match_expected_shapes() { + assert_eq!( + upstream_proxy_ws_url("http://localhost:3000/"), + "ws://localhost:3000/v1/code/upstreamproxy/ws" + ); + assert!(no_proxy_list().contains("anthropic.com")); + assert!(no_proxy_list().contains("github.com")); + } +}