use std::collections::BTreeMap; use std::fs::File; use std::io::{self, Read}; use sha2::{Digest, Sha256}; use crate::config::OAuthConfig; #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthTokenSet { pub access_token: String, pub refresh_token: Option, pub expires_at: Option, pub scopes: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct PkceCodePair { pub verifier: String, pub challenge: String, pub challenge_method: PkceChallengeMethod, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PkceChallengeMethod { S256, } impl PkceChallengeMethod { #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::S256 => "S256", } } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthAuthorizationRequest { pub authorize_url: String, pub client_id: String, pub redirect_uri: String, pub scopes: Vec, pub state: String, pub code_challenge: String, pub code_challenge_method: PkceChallengeMethod, pub extra_params: BTreeMap, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, pub code: String, pub redirect_uri: String, pub client_id: String, pub code_verifier: String, pub state: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthRefreshRequest { pub grant_type: &'static str, pub refresh_token: String, pub client_id: String, pub scopes: Vec, } impl OAuthAuthorizationRequest { #[must_use] pub fn from_config( config: &OAuthConfig, redirect_uri: impl Into, state: impl Into, pkce: &PkceCodePair, ) -> Self { Self { authorize_url: config.authorize_url.clone(), client_id: config.client_id.clone(), redirect_uri: redirect_uri.into(), scopes: config.scopes.clone(), state: state.into(), code_challenge: pkce.challenge.clone(), code_challenge_method: pkce.challenge_method, extra_params: BTreeMap::new(), } } #[must_use] pub fn with_extra_param(mut self, key: impl Into, value: impl Into) -> Self { self.extra_params.insert(key.into(), value.into()); self } #[must_use] pub fn build_url(&self) -> String { let mut params = vec![ ("response_type", "code".to_string()), ("client_id", self.client_id.clone()), ("redirect_uri", self.redirect_uri.clone()), ("scope", self.scopes.join(" ")), ("state", self.state.clone()), ("code_challenge", self.code_challenge.clone()), ( "code_challenge_method", self.code_challenge_method.as_str().to_string(), ), ]; params.extend( self.extra_params .iter() .map(|(key, value)| (key.as_str(), value.clone())), ); let query = params .into_iter() .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value))) .collect::>() .join("&"); format!( "{}{}{}", self.authorize_url, if self.authorize_url.contains('?') { '&' } else { '?' }, query ) } } impl OAuthTokenExchangeRequest { #[must_use] pub fn from_config( config: &OAuthConfig, code: impl Into, state: impl Into, verifier: impl Into, redirect_uri: impl Into, ) -> Self { let _ = config; Self { grant_type: "authorization_code", code: code.into(), redirect_uri: redirect_uri.into(), client_id: config.client_id.clone(), code_verifier: verifier.into(), state: state.into(), } } #[must_use] pub fn form_params(&self) -> BTreeMap<&str, String> { BTreeMap::from([ ("grant_type", self.grant_type.to_string()), ("code", self.code.clone()), ("redirect_uri", self.redirect_uri.clone()), ("client_id", self.client_id.clone()), ("code_verifier", self.code_verifier.clone()), ("state", self.state.clone()), ]) } } impl OAuthRefreshRequest { #[must_use] pub fn from_config( config: &OAuthConfig, refresh_token: impl Into, scopes: Option>, ) -> Self { Self { grant_type: "refresh_token", refresh_token: refresh_token.into(), client_id: config.client_id.clone(), scopes: scopes.unwrap_or_else(|| config.scopes.clone()), } } #[must_use] pub fn form_params(&self) -> BTreeMap<&str, String> { BTreeMap::from([ ("grant_type", self.grant_type.to_string()), ("refresh_token", self.refresh_token.clone()), ("client_id", self.client_id.clone()), ("scope", self.scopes.join(" ")), ]) } } pub fn generate_pkce_pair() -> io::Result { let verifier = generate_random_token(32)?; Ok(PkceCodePair { challenge: code_challenge_s256(&verifier), verifier, challenge_method: PkceChallengeMethod::S256, }) } pub fn generate_state() -> io::Result { generate_random_token(32) } #[must_use] pub fn code_challenge_s256(verifier: &str) -> String { let digest = Sha256::digest(verifier.as_bytes()); base64url_encode(&digest) } #[must_use] pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") } fn generate_random_token(bytes: usize) -> io::Result { let mut buffer = vec![0_u8; bytes]; File::open("/dev/urandom")?.read_exact(&mut buffer)?; Ok(base64url_encode(&buffer)) } fn base64url_encode(bytes: &[u8]) -> String { const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; let mut output = String::new(); let mut index = 0; while index + 3 <= bytes.len() { let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8) | u32::from(bytes[index + 2]); output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); output.push(TABLE[(block & 0x3F) as usize] as char); index += 3; } match bytes.len().saturating_sub(index) { 1 => { let block = u32::from(bytes[index]) << 16; output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); } 2 => { let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8); output.push(TABLE[((block >> 18) & 0x3F) as usize] as char); output.push(TABLE[((block >> 12) & 0x3F) as usize] as char); output.push(TABLE[((block >> 6) & 0x3F) as usize] as char); } _ => {} } output } fn percent_encode(value: &str) -> String { let mut encoded = String::new(); for byte in value.bytes() { match byte { b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { encoded.push(char::from(byte)); } _ => { use std::fmt::Write as _; let _ = write!(&mut encoded, "%{byte:02X}"); } } } encoded } #[cfg(test)] mod tests { use super::{ code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri, OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, }; fn sample_config() -> OAuthConfig { OAuthConfig { client_id: "runtime-client".to_string(), authorize_url: "https://console.test/oauth/authorize".to_string(), token_url: "https://console.test/oauth/token".to_string(), callback_port: Some(4545), manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), scopes: vec!["org:read".to_string(), "user:write".to_string()], } } #[test] fn s256_challenge_matches_expected_vector() { assert_eq!( code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"), "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" ); } #[test] fn generates_pkce_pair_and_state() { let pair = generate_pkce_pair().expect("pkce pair"); let state = generate_state().expect("state"); assert!(!pair.verifier.is_empty()); assert!(!pair.challenge.is_empty()); assert!(!state.is_empty()); } #[test] fn builds_authorize_url_and_form_requests() { let config = sample_config(); let pair = generate_pkce_pair().expect("pkce"); let url = OAuthAuthorizationRequest::from_config( &config, loopback_redirect_uri(4545), "state-123", &pair, ) .with_extra_param("login_hint", "user@example.com") .build_url(); assert!(url.starts_with("https://console.test/oauth/authorize?")); assert!(url.contains("response_type=code")); assert!(url.contains("client_id=runtime-client")); assert!(url.contains("scope=org%3Aread%20user%3Awrite")); assert!(url.contains("login_hint=user%40example.com")); let exchange = OAuthTokenExchangeRequest::from_config( &config, "auth-code", "state-123", pair.verifier, loopback_redirect_uri(4545), ); assert_eq!( exchange.form_params().get("grant_type").map(String::as_str), Some("authorization_code") ); let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None); assert_eq!( refresh.form_params().get("scope").map(String::as_str), Some("org:read user:write") ); } }