use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
use base64ct::{Base64UrlUnpadded, Encoding};
use mas_iana::jose::JsonWebSignatureAlg;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sha2::{Digest, Sha256, Sha384, Sha512};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ClaimError {
    #[error("missing claim {0:?}")]
    MissingClaim(&'static str),
    #[error("invalid claim {0:?}")]
    InvalidClaim(&'static str),
    #[error("could not validate claim {claim:?}")]
    ValidationError {
        claim: &'static str,
        #[source]
        source: Box<dyn std::error::Error + Send + Sync + 'static>,
    },
}
pub trait Validator<T> {
    type Error;
    fn validate(&self, value: &T) -> Result<(), Self::Error>;
}
impl<T> Validator<T> for () {
    type Error = Infallible;
    fn validate(&self, _value: &T) -> Result<(), Self::Error> {
        Ok(())
    }
}
pub struct Claim<T, V = ()> {
    claim: &'static str,
    t: PhantomData<T>,
    v: PhantomData<V>,
}
impl<T, V> Claim<T, V>
where
    V: Validator<T>,
{
    #[must_use]
    pub const fn new(claim: &'static str) -> Self {
        Self {
            claim,
            t: PhantomData,
            v: PhantomData,
        }
    }
    pub fn insert<I>(
        &self,
        claims: &mut HashMap<String, serde_json::Value>,
        value: I,
    ) -> Result<(), ClaimError>
    where
        I: Into<T>,
        T: Serialize,
    {
        let value = value.into();
        let value: serde_json::Value =
            serde_json::to_value(&value).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
        claims.insert(self.claim.to_owned(), value);
        Ok(())
    }
    pub fn extract_required(
        &self,
        claims: &mut HashMap<String, serde_json::Value>,
    ) -> Result<T, ClaimError>
    where
        T: DeserializeOwned,
        V: Default,
        V::Error: std::error::Error + Send + Sync + 'static,
    {
        let validator = V::default();
        self.extract_required_with_options(claims, validator)
    }
    pub fn extract_required_with_options<I>(
        &self,
        claims: &mut HashMap<String, serde_json::Value>,
        validator: I,
    ) -> Result<T, ClaimError>
    where
        T: DeserializeOwned,
        I: Into<V>,
        V::Error: std::error::Error + Send + Sync + 'static,
    {
        let validator: V = validator.into();
        let claim = claims
            .remove(self.claim)
            .ok_or(ClaimError::MissingClaim(self.claim))?;
        let res =
            serde_json::from_value(claim).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
        validator
            .validate(&res)
            .map_err(|source| ClaimError::ValidationError {
                claim: self.claim,
                source: Box::new(source),
            })?;
        Ok(res)
    }
    pub fn extract_optional(
        &self,
        claims: &mut HashMap<String, serde_json::Value>,
    ) -> Result<Option<T>, ClaimError>
    where
        T: DeserializeOwned,
        V: Default,
        V::Error: std::error::Error + Send + Sync + 'static,
    {
        let validator = V::default();
        self.extract_optional_with_options(claims, validator)
    }
    pub fn extract_optional_with_options<I>(
        &self,
        claims: &mut HashMap<String, serde_json::Value>,
        validator: I,
    ) -> Result<Option<T>, ClaimError>
    where
        T: DeserializeOwned,
        I: Into<V>,
        V::Error: std::error::Error + Send + Sync + 'static,
    {
        match self.extract_required_with_options(claims, validator) {
            Ok(v) => Ok(Some(v)),
            Err(ClaimError::MissingClaim(_)) => Ok(None),
            Err(e) => Err(e),
        }
    }
}
#[derive(Debug, Clone)]
pub struct TimeOptions {
    when: chrono::DateTime<chrono::Utc>,
    leeway: chrono::Duration,
}
impl TimeOptions {
    #[must_use]
    pub fn new(when: chrono::DateTime<chrono::Utc>) -> Self {
        Self {
            when,
            leeway: chrono::Duration::microseconds(5 * 60 * 1000 * 1000),
        }
    }
    #[must_use]
    pub fn leeway(mut self, leeway: chrono::Duration) -> Self {
        self.leeway = leeway;
        self
    }
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Current time is too far away")]
pub struct TimeTooFarError;
#[derive(Debug, Clone)]
pub struct TimeNotAfter(TimeOptions);
impl Validator<Timestamp> for TimeNotAfter {
    type Error = TimeTooFarError;
    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
        if self.0.when <= value.0 + self.0.leeway {
            Ok(())
        } else {
            Err(TimeTooFarError)
        }
    }
}
impl From<TimeOptions> for TimeNotAfter {
    fn from(opt: TimeOptions) -> Self {
        Self(opt)
    }
}
impl From<&TimeOptions> for TimeNotAfter {
    fn from(opt: &TimeOptions) -> Self {
        opt.clone().into()
    }
}
#[derive(Debug, Clone)]
pub struct TimeNotBefore(TimeOptions);
impl Validator<Timestamp> for TimeNotBefore {
    type Error = TimeTooFarError;
    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
        if self.0.when >= value.0 - self.0.leeway {
            Ok(())
        } else {
            Err(TimeTooFarError)
        }
    }
}
impl From<TimeOptions> for TimeNotBefore {
    fn from(opt: TimeOptions) -> Self {
        Self(opt)
    }
}
impl From<&TimeOptions> for TimeNotBefore {
    fn from(opt: &TimeOptions) -> Self {
        opt.clone().into()
    }
}
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
    let bits = match alg {
        JsonWebSignatureAlg::Hs256
        | JsonWebSignatureAlg::Rs256
        | JsonWebSignatureAlg::Es256
        | JsonWebSignatureAlg::Ps256
        | JsonWebSignatureAlg::Es256K => {
            let mut hasher = Sha256::new();
            hasher.update(token);
            let hash: [u8; 32] = hasher.finalize().into();
            hash[..16].to_owned()
        }
        JsonWebSignatureAlg::Hs384
        | JsonWebSignatureAlg::Rs384
        | JsonWebSignatureAlg::Es384
        | JsonWebSignatureAlg::Ps384 => {
            let mut hasher = Sha384::new();
            hasher.update(token);
            let hash: [u8; 48] = hasher.finalize().into();
            hash[..24].to_owned()
        }
        JsonWebSignatureAlg::Hs512
        | JsonWebSignatureAlg::Rs512
        | JsonWebSignatureAlg::Es512
        | JsonWebSignatureAlg::Ps512 => {
            let mut hasher = Sha512::new();
            hasher.update(token);
            let hash: [u8; 64] = hasher.finalize().into();
            hash[..32].to_owned()
        }
        _ => return Err(TokenHashError::UnsupportedAlgorithm),
    };
    Ok(Base64UrlUnpadded::encode_string(&bits))
}
#[derive(Debug, Clone, Copy, Error)]
pub enum TokenHashError {
    #[error("Hashes don't match")]
    HashMismatch,
    #[error("Unsupported algorithm for hashing")]
    UnsupportedAlgorithm,
}
#[derive(Debug, Clone)]
pub struct TokenHash<'a> {
    alg: &'a JsonWebSignatureAlg,
    token: &'a str,
}
impl<'a> TokenHash<'a> {
    #[must_use]
    pub fn new(alg: &'a JsonWebSignatureAlg, token: &'a str) -> Self {
        Self { alg, token }
    }
}
impl<'a> Validator<String> for TokenHash<'a> {
    type Error = TokenHashError;
    fn validate(&self, value: &String) -> Result<(), Self::Error> {
        if hash_token(self.alg, self.token)? == *value {
            Ok(())
        } else {
            Err(TokenHashError::HashMismatch)
        }
    }
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Values don't match")]
pub struct EqualityError;
#[derive(Debug, Clone)]
pub struct Equality<'a, T: ?Sized> {
    value: &'a T,
}
impl<'a, T: ?Sized> Equality<'a, T> {
    #[must_use]
    pub fn new(value: &'a T) -> Self {
        Self { value }
    }
}
impl<'a, T1, T2> Validator<T1> for Equality<'a, T2>
where
    T2: PartialEq<T1> + ?Sized,
{
    type Error = EqualityError;
    fn validate(&self, value: &T1) -> Result<(), Self::Error> {
        if *self.value == *value {
            Ok(())
        } else {
            Err(EqualityError)
        }
    }
}
impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> {
    fn from(value: &'a T) -> Self {
        Self::new(value)
    }
}
#[derive(Debug, Clone)]
pub struct Contains<'a, T> {
    value: &'a T,
}
impl<'a, T> Contains<'a, T> {
    #[must_use]
    pub fn new(value: &'a T) -> Self {
        Self { value }
    }
}
#[derive(Debug, Clone, Copy, Error)]
#[error("OneOrMany doesn't contain value")]
pub struct ContainsError;
impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
where
    T: PartialEq,
{
    type Error = ContainsError;
    fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
        if value.contains(self.value) {
            Ok(())
        } else {
            Err(ContainsError)
        }
    }
}
impl<'a, T> From<&'a T> for Contains<'a, T> {
    fn from(value: &'a T) -> Self {
        Self::new(value)
    }
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(transparent)]
pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>);
impl Deref for Timestamp {
    type Target = chrono::DateTime<chrono::Utc>;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}
