diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c79bbcc..fc01bf8 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -67,11 +67,14 @@ version = "0.1.0" dependencies = [ "api", "async-trait", + "base64 0.22.1", "bitcoin_hashes 0.19.0", "hex-conservative 1.0.1", "jsonwebtoken", + "openssl", "secp256k1", "serde", + "serde_json", "tokio", ] diff --git a/rust/auth-impls/Cargo.toml b/rust/auth-impls/Cargo.toml index a6a0698..819a748 100644 --- a/rust/auth-impls/Cargo.toml +++ b/rust/auth-impls/Cargo.toml @@ -5,18 +5,20 @@ edition = "2021" rust-version.workspace = true [features] -jwt = [ "jsonwebtoken", "serde" ] +jwt = [ "base64", "serde", "serde_json", "openssl" ] sigs = [ "bitcoin_hashes", "hex-conservative", "secp256k1" ] [dependencies] -async-trait = "0.1.77" api = { path = "../api" } -jsonwebtoken = { version = "9.3.0", optional = true, default-features = false, features = ["use_pem"] } -serde = { version = "1.0.210", optional = true, default-features = false, features = ["derive"] } - +async-trait = "0.1.77" +base64 = { version = "0.22.1", optional = true, default-features = false, features = ["std"] } bitcoin_hashes = { version = "0.19", optional = true, default-features = false } hex-conservative = { version = "1.0", optional = true, default-features = false } +openssl = { version = "0.10.75", optional = true, default-features = false } secp256k1 = { version = "0.31", optional = true, default-features = false, features = [ "global-context" ] } +serde = { version = "1.0.210", optional = true, default-features = false, features = ["derive"] } +serde_json = { version = "1.0.149", optional = true, default-features = false, features = ["std"] } [dev-dependencies] +jsonwebtoken = { version = "9.3.0", default-features = false, features = ["use_pem"] } tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] } diff --git a/rust/auth-impls/src/jwt.rs b/rust/auth-impls/src/jwt.rs index 01889de..e881228 100644 --- a/rust/auth-impls/src/jwt.rs +++ b/rust/auth-impls/src/jwt.rs @@ -5,8 +5,13 @@ use api::auth::{AuthResponse, Authorizer}; use api::error::VssError; use async_trait::async_trait; -use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; -use serde::{Deserialize, Serialize}; +use base64::engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}; +use base64::Engine; +use openssl::hash::MessageDigest; +use openssl::pkey::PKey; +use openssl::pkey::Public; +use openssl::sign::Verifier; +use serde::Deserialize; use std::collections::HashMap; /// A JWT based authorizer, only allows requests with verified 'JsonWebToken' signed by the given @@ -14,13 +19,13 @@ use std::collections::HashMap; /// /// Refer: https://datatracker.ietf.org/doc/html/rfc7519 pub struct JWTAuthorizer { - jwt_issuer_key: DecodingKey, + jwt_issuer_key: PKey, } /// A set of Claims claimed by 'JsonWebToken' /// /// Refer: https://datatracker.ietf.org/doc/html/rfc7519#section-4 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Deserialize, Debug)] pub(crate) struct Claims { /// The "sub" (subject) claim identifies the principal that is the subject of the JWT. /// The claims in a JWT are statements about the subject. This can be used as user identifier. @@ -31,10 +36,22 @@ pub(crate) struct Claims { const BEARER_PREFIX: &str = "Bearer "; +fn parse_public_key_pem(pem: &str) -> Result, String> { + let body = pem + .trim() + .strip_prefix("-----BEGIN PUBLIC KEY-----") + .ok_or(String::from("Prefix not found"))? + .strip_suffix("-----END PUBLIC KEY-----") + .ok_or(String::from("Suffix not found"))?; + let body: String = body.lines().map(|line| line.trim()).collect(); + let body = STANDARD.decode(body).map_err(|_| String::from("Base64 decode failed"))?; + PKey::public_key_from_der(&body).map_err(|_| String::from("DER decode failed")) +} + impl JWTAuthorizer { /// Creates a new instance of [`JWTAuthorizer`], fails on failure to parse the PEM formatted RSA public key pub async fn new(rsa_pem: &str) -> Result { - let jwt_issuer_key = DecodingKey::from_rsa_pem(rsa_pem.as_bytes()) + let jwt_issuer_key = parse_public_key_pem(rsa_pem) .map_err(|e| format!("Failed to parse the PEM formatted RSA public key: {}", e))?; Ok(Self { jwt_issuer_key }) } @@ -53,10 +70,45 @@ impl Authorizer for JWTAuthorizer { .strip_prefix(BEARER_PREFIX) .ok_or(VssError::AuthError("Invalid token format.".to_string()))?; - let claims = - decode::(token, &self.jwt_issuer_key, &Validation::new(Algorithm::RS256)) - .map_err(|e| VssError::AuthError(format!("Authentication failure. {}", e)))? - .claims; + let mut iter = token.split('.'); + let [header_base64, claims_base64, signature_base64] = + match [iter.next(), iter.next(), iter.next(), iter.next()] { + [Some(h), Some(c), Some(s), None] => [h, c, s], + _ => { + return Err(VssError::AuthError(String::from( + "Token does not have three parts", + ))) + }, + }; + + let header_bytes = URL_SAFE_NO_PAD + .decode(header_base64) + .map_err(|_| VssError::AuthError(String::from("Header base64 decode failed")))?; + let header: serde_json::Value = serde_json::from_slice(&header_bytes) + .map_err(|_| VssError::AuthError(String::from("Header json decode failed")))?; + match header["alg"] { + serde_json::Value::String(ref alg) if alg == "RS256" => (), + _ => return Err(VssError::AuthError(String::from("alg: RS256 not found in header"))), + } + + let (message, _) = token.rsplit_once('.').expect("There are two periods in the token"); + let signature = URL_SAFE_NO_PAD + .decode(signature_base64) + .map_err(|_| VssError::AuthError(String::from("Signature base64 decode failed")))?; + let mut verifier = Verifier::new(MessageDigest::sha256(), &self.jwt_issuer_key) + .map_err(|_| VssError::AuthError(String::from("RSA initialization failed")))?; + if !verifier + .verify_oneshot(&signature, message.as_bytes()) + .map_err(|_| VssError::AuthError(String::from("RSA verification failed")))? + { + return Err(VssError::AuthError(String::from("RSA verification failed"))); + } + + let claims_json = URL_SAFE_NO_PAD + .decode(claims_base64) + .map_err(|_| VssError::AuthError(String::from("Claims base64 decode failed")))?; + let claims: Claims = serde_json::from_slice(&claims_json) + .map_err(|_| VssError::AuthError(String::from("Claims json decode failed")))?; Ok(AuthResponse { user_token: claims.sub }) } diff --git a/rust/auth-impls/src/signature.rs b/rust/auth-impls/src/signature.rs index 428dd57..dfc8cf6 100644 --- a/rust/auth-impls/src/signature.rs +++ b/rust/auth-impls/src/signature.rs @@ -92,7 +92,7 @@ mod tests { use crate::signature::{SignatureValidatingAuthorizer, SIGNING_CONSTANT}; use api::auth::Authorizer; use api::error::VssError; - use secp256k1::{Message, PublicKey, Secp256k1, SecretKey}; + use secp256k1::{Message, PublicKey, SecretKey}; use std::collections::HashMap; use std::fmt::Write; use std::time::SystemTime;