use std::collections::BTreeMap; use std::fs::{self, File}; use std::io::{self, Read}; use std::path::PathBuf; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use sha2::{Digest, Sha256}; use crate::config::OAuthConfig; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct OAuthTokenSet { pub access_token: String, pub refresh_token: Option, pub expires_at: Option, pub scopes: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct PkceCodePair { pub verifier: String, pub challenge: String, pub challenge_method: PkceChallengeMethod, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PkceChallengeMethod { S256, } impl PkceChallengeMethod { #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::S256 => "S256", } } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthAuthorizationRequest { pub authorize_url: String, pub client_id: String, pub redirect_uri: String, pub scopes: Vec, pub state: String, pub code_challenge: String, pub code_challenge_method: PkceChallengeMethod, pub extra_params: BTreeMap, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, pub code: String, pub redirect_uri: String, pub client_id: String, pub code_verifier: String, pub state: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthRefreshRequest { pub grant_type: &'static str, pub refresh_token: String, pub client_id: String, pub scopes: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthCallbackParams { pub code: Option, pub state: Option, pub error: Option, pub error_description: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct StoredOAuthCredentials { access_token: String, #[serde(default)] refresh_token: Option, #[serde(default)] expires_at: Option, #[serde(default)] scopes: Vec, } impl From for StoredOAuthCredentials { fn from(value: OAuthTokenSet) -> Self { Self { access_token: value.access_token, refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, } } } impl From for OAuthTokenSet { fn from(value: StoredOAuthCredentials) -> Self { Self { access_token: value.access_token, refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, } } } impl OAuthAuthorizationRequest { #[must_use] pub fn from_config( config: &OAuthConfig, redirect_uri: impl Into, state: impl Into, pkce: &PkceCodePair, ) -> Self { Self { authorize_url: config.authorize_url.clone(), client_id: config.client_id.clone(), redirect_uri: redirect_uri.into(), scopes: config.scopes.clone(), state: state.into(), code_challenge: pkce.challenge.clone(), code_challenge_method: pkce.challenge_method, extra_params: BTreeMap::new(), } } #[must_use] pub fn with_extra_param(mut self, key: impl Into, value: impl Into) -> Self { self.extra_params.insert(key.into(), value.into()); self } #[must_use] pub fn build_url(&self) -> String { let mut params = vec![ ("response_type", "code".to_string()), ("client_id", self.client_id.clone()), ("redirect_uri", self.redirect_uri.clone()), ("scope", self.scopes.join(" ")), ("state", self.state.clone()), ("code_challenge", self.code_challenge.clone()), ( "code_challenge_method", self.code_challenge_method.as_str().to_string(), ), ]; params.extend( self.extra_params .iter() .map(|(key, value)| (key.as_str(), value.clone())), ); let query = params .into_iter() .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value))) .collect::>() .join("&"); format!( "{}{}{}", self.authorize_url, if self.authorize_url.contains('?') { '&' } else { '?' }, query ) } } impl OAuthTokenExchangeRequest { #[must_use] pub fn from_config( config: &OAuthConfig, code: impl Into, state: impl Into, verifier: impl Into, redirect_uri: impl Into, ) -> Self { Self { grant_type: "authorization_code", code: code.into(), redirect_uri: redirect_uri.into(), client_id: config.client_id.clone(), code_verifier: verifier.into(), state: state.into(), } } #[must_use] pub fn form_params(&self) -> BTreeMap<&str, String> { BTreeMap::from([ ("grant_type", self.grant_type.to_string()), ("code", self.code.clone()), ("redirect_uri", self.redirect_uri.clone()), ("client_id", self.client_id.clone()), ("code_verifier", self.code_verifier.clone()), ("state", self.state.clone()), ]) } } impl OAuthRefreshRequest { #[must_use] pub fn from_config( config: &OAuthConfig, refresh_token: impl Into, scopes: Option>, ) -> Self { Self { grant_type: "refresh_token", refresh_token: refresh_token.into(), client_id: config.client_id.clone(), scopes: scopes.unwrap_or_else(|| config.scopes.clone()), } } #[must_use] pub fn form_params(&self) -> BTreeMap<&str, String> { BTreeMap::from([ ("grant_type", self.grant_type.to_string()), ("refresh_token", self.refresh_token.clone()), ("client_id", self.client_id.clone()), ("scope", self.scopes.join(" ")), ]) } } pub fn generate_pkce_pair() -> io::Result { let verifier = generate_random_token(32)?; Ok(PkceCodePair { challenge: code_challenge_s256(&verifier), verifier, challenge_method: PkceChallengeMethod::S256, }) } pub fn generate_state() -> io::Result { generate_random_token(32) } #[must_use] pub fn code_challenge_s256(verifier: &str) -> String { let digest = Sha256::digest(verifier.as_bytes()); base64url_encode(&digest) } #[must_use] pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") } pub fn credentials_path() -> io::Result { Ok(credentials_home_dir()?.join("credentials.json")) } pub fn load_oauth_credentials() -> io::Result> { let path = credentials_path()?; let root = read_credentials_root(&path)?; let Some(oauth) = root.get("oauth") else { return Ok(None); }; if oauth.is_null() { return Ok(None); } let stored = serde_json::from_value::(oauth.clone()) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; Ok(Some(stored.into())) } pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> { let path = credentials_path()?; let mut root = read_credentials_root(&path)?; root.insert( "oauth".to_string(), serde_json::to_value(StoredOAuthCredentials::from(token_set.clone())) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, ); write_credentials_root(&path, &root) } pub fn clear_oauth_credentials() -> io::Result<()> { let path = credentials_path()?; let mut root = read_credentials_root(&path)?; root.remove("oauth"); write_credentials_root(&path, &root) } pub fn parse_oauth_callback_request_target(target: &str) -> Result { let (path, query) = target .split_once('?') .map_or((target, ""), |(path, query)| (path, query)); if path != "/callback" { return Err(format!("unexpected callback path: {path}")); } parse_oauth_callback_query(query) } pub fn parse_oauth_callback_query(query: &str) -> Result { let mut params = BTreeMap::new(); for pair in query.split('&').filter(|pair| !pair.is_empty()) { let (key, value) = pair .split_once('=') .map_or((pair, ""), |(key, value)| (key, value)); params.insert(percent_decode(key)?, percent_decode(value)?); } Ok(OAuthCallbackParams { code: params.get("code").cloned(), state: params.get("state").cloned(), error: params.get("error").cloned(), error_description: params.get("error_description").cloned(), }) } fn generate_random_token(bytes: usize) -> io::Result { let mut buffer = vec![0_u8; bytes]; File::open("/dev/urandom")?.read_exact(&mut buffer)?; Ok(base64url_encode(&buffer)) } fn credentials_home_dir() -> io::Result { if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") { return Ok(PathBuf::from(path)); } let home = std::env::var_os("HOME") .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?; Ok(PathBuf::from(home).join(".claude")) } fn read_credentials_root(path: &PathBuf) -> io::Result> { match fs::read_to_string(path) { Ok(contents) => { if contents.trim().is_empty() { return Ok(Map::new()); } serde_json::from_str::(&contents) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))? .as_object() .cloned() .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidData, "credentials file must contain a JSON object", ) }) } Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()), Err(error) => Err(error), } } fn write_credentials_root(path: &PathBuf, root: &Map) -> io::Result<()> { if let Some(parent) = path.parent() { fs::create_dir_all(parent)?; } let rendered = serde_json::to_string_pretty(&Value::Object(root.clone())) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; let temp_path = path.with_extension("json.tmp"); fs::write(&temp_path, format!("{rendered}\n"))?; fs::rename(temp_path, path) } fn base64url_encode(bytes: &[u8]) -> String { const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; let mut output = String::new(); let mut index = 0; while index + 3 <= bytes.len() { let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8) | u32::from(bytes[index + 2]); output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); output.push(TABLE[(block & 0x3F) as usize] as char); index += 3; } match bytes.len().saturating_sub(index) { 1 => { let block = u32::from(bytes[index]) << 16; output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); } 2 => { let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8); output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); } _ => {} } output } fn percent_encode(value: &str) -> String { let mut encoded = String::new(); for byte in value.bytes() { match byte { b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { encoded.push(char::from(byte)); } _ => { use std::fmt::Write as _; let _ = write!(&mut encoded, "%{byte:02X}"); } } } encoded } fn percent_decode(value: &str) -> Result { let mut decoded = Vec::with_capacity(value.len()); let bytes = value.as_bytes(); let mut index = 0; while index < bytes.len() { match bytes[index] { b'%' if index + 2 < bytes.len() => { let hi = decode_hex(bytes[index + 1])?; let lo = decode_hex(bytes[index + 2])?; decoded.push((hi << 4) | lo); index += 3; } b'+' => { decoded.push(b' '); index += 1; } byte => { decoded.push(byte); index += 1; } } } String::from_utf8(decoded).map_err(|error| error.to_string()) } fn decode_hex(byte: u8) -> Result { match byte { b'0'..=b'9' => Ok(byte - b'0'), b'a'..=b'f' => Ok(byte - b'a' + 10), b'A'..=b'F' => Ok(byte - b'A' + 10), _ => Err(format!("invalid percent-encoding byte: {byte}")), } } #[cfg(test)] mod tests { use std::time::{SystemTime, UNIX_EPOCH}; use super::{ clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, }; fn sample_config() -> OAuthConfig { OAuthConfig { client_id: "runtime-client".to_string(), authorize_url: "https://console.test/oauth/authorize".to_string(), token_url: "https://console.test/oauth/token".to_string(), callback_port: Some(4545), manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), scopes: vec!["org:read".to_string(), "user:write".to_string()], } } fn env_lock() -> std::sync::MutexGuard<'static, ()> { crate::test_env_lock() } fn temp_config_home() -> std::path::PathBuf { std::env::temp_dir().join(format!( "runtime-oauth-test-{}-{}", std::process::id(), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") .as_nanos() )) } #[test] fn s256_challenge_matches_expected_vector() { assert_eq!( code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"), "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" ); } #[test] fn generates_pkce_pair_and_state() { let pair = generate_pkce_pair().expect("pkce pair"); let state = generate_state().expect("state"); assert!(!pair.verifier.is_empty()); assert!(!pair.challenge.is_empty()); assert!(!state.is_empty()); } #[test] fn builds_authorize_url_and_form_requests() { let config = sample_config(); let pair = generate_pkce_pair().expect("pkce"); let url = OAuthAuthorizationRequest::from_config( &config, loopback_redirect_uri(4545), "state-123", &pair, ) .with_extra_param("login_hint", "user@example.com") .build_url(); assert!(url.starts_with("https://console.test/oauth/authorize?")); assert!(url.contains("response_type=code")); assert!(url.contains("client_id=runtime-client")); assert!(url.contains("scope=org%3Aread%20user%3Awrite")); assert!(url.contains("login_hint=user%40example.com")); let exchange = OAuthTokenExchangeRequest::from_config( &config, "auth-code", "state-123", pair.verifier, loopback_redirect_uri(4545), ); assert_eq!( exchange.form_params().get("grant_type").map(String::as_str), Some("authorization_code") ); let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None); assert_eq!( refresh.form_params().get("scope").map(String::as_str), Some("org:read user:write") ); } #[test] fn oauth_credentials_round_trip_and_clear_preserves_other_fields() { let _guard = env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); let path = credentials_path().expect("credentials path"); std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials"); let token_set = OAuthTokenSet { access_token: "access-token".to_string(), refresh_token: Some("refresh-token".to_string()), expires_at: Some(123), scopes: vec!["scope:a".to_string()], }; save_oauth_credentials(&token_set).expect("save credentials"); assert_eq!( load_oauth_credentials().expect("load credentials"), Some(token_set) ); let saved = std::fs::read_to_string(&path).expect("read saved file"); assert!(saved.contains("\"other\": \"value\"")); assert!(saved.contains("\"oauth\"")); clear_oauth_credentials().expect("clear credentials"); assert_eq!(load_oauth_credentials().expect("load cleared"), None); let cleared = std::fs::read_to_string(&path).expect("read cleared file"); assert!(cleared.contains("\"other\": \"value\"")); assert!(!cleared.contains("\"oauth\"")); std::env::remove_var("CLAUDE_CONFIG_HOME"); std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); } #[test] fn parses_callback_query_and_target() { let params = parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login") .expect("parse query"); assert_eq!(params.code.as_deref(), Some("abc123")); assert_eq!(params.state.as_deref(), Some("state-1")); assert_eq!(params.error_description.as_deref(), Some("needs login")); let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz") .expect("parse callback target"); assert_eq!(params.code.as_deref(), Some("abc")); assert_eq!(params.state.as_deref(), Some("xyz")); assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err()); } }