impl From<chrono::DateTime<chrono::Utc>> for Timestamp {
    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
        Timestamp(value)
    }
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(
    transparent,
    bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>")
)]
pub struct OneOrMany<T>(
    #[serde(
        with = "serde_with::As::<serde_with::OneOrMany<serde_with::Same, serde_with::formats::PreferOne>>"
    )]
    Vec<T>,
);
impl<T> Deref for OneOrMany<T> {
    type Target = Vec<T>;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}
impl<T> From<Vec<T>> for OneOrMany<T> {
    fn from(value: Vec<T>) -> Self {
        Self(value)
    }
}
impl<T> From<T> for OneOrMany<T> {
    fn from(value: T) -> Self {
        Self(vec![value])
    }
}
mod rfc7519 {
    use super::{Claim, Contains, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp};
    pub const ISS: Claim<String, Equality<str>> = Claim::new("iss");
    pub const SUB: Claim<String> = Claim::new("sub");
    pub const AUD: Claim<OneOrMany<String>, Contains<String>> = Claim::new("aud");
    pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf");
    pub const EXP: Claim<Timestamp, TimeNotAfter> = Claim::new("exp");
    pub const IAT: Claim<Timestamp, TimeNotBefore> = Claim::new("iat");
    pub const JTI: Claim<String> = Claim::new("jti");
}
mod oidc_core {
    use url::Url;
    use super::{Claim, Equality, Timestamp, TokenHash};
    pub const AUTH_TIME: Claim<Timestamp> = Claim::new("auth_time");
    pub const NONCE: Claim<String, Equality<str>> = Claim::new("nonce");
    pub const AT_HASH: Claim<String, TokenHash> = Claim::new("at_hash");
    pub const C_HASH: Claim<String, TokenHash> = Claim::new("c_hash");
    pub const NAME: Claim<String> = Claim::new("name");
    pub const GIVEN_NAME: Claim<String> = Claim::new("given_name");
    pub const FAMILY_NAME: Claim<String> = Claim::new("family_name");
    pub const MIDDLE_NAME: Claim<String> = Claim::new("middle_name");
    pub const NICKNAME: Claim<String> = Claim::new("nickname");
    pub const PREFERRED_USERNAME: Claim<String> = Claim::new("preferred_username");
    pub const PROFILE: Claim<Url> = Claim::new("profile");
    pub const PICTURE: Claim<Url> = Claim::new("picture");
    pub const WEBSITE: Claim<Url> = Claim::new("website");
    pub const EMAIL: Claim<String> = Claim::new("email");
    pub const EMAIL_VERIFIED: Claim<bool> = Claim::new("email_verified");
    pub const GENDER: Claim<String> = Claim::new("gender");
    pub const BIRTHDATE: Claim<String> = Claim::new("birthdate");
    pub const ZONEINFO: Claim<String> = Claim::new("zoneinfo");
    pub const LOCALE: Claim<String> = Claim::new("locale");
    pub const PHONE_NUMBER: Claim<String> = Claim::new("phone_number");
    pub const PHONE_NUMBER_VERIFIED: Claim<bool> = Claim::new("phone_number_verified");
    pub const UPDATED_AT: Claim<Timestamp> = Claim::new("updated_at");
}
pub use self::{oidc_core::*, rfc7519::*};
#[cfg(test)]
mod tests {
    use chrono::TimeZone;
    use super::*;
    #[test]
    fn timestamp_serde() {
        let datetime = Timestamp(
            chrono::Utc
                .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
                .unwrap(),
        );
        let timestamp = serde_json::Value::Number(1_516_239_022.into());
        assert_eq!(datetime, serde_json::from_value(timestamp.clone()).unwrap());
        assert_eq!(timestamp, serde_json::to_value(&datetime).unwrap());
    }
    #[test]
    fn one_or_many_serde() {
        let one = OneOrMany(vec!["one".to_owned()]);
        let many = OneOrMany(vec!["one".to_owned(), "two".to_owned()]);
        assert_eq!(
            one,
            serde_json::from_value(serde_json::json!("one")).unwrap()
        );
        assert_eq!(
            one,
            serde_json::from_value(serde_json::json!(["one"])).unwrap()
        );
        assert_eq!(
            many,
            serde_json::from_value(serde_json::json!(["one", "two"])).unwrap()
        );
        assert_eq!(
            serde_json::to_value(&one).unwrap(),
            serde_json::json!("one")
        );
        assert_eq!(
            serde_json::to_value(&many).unwrap(),
            serde_json::json!(["one", "two"])
        );
    }
    #[test]
    fn extract_claims() {
        let now = chrono::Utc
            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
            .unwrap();
        let expiration = now + chrono::Duration::microseconds(5 * 60 * 1000 * 1000);
        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
        let claims = serde_json::json!({
            "iss": "https://foo.com",
            "sub": "johndoe",
            "aud": ["abcd-efgh"],
            "iat": 1_516_239_022,
            "nbf": 1_516_239_022,
            "exp": 1_516_239_322,
            "jti": "1122-3344-5566-7788",
        });
        let mut claims = serde_json::from_value(claims).unwrap();
        let iss = ISS
            .extract_required_with_options(&mut claims, "https://foo.com")
            .unwrap();
        let sub = SUB.extract_optional(&mut claims).unwrap();
        let aud = AUD
            .extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned())
            .unwrap();
        let nbf = NBF
            .extract_optional_with_options(&mut claims, &time_options)
            .unwrap();
        let exp = EXP
            .extract_optional_with_options(&mut claims, &time_options)
            .unwrap();
        let iat = IAT
            .extract_optional_with_options(&mut claims, &time_options)
            .unwrap();
        let jti = JTI.extract_optional(&mut claims).unwrap();
        assert_eq!(iss, "https://foo.com".to_owned());
        assert_eq!(sub, Some("johndoe".to_owned()));
        assert_eq!(aud.as_deref(), Some(&vec!["abcd-efgh".to_owned()]));
        assert_eq!(iat.as_deref(), Some(&now));
        assert_eq!(nbf.as_deref(), Some(&now));
        assert_eq!(exp.as_deref(), Some(&expiration));
        assert_eq!(jti, Some("1122-3344-5566-7788".to_owned()));
        assert!(claims.is_empty());
    }
    #[test]
    fn time_validation() {
        let now = chrono::Utc
            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
            .unwrap();
        let claims = serde_json::json!({
            "iat": 1_516_239_022,
            "nbf": 1_516_239_022,
            "exp": 1_516_239_322,
        });
        let claims: HashMap<String, serde_json::Value> = serde_json::from_value(claims).unwrap();
        {
            let mut claims = claims.clone();
            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
            assert!(IAT
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(NBF
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(EXP
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
        }
        let now = now - chrono::Duration::microseconds(60 * 1000 * 1000);
        {
            let mut claims = claims.clone();
            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
            assert!(matches!(
                IAT.extract_required_with_options(&mut claims, &time_options),
                Err(ClaimError::ValidationError { claim: "iat", .. }),
            ));
            assert!(matches!(
                NBF.extract_required_with_options(&mut claims, &time_options),
                Err(ClaimError::ValidationError { claim: "nbf", .. }),
            ));
            assert!(EXP
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
        }
        {
            let mut claims = claims.clone();
            let time_options =
                TimeOptions::new(now).leeway(chrono::Duration::microseconds(2 * 60 * 1000 * 1000));
            assert!(IAT
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(NBF
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(EXP
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
        }
        let now = now + chrono::Duration::microseconds((1 + 6) * 60 * 1000 * 1000);
        {
            let mut claims = claims.clone();
            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
            assert!(IAT
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(NBF
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(matches!(
                EXP.extract_required_with_options(&mut claims, &time_options),
                Err(ClaimError::ValidationError { claim: "exp", .. }),
            ));
        }
        {
            let mut claims = claims;
            let time_options =
                TimeOptions::new(now).leeway(chrono::Duration::try_minutes(2).unwrap());
            assert!(IAT
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(NBF
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
            assert!(EXP
                .extract_required_with_options(&mut claims, &time_options)
                .is_ok());
        }
    }
    #[test]
    fn invalid_claims() {
        let now = chrono::Utc
            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
            .unwrap();
        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
        let claims = serde_json::json!({
            "iss": 123,
            "sub": 456,
            "aud": 789,
            "iat": "123",
            "nbf": "456",
            "exp": "789",
            "jti": 123,
        });
        let mut claims = serde_json::from_value(claims).unwrap();
        assert!(matches!(
            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
            Err(ClaimError::InvalidClaim("iss"))
        ));
        assert!(matches!(
            SUB.extract_required(&mut claims),
            Err(ClaimError::InvalidClaim("sub"))
        ));
        assert!(matches!(
            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
            Err(ClaimError::InvalidClaim("aud"))
        ));
        assert!(matches!(
            NBF.extract_required_with_options(&mut claims, &time_options),
            Err(ClaimError::InvalidClaim("nbf"))
        ));
        assert!(matches!(
            EXP.extract_required_with_options(&mut claims, &time_options),
            Err(ClaimError::InvalidClaim("exp"))
        ));
        assert!(matches!(
            IAT.extract_required_with_options(&mut claims, &time_options),
            Err(ClaimError::InvalidClaim("iat"))
        ));
        assert!(matches!(
            JTI.extract_required(&mut claims),
            Err(ClaimError::InvalidClaim("jti"))
        ));
    }
    #[test]
    fn missing_claims() {
        let mut claims = HashMap::new();
        assert!(matches!(
            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
            Err(ClaimError::MissingClaim("iss"))
        ));
        assert!(matches!(
            SUB.extract_required(&mut claims),
            Err(ClaimError::MissingClaim("sub"))
        ));
        assert!(matches!(
            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
            Err(ClaimError::MissingClaim("aud"))
        ));
        assert!(matches!(
            ISS.extract_optional_with_options(&mut claims, "https://foo.com"),
            Ok(None)
        ));
        assert!(matches!(SUB.extract_optional(&mut claims), Ok(None)));
        assert!(matches!(
            AUD.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()),
            Ok(None)
        ));
    }
    #[test]
    fn string_eq_validation() {
        let claims = serde_json::json!({
            "iss": "https://foo.com",
        });
        let mut claims: HashMap<String, serde_json::Value> =
            serde_json::from_value(claims).unwrap();
        ISS.extract_required_with_options(&mut claims.clone(), "https://foo.com")
            .unwrap();
        assert!(matches!(
            ISS.extract_required_with_options(&mut claims, "https://bar.com"),
            Err(ClaimError::ValidationError { claim: "iss", .. }),
        ));
    }
    #[test]
    fn contains_validation() {
        let claims = serde_json::json!({
            "aud": "abcd-efgh",
        });
        let mut claims: HashMap<String, serde_json::Value> =
            serde_json::from_value(claims).unwrap();
        AUD.extract_required_with_options(&mut claims.clone(), &"abcd-efgh".to_owned())
            .unwrap();
        assert!(matches!(
            AUD.extract_required_with_options(&mut claims, &"wxyz".to_owned()),
            Err(ClaimError::ValidationError { claim: "aud", .. }),
        ));
    }
}