1use std::ops::Deref;
8
9use anyhow::bail;
10use camino::Utf8PathBuf;
11use mas_iana::oauth::OAuthClientAuthenticationMethod;
12use mas_jose::jwk::PublicJsonWebKeySet;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize, de::Error};
15use serde_with::serde_as;
16use ulid::Ulid;
17use url::Url;
18
19use super::ConfigurationSection;
20
21#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
22#[serde(rename_all = "snake_case")]
23pub enum JwksOrJwksUri {
24    Jwks(PublicJsonWebKeySet),
25    JwksUri(Url),
26}
27
28impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
29    fn from(jwks: PublicJsonWebKeySet) -> Self {
30        Self::Jwks(jwks)
31    }
32}
33
34#[derive(Clone, Debug)]
39pub enum ClientSecret {
40    File(Utf8PathBuf),
41    Value(String),
42}
43
44#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
46struct ClientSecretRaw {
47    #[schemars(with = "Option<String>")]
51    #[serde(skip_serializing_if = "Option::is_none")]
52    client_secret_file: Option<Utf8PathBuf>,
53
54    #[serde(skip_serializing_if = "Option::is_none")]
57    client_secret: Option<String>,
58}
59
60impl TryFrom<ClientSecretRaw> for Option<ClientSecret> {
61    type Error = anyhow::Error;
62
63    fn try_from(value: ClientSecretRaw) -> Result<Self, Self::Error> {
64        match (value.client_secret, value.client_secret_file) {
65            (None, None) => Ok(None),
66            (None, Some(path)) => Ok(Some(ClientSecret::File(path))),
67            (Some(client_secret), None) => Ok(Some(ClientSecret::Value(client_secret))),
68            (Some(_), Some(_)) => {
69                bail!("Cannot specify both `client_secret` and `client_secret_file`")
70            }
71        }
72    }
73}
74
75impl From<Option<ClientSecret>> for ClientSecretRaw {
76    fn from(value: Option<ClientSecret>) -> Self {
77        match value {
78            Some(ClientSecret::File(path)) => ClientSecretRaw {
79                client_secret_file: Some(path),
80                client_secret: None,
81            },
82            Some(ClientSecret::Value(client_secret)) => ClientSecretRaw {
83                client_secret_file: None,
84                client_secret: Some(client_secret),
85            },
86            None => ClientSecretRaw {
87                client_secret_file: None,
88                client_secret: None,
89            },
90        }
91    }
92}
93
94#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
96#[serde(rename_all = "snake_case")]
97pub enum ClientAuthMethodConfig {
98    None,
100
101    ClientSecretBasic,
104
105    ClientSecretPost,
108
109    ClientSecretJwt,
112
113    PrivateKeyJwt,
116}
117
118impl std::fmt::Display for ClientAuthMethodConfig {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            ClientAuthMethodConfig::None => write!(f, "none"),
122            ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
123            ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
124            ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
125            ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
126        }
127    }
128}
129
130#[serde_as]
132#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
133pub struct ClientConfig {
134    #[schemars(
136        with = "String",
137        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
138        description = "A ULID as per https://github.com/ulid/spec"
139    )]
140    pub client_id: Ulid,
141
142    client_auth_method: ClientAuthMethodConfig,
144
145    #[serde(skip_serializing_if = "Option::is_none")]
147    pub client_name: Option<String>,
148
149    #[schemars(with = "ClientSecretRaw")]
152    #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
153    #[serde(flatten)]
154    pub client_secret: Option<ClientSecret>,
155
156    #[serde(skip_serializing_if = "Option::is_none")]
159    pub jwks: Option<PublicJsonWebKeySet>,
160
161    #[serde(skip_serializing_if = "Option::is_none")]
164    pub jwks_uri: Option<Url>,
165
166    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub redirect_uris: Vec<Url>,
169}
170
171impl ClientConfig {
172    fn validate(&self) -> Result<(), Box<figment::error::Error>> {
173        let auth_method = self.client_auth_method;
174        match self.client_auth_method {
175            ClientAuthMethodConfig::PrivateKeyJwt => {
176                if self.jwks.is_none() && self.jwks_uri.is_none() {
177                    let error = figment::error::Error::custom(
178                        "jwks or jwks_uri is required for private_key_jwt",
179                    );
180                    return Err(Box::new(error.with_path("client_auth_method")));
181                }
182
183                if self.jwks.is_some() && self.jwks_uri.is_some() {
184                    let error =
185                        figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
186                    return Err(Box::new(error.with_path("jwks")));
187                }
188
189                if self.client_secret.is_some() {
190                    let error = figment::error::Error::custom(
191                        "client_secret is not allowed with private_key_jwt",
192                    );
193                    return Err(Box::new(error.with_path("client_secret")));
194                }
195            }
196
197            ClientAuthMethodConfig::ClientSecretPost
198            | ClientAuthMethodConfig::ClientSecretBasic
199            | ClientAuthMethodConfig::ClientSecretJwt => {
200                if self.client_secret.is_none() {
201                    let error = figment::error::Error::custom(format!(
202                        "client_secret is required for {auth_method}"
203                    ));
204                    return Err(Box::new(error.with_path("client_auth_method")));
205                }
206
207                if self.jwks.is_some() {
208                    let error = figment::error::Error::custom(format!(
209                        "jwks is not allowed with {auth_method}"
210                    ));
211                    return Err(Box::new(error.with_path("jwks")));
212                }
213
214                if self.jwks_uri.is_some() {
215                    let error = figment::error::Error::custom(format!(
216                        "jwks_uri is not allowed with {auth_method}"
217                    ));
218                    return Err(Box::new(error.with_path("jwks_uri")));
219                }
220            }
221
222            ClientAuthMethodConfig::None => {
223                if self.client_secret.is_some() {
224                    let error = figment::error::Error::custom(
225                        "client_secret is not allowed with none authentication method",
226                    );
227                    return Err(Box::new(error.with_path("client_secret")));
228                }
229
230                if self.jwks.is_some() {
231                    let error = figment::error::Error::custom(
232                        "jwks is not allowed with none authentication method",
233                    );
234                    return Err(Box::new(error));
235                }
236
237                if self.jwks_uri.is_some() {
238                    let error = figment::error::Error::custom(
239                        "jwks_uri is not allowed with none authentication method",
240                    );
241                    return Err(Box::new(error));
242                }
243            }
244        }
245
246        Ok(())
247    }
248
249    #[must_use]
251    pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
252        match self.client_auth_method {
253            ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
254            ClientAuthMethodConfig::ClientSecretBasic => {
255                OAuthClientAuthenticationMethod::ClientSecretBasic
256            }
257            ClientAuthMethodConfig::ClientSecretPost => {
258                OAuthClientAuthenticationMethod::ClientSecretPost
259            }
260            ClientAuthMethodConfig::ClientSecretJwt => {
261                OAuthClientAuthenticationMethod::ClientSecretJwt
262            }
263            ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
264        }
265    }
266
267    pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
275        Ok(match &self.client_secret {
276            Some(ClientSecret::File(path)) => Some(tokio::fs::read_to_string(path).await?),
277            Some(ClientSecret::Value(client_secret)) => Some(client_secret.clone()),
278            None => None,
279        })
280    }
281}
282
283#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
285#[serde(transparent)]
286pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
287
288impl ClientsConfig {
289    pub(crate) fn is_default(&self) -> bool {
291        self.0.is_empty()
292    }
293}
294
295impl Deref for ClientsConfig {
296    type Target = Vec<ClientConfig>;
297
298    fn deref(&self) -> &Self::Target {
299        &self.0
300    }
301}
302
303impl IntoIterator for ClientsConfig {
304    type Item = ClientConfig;
305    type IntoIter = std::vec::IntoIter<ClientConfig>;
306
307    fn into_iter(self) -> Self::IntoIter {
308        self.0.into_iter()
309    }
310}
311
312impl ConfigurationSection for ClientsConfig {
313    const PATH: Option<&'static str> = Some("clients");
314
315    fn validate(
316        &self,
317        figment: &figment::Figment,
318    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
319        for (index, client) in self.0.iter().enumerate() {
320            client.validate().map_err(|mut err| {
321                err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
323                err.profile = Some(figment::Profile::Default);
324                err.path.insert(0, Self::PATH.unwrap().to_owned());
325                err.path.insert(1, format!("{index}"));
326                err
327            })?;
328        }
329
330        Ok(())
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use std::str::FromStr;
337
338    use figment::{
339        Figment, Jail,
340        providers::{Format, Yaml},
341    };
342    use tokio::{runtime::Handle, task};
343
344    use super::*;
345
346    #[tokio::test]
347    async fn load_config() {
348        task::spawn_blocking(|| {
349            Jail::expect_with(|jail| {
350                jail.create_file(
351                    "config.yaml",
352                    r#"
353                      clients:
354                        - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
355                          client_auth_method: none
356                          redirect_uris:
357                            - https://exemple.fr/callback
358
359                        - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
360                          client_auth_method: client_secret_basic
361                          client_secret_file: secret
362
363                        - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
364                          client_auth_method: client_secret_post
365                          client_secret: c1!3n753c237
366
367                        - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
368                          client_auth_method: client_secret_jwt
369                          client_secret_file: secret
370
371                        - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
372                          client_auth_method: private_key_jwt
373                          jwks:
374                            keys:
375                            - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
376                              kty: "RSA"
377                              alg: "RS256"
378                              use: "sig"
379                              e: "AQAB"
380                              n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
381
382                            - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
383                              kty: "RSA"
384                              alg: "RS256"
385                              use: "sig"
386                              e: "AQAB"
387                              n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
388                    "#,
389                )?;
390                jail.create_file("secret", r"c1!3n753c237")?;
391
392                let config = Figment::new()
393                    .merge(Yaml::file("config.yaml"))
394                    .extract_inner::<ClientsConfig>("clients")?;
395
396                assert_eq!(config.0.len(), 5);
397
398                assert_eq!(
399                    config.0[0].client_id,
400                    Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
401                );
402                assert_eq!(
403                    config.0[0].redirect_uris,
404                    vec!["https://exemple.fr/callback".parse().unwrap()]
405                );
406
407                assert_eq!(
408                    config.0[1].client_id,
409                    Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
410                );
411                assert_eq!(config.0[1].redirect_uris, Vec::new());
412
413                assert!(config.0[0].client_secret.is_none());
414                assert!(matches!(config.0[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
415                assert!(matches!(config.0[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
416                assert!(matches!(config.0[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
417                assert!(config.0[4].client_secret.is_none());
418
419                Handle::current().block_on(async move {
420                    assert_eq!(config.0[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
421                    assert_eq!(config.0[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
422                    assert_eq!(config.0[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
423                });
424
425                Ok(())
426            });
427        }).await.unwrap();
428    }
429